metaflow 2.11.15__py2.py3-none-any.whl → 2.11.16__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. metaflow/__init__.py +3 -0
  2. metaflow/clone_util.py +6 -0
  3. metaflow/extension_support/plugins.py +2 -0
  4. metaflow/metaflow_config.py +24 -0
  5. metaflow/metaflow_environment.py +2 -2
  6. metaflow/plugins/__init__.py +19 -0
  7. metaflow/plugins/airflow/airflow.py +7 -0
  8. metaflow/plugins/argo/argo_workflows.py +17 -0
  9. metaflow/plugins/azure/__init__.py +3 -0
  10. metaflow/plugins/azure/azure_credential.py +53 -0
  11. metaflow/plugins/azure/azure_exceptions.py +1 -1
  12. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  13. metaflow/plugins/azure/azure_utils.py +2 -35
  14. metaflow/plugins/azure/blob_service_client_factory.py +4 -2
  15. metaflow/plugins/datastores/azure_storage.py +6 -6
  16. metaflow/plugins/datatools/s3/s3.py +1 -1
  17. metaflow/plugins/gcp/__init__.py +1 -0
  18. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +169 -0
  19. metaflow/plugins/gcp/gs_storage_client_factory.py +52 -1
  20. metaflow/plugins/kubernetes/kubernetes.py +85 -8
  21. metaflow/plugins/kubernetes/kubernetes_cli.py +24 -1
  22. metaflow/plugins/kubernetes/kubernetes_client.py +4 -1
  23. metaflow/plugins/kubernetes/kubernetes_decorator.py +49 -4
  24. metaflow/plugins/kubernetes/kubernetes_job.py +208 -201
  25. metaflow/plugins/kubernetes/kubernetes_jobsets.py +784 -0
  26. metaflow/plugins/timeout_decorator.py +2 -1
  27. metaflow/task.py +1 -12
  28. metaflow/tuple_util.py +27 -0
  29. metaflow/util.py +0 -15
  30. metaflow/version.py +1 -1
  31. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/METADATA +2 -2
  32. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/RECORD +36 -31
  33. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/LICENSE +0 -0
  34. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/WHEEL +0 -0
  35. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/entry_points.txt +0 -0
  36. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ import base64
2
+ import json
3
+ from json import JSONDecodeError
4
+
5
+
6
+ from metaflow.exception import MetaflowException
7
+ from metaflow.plugins.secrets import SecretsProvider
8
+ import re
9
+ from metaflow.plugins.gcp.gs_storage_client_factory import get_credentials
10
+ from metaflow.metaflow_config import GCP_SECRET_MANAGER_PREFIX
11
+
12
+
13
+ class MetaflowGcpSecretsManagerBadResponse(MetaflowException):
14
+ """Raised when the response from GCP Secrets Manager is not valid in some way"""
15
+
16
+
17
+ class MetaflowGcpSecretsManagerDuplicateKey(MetaflowException):
18
+ """Raised when the response from GCP Secrets Manager contains duplicate keys"""
19
+
20
+
21
+ class MetaflowGcpSecretsManagerJSONParseError(MetaflowException):
22
+ """Raised when the SecretString response from GCP Secrets Manager is not valid JSON"""
23
+
24
+
25
+ class MetaflowGcpSecretsManagerNotJSONObject(MetaflowException):
26
+ """Raised when the SecretString response from GCP Secrets Manager is not valid JSON dictionary"""
27
+
28
+
29
+ def _sanitize_key_as_env_var(key):
30
+ """
31
+ Sanitize a key as an environment variable name.
32
+ This is purely a convenience trade-off to cover common cases well, vs. introducing
33
+ ambiguities (e.g. did the final '_' come from '.', or '-' or is original?).
34
+
35
+ 1/27/2023(jackie):
36
+
37
+ We start with few rules and should *sparingly* add more over time.
38
+ Also, it's TBD whether all possible providers will share the same sanitization logic.
39
+ Therefore we will keep this function private for now
40
+ """
41
+ return key.replace("-", "_").replace(".", "_").replace("/", "_")
42
+
43
+
44
+ class GcpSecretManagerSecretsProvider(SecretsProvider):
45
+ TYPE = "gcp-secret-manager"
46
+
47
+ def get_secret_as_dict(self, secret_id, options={}, role=None):
48
+ """
49
+ Reads a secret from GCP Secrets Manager and returns it as a dictionary of environment variables.
50
+
51
+ If the secret contains a string payload ("SecretString"):
52
+ - if the `json` option is True:
53
+ Secret will be parsed as a JSON. If successfully parsed, AND the JSON contains a
54
+ top-level object, each entry K/V in the object will also be converted to an entry in the result. V will
55
+ always be casted to a string (if not already a string).
56
+ - If `json` option is False (default):
57
+ Will be returned as a single entry in the result, with the key being the last part after / in secret_id.
58
+
59
+ On GCP Secrets Manager, the secret payload is a binary blob. However, by default we interpret it as UTF8 encoded
60
+ string. To disable this, set the `binary` option to True, the binary will be base64 encoded in the result.
61
+
62
+ All keys in the result are sanitized to be more valid environment variable names. This is done on a best effort
63
+ basis. Further validation is expected to be done by the invoking @secrets decorator itself.
64
+
65
+ :param secret_id: GCP Secrets Manager secret ID
66
+ :param options: unused
67
+ :return: dict of environment variables. All keys and values are strings.
68
+ """
69
+ from google.cloud.secretmanager_v1.services.secret_manager_service import (
70
+ SecretManagerServiceClient,
71
+ )
72
+ from google.cloud.secretmanager_v1.services.secret_manager_service.transports import (
73
+ SecretManagerServiceTransport,
74
+ )
75
+
76
+ # Full secret id looks like projects/1234567890/secrets/mysecret/versions/latest
77
+ #
78
+ # We allow these forms of secret_id:
79
+ #
80
+ # 1. Full path like projects/1234567890/secrets/mysecret/versions/latest
81
+ # This is what you'd specify if you used to GCP SDK.
82
+ #
83
+ # 2. Full path but without the version like projects/1234567890/secrets/mysecret.
84
+ # This is what you see in the GCP console, makes it easier to copy & paste.
85
+ #
86
+ # 3. Simple string like mysecret
87
+ #
88
+ # 4. Simple string with /versions/<version> suffix like mysecret/versions/1
89
+
90
+ # The latter two forms require METAFLOW_GCP_SECRET_MANAGER_PREFIX to be set.
91
+
92
+ match_full = re.match(
93
+ r"^projects/\d+/secrets/([\w\-]+)(/versions/([\w\-]+))?$", secret_id
94
+ )
95
+ match_partial = re.match(r"^([\w\-]+)(/versions/[\w\-]+)?$", secret_id)
96
+ if match_full:
97
+ # Full path
98
+ env_var_name = match_full.group(1)
99
+ if match_full.group(3):
100
+ # With version specified
101
+ full_secret_name = secret_id
102
+ else:
103
+ # No version specified, use latest
104
+ full_secret_name = secret_id + "/versions/latest"
105
+ elif match_partial:
106
+ # Partial path, possibly with /versions/<version> suffix
107
+ env_var_name = secret_id
108
+ if not GCP_SECRET_MANAGER_PREFIX:
109
+ raise ValueError(
110
+ "Cannot use simple secret_id without setting METAFLOW_GCP_SECRET_MANAGER_PREFIX. %s"
111
+ % GCP_SECRET_MANAGER_PREFIX
112
+ )
113
+ if match_partial.group(2):
114
+ # With version specified
115
+ full_secret_name = "%s%s" % (GCP_SECRET_MANAGER_PREFIX, secret_id)
116
+ env_var_name = match_partial.group(1)
117
+ else:
118
+ # No version specified, use latest
119
+ full_secret_name = "%s%s/versions/latest" % (
120
+ GCP_SECRET_MANAGER_PREFIX,
121
+ secret_id,
122
+ )
123
+ else:
124
+ raise ValueError(
125
+ "Invalid secret_id: %s. Must be either a full path or a simple string."
126
+ % secret_id
127
+ )
128
+
129
+ result = {}
130
+
131
+ def _sanitize_and_add_entry_to_result(k, v):
132
+ # Two jobs - sanitize, and check for dupes
133
+ sanitized_k = _sanitize_key_as_env_var(k)
134
+ if sanitized_k in result:
135
+ raise MetaflowGcpSecretsManagerDuplicateKey(
136
+ "Duplicate key in secret: '%s' (sanitizes to '%s')"
137
+ % (k, sanitized_k)
138
+ )
139
+ result[sanitized_k] = v
140
+
141
+ credentials, _ = get_credentials(
142
+ scopes=SecretManagerServiceTransport.AUTH_SCOPES
143
+ )
144
+ client = SecretManagerServiceClient(credentials=credentials)
145
+ response = client.access_secret_version(request={"name": full_secret_name})
146
+ payload_str = response.payload.data.decode("UTF-8")
147
+ if options.get("json", False):
148
+ obj = json.loads(payload_str)
149
+ if type(obj) == dict:
150
+ for k, v in obj.items():
151
+ # We try to make it work here - cast to string always
152
+ _sanitize_and_add_entry_to_result(k, str(v))
153
+ else:
154
+ raise MetaflowGcpSecretsManagerNotJSONObject(
155
+ "Secret string is a JSON, but not an object (dict-like) - actual type %s."
156
+ % type(obj)
157
+ )
158
+ else:
159
+ if options.get("env_var_name"):
160
+ env_var_name = options["env_var_name"]
161
+
162
+ if options.get("binary", False):
163
+ _sanitize_and_add_entry_to_result(
164
+ env_var_name, base64.b64encode(response.payload.data)
165
+ )
166
+ else:
167
+ _sanitize_and_add_entry_to_result(env_var_name, payload_str)
168
+
169
+ return result
@@ -8,7 +8,7 @@ def _get_cache_key():
8
8
  return os.getpid(), threading.get_ident()
9
9
 
10
10
 
11
- def get_gs_storage_client():
11
+ def _get_gs_storage_client_default():
12
12
  cache_key = _get_cache_key()
13
13
  if cache_key not in _client_cache:
14
14
  from google.cloud import storage
@@ -19,3 +19,54 @@ def get_gs_storage_client():
19
19
  credentials=credentials, project=project_id
20
20
  )
21
21
  return _client_cache[cache_key]
22
+
23
+
24
+ class GcpDefaultClientProvider(object):
25
+ name = "gcp-default"
26
+
27
+ @staticmethod
28
+ def get_gs_storage_client(*args, **kwargs):
29
+ return _get_gs_storage_client_default()
30
+
31
+ @staticmethod
32
+ def get_credentials(scopes, *args, **kwargs):
33
+ import google.auth
34
+
35
+ return google.auth.default(scopes=scopes)
36
+
37
+
38
+ cached_provider_class = None
39
+
40
+
41
+ def get_gs_storage_client():
42
+ global cached_provider_class
43
+ if cached_provider_class is None:
44
+ from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
45
+ from metaflow.plugins import GCP_CLIENT_PROVIDERS
46
+
47
+ for p in GCP_CLIENT_PROVIDERS:
48
+ if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
49
+ cached_provider_class = p
50
+ break
51
+ else:
52
+ raise ValueError(
53
+ "Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
54
+ )
55
+ return cached_provider_class.get_gs_storage_client()
56
+
57
+
58
+ def get_credentials(scopes, *args, **kwargs):
59
+ global cached_provider_class
60
+ if cached_provider_class is None:
61
+ from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
62
+ from metaflow.plugins import GCP_CLIENT_PROVIDERS
63
+
64
+ for p in GCP_CLIENT_PROVIDERS:
65
+ if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
66
+ cached_provider_class = p
67
+ break
68
+ else:
69
+ raise ValueError(
70
+ "Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
71
+ )
72
+ return cached_provider_class.get_credentials(scopes, *args, **kwargs)
@@ -3,9 +3,9 @@ import math
3
3
  import os
4
4
  import re
5
5
  import shlex
6
+ import copy
6
7
  import time
7
8
  from typing import Dict, List, Optional
8
- import uuid
9
9
  from uuid import uuid4
10
10
 
11
11
  from metaflow import current, util
@@ -27,8 +27,11 @@ from metaflow.metaflow_config import (
27
27
  DATASTORE_SYSROOT_S3,
28
28
  DATATOOLS_S3ROOT,
29
29
  DEFAULT_AWS_CLIENT_PROVIDER,
30
+ DEFAULT_GCP_CLIENT_PROVIDER,
30
31
  DEFAULT_METADATA,
31
32
  DEFAULT_SECRETS_BACKEND_TYPE,
33
+ GCP_SECRET_MANAGER_PREFIX,
34
+ AZURE_KEY_VAULT_PREFIX,
32
35
  KUBERNETES_FETCH_EC2_METADATA,
33
36
  KUBERNETES_LABELS,
34
37
  KUBERNETES_SANDBOX_INIT_SCRIPT,
@@ -66,6 +69,12 @@ class KubernetesKilledException(MetaflowException):
66
69
  headline = "Kubernetes Batch job killed"
67
70
 
68
71
 
72
+ def _extract_labels_and_annotations_from_job_spec(job_spec):
73
+ annotations = job_spec.template.metadata.annotations
74
+ labels = job_spec.template.metadata.labels
75
+ return copy.copy(annotations), copy.copy(labels)
76
+
77
+
69
78
  class Kubernetes(object):
70
79
  def __init__(
71
80
  self,
@@ -140,9 +149,64 @@ class Kubernetes(object):
140
149
  return shlex.split('bash -c "%s"' % cmd_str)
141
150
 
142
151
  def launch_job(self, **kwargs):
143
- self._job = self.create_job(**kwargs).execute()
152
+ if (
153
+ "num_parallel" in kwargs
154
+ and kwargs["num_parallel"]
155
+ and int(kwargs["num_parallel"]) > 0
156
+ ):
157
+ job = self.create_job_object(**kwargs)
158
+ spec = job.create_job_spec()
159
+ # `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
160
+ # This will be modified by the KubernetesJobSet object
161
+ annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
162
+ self._job = self.create_jobset(
163
+ job_spec=spec,
164
+ run_id=kwargs["run_id"],
165
+ step_name=kwargs["step_name"],
166
+ task_id=kwargs["task_id"],
167
+ namespace=kwargs["namespace"],
168
+ env=kwargs["env"],
169
+ num_parallel=kwargs["num_parallel"],
170
+ port=kwargs["port"],
171
+ annotations=annotations,
172
+ labels=labels,
173
+ ).execute()
174
+ else:
175
+ kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
176
+ self._job = self.create_job_object(**kwargs).create().execute()
177
+
178
+ def create_jobset(
179
+ self,
180
+ job_spec=None,
181
+ run_id=None,
182
+ step_name=None,
183
+ task_id=None,
184
+ namespace=None,
185
+ env=None,
186
+ num_parallel=None,
187
+ port=None,
188
+ annotations=None,
189
+ labels=None,
190
+ ):
191
+ if env is None:
192
+ env = {}
193
+
194
+ _prefix = str(uuid4())[:6]
195
+ js = KubernetesClient().jobset(
196
+ name="js-%s" % _prefix,
197
+ run_id=run_id,
198
+ task_id=task_id,
199
+ step_name=step_name,
200
+ namespace=namespace,
201
+ labels=self._get_labels(labels),
202
+ annotations=annotations,
203
+ num_parallel=num_parallel,
204
+ job_spec=job_spec,
205
+ port=port,
206
+ )
207
+ return js
144
208
 
145
- def create_job(
209
+ def create_job_object(
146
210
  self,
147
211
  flow_name,
148
212
  run_id,
@@ -176,14 +240,15 @@ class Kubernetes(object):
176
240
  labels=None,
177
241
  shared_memory=None,
178
242
  port=None,
243
+ name_pattern=None,
244
+ num_parallel=None,
179
245
  ):
180
246
  if env is None:
181
247
  env = {}
182
-
183
248
  job = (
184
249
  KubernetesClient()
185
250
  .job(
186
- generate_name="t-{uid}-".format(uid=str(uuid4())[:8]),
251
+ generate_name=name_pattern,
187
252
  namespace=namespace,
188
253
  service_account=service_account,
189
254
  secrets=secrets,
@@ -217,6 +282,7 @@ class Kubernetes(object):
217
282
  persistent_volume_claims=persistent_volume_claims,
218
283
  shared_memory=shared_memory,
219
284
  port=port,
285
+ num_parallel=num_parallel,
220
286
  )
221
287
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
222
288
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -243,10 +309,19 @@ class Kubernetes(object):
243
309
  .environment_variable(
244
310
  "METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER
245
311
  )
312
+ .environment_variable(
313
+ "METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER
314
+ )
246
315
  .environment_variable(
247
316
  "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
248
317
  AWS_SECRETS_MANAGER_DEFAULT_REGION,
249
318
  )
319
+ .environment_variable(
320
+ "METAFLOW_GCP_SECRET_MANAGER_PREFIX", GCP_SECRET_MANAGER_PREFIX
321
+ )
322
+ .environment_variable(
323
+ "METAFLOW_AZURE_KEY_VAULT_PREFIX", AZURE_KEY_VAULT_PREFIX
324
+ )
250
325
  .environment_variable("METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL)
251
326
  .environment_variable(
252
327
  "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT",
@@ -332,6 +407,9 @@ class Kubernetes(object):
332
407
  .label("app.kubernetes.io/part-of", "metaflow")
333
408
  )
334
409
 
410
+ return job
411
+
412
+ def create_k8sjob(self, job):
335
413
  return job.create()
336
414
 
337
415
  def wait(self, stdout_location, stderr_location, echo=None):
@@ -366,7 +444,7 @@ class Kubernetes(object):
366
444
  t = time.time()
367
445
  time.sleep(update_delay(time.time() - start_time))
368
446
 
369
- prefix = b"[%s] " % util.to_bytes(self._job.id)
447
+ _make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id)
370
448
 
371
449
  stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
372
450
  stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
@@ -376,7 +454,7 @@ class Kubernetes(object):
376
454
 
377
455
  # 2) Tail logs until the job has finished
378
456
  tail_logs(
379
- prefix=prefix,
457
+ prefix=_make_prefix(),
380
458
  stdout_tail=stdout_tail,
381
459
  stderr_tail=stderr_tail,
382
460
  echo=echo,
@@ -392,7 +470,6 @@ class Kubernetes(object):
392
470
  # exists prior to calling S3Tail and note the user about
393
471
  # truncated logs if it doesn't.
394
472
  # TODO : For hard crashes, we can fetch logs from the pod.
395
-
396
473
  if self._job.has_failed:
397
474
  exit_code, reason = self._job.reason
398
475
  msg = next(
@@ -7,11 +7,17 @@ from metaflow import JSONTypeClass, util
7
7
  from metaflow._vendor import click
8
8
  from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
9
9
  from metaflow.metadata.util import sync_local_metadata_from_datastore
10
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
10
11
  from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
11
12
  from metaflow.mflog import TASK_LOG_SOURCE
12
13
  import metaflow.tracing as tracing
13
14
 
14
- from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list
15
+ from .kubernetes import (
16
+ Kubernetes,
17
+ KubernetesKilledException,
18
+ parse_kube_keyvalue_list,
19
+ KubernetesException,
20
+ )
15
21
  from .kubernetes_decorator import KubernetesDecorator
16
22
 
17
23
 
@@ -109,6 +115,15 @@ def kubernetes():
109
115
  )
110
116
  @click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
111
117
  @click.option("--port", default=None, help="Port number to expose from the container")
118
+ @click.option(
119
+ "--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
120
+ )
121
+ @click.option(
122
+ "--num-parallel",
123
+ default=None,
124
+ type=int,
125
+ help="Number of parallel nodes to run as a multi-node job.",
126
+ )
112
127
  @click.pass_context
113
128
  def step(
114
129
  ctx,
@@ -136,6 +151,7 @@ def step(
136
151
  tolerations=None,
137
152
  shared_memory=None,
138
153
  port=None,
154
+ num_parallel=None,
139
155
  **kwargs
140
156
  ):
141
157
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -167,6 +183,12 @@ def step(
167
183
  kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
168
184
  env.update(split_vars)
169
185
 
186
+ if num_parallel is not None and num_parallel <= 1:
187
+ raise KubernetesException(
188
+ "Using @parallel with `num_parallel` <= 1 is not supported with Kubernetes. "
189
+ "Please set the value of `num_parallel` to be greater than 1."
190
+ )
191
+
170
192
  # Set retry policy.
171
193
  retry_count = int(kwargs.get("retry_count", 0))
172
194
  retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
@@ -251,6 +273,7 @@ def step(
251
273
  tolerations=tolerations,
252
274
  shared_memory=shared_memory,
253
275
  port=port,
276
+ num_parallel=num_parallel,
254
277
  )
255
278
  except Exception as e:
256
279
  traceback.print_exc(chain=False)
@@ -4,7 +4,7 @@ import time
4
4
 
5
5
  from metaflow.exception import MetaflowException
6
6
 
7
- from .kubernetes_job import KubernetesJob
7
+ from .kubernetes_job import KubernetesJob, KubernetesJobSet
8
8
 
9
9
 
10
10
  CLIENT_REFRESH_INTERVAL_SECONDS = 300
@@ -61,5 +61,8 @@ class KubernetesClient(object):
61
61
 
62
62
  return self._client
63
63
 
64
+ def jobset(self, **kwargs):
65
+ return KubernetesJobSet(self, **kwargs)
66
+
64
67
  def job(self, **kwargs):
65
68
  return KubernetesJob(self, **kwargs)
@@ -32,6 +32,8 @@ from metaflow.sidecar import Sidecar
32
32
 
33
33
  from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
34
34
  from .kubernetes import KubernetesException, parse_kube_keyvalue_list
35
+ from metaflow.unbounded_foreach import UBF_CONTROL
36
+ from .kubernetes_jobsets import TaskIdConstructor
35
37
 
36
38
  try:
37
39
  unicode
@@ -239,11 +241,15 @@ class KubernetesDecorator(StepDecorator):
239
241
  "Kubernetes. Please use one or the other.".format(step=step)
240
242
  )
241
243
 
242
- for deco in decos:
243
- if getattr(deco, "IS_PARALLEL", False):
244
- raise KubernetesException(
245
- "@kubernetes does not support parallel execution currently."
244
+ if any([deco.name == "parallel" for deco in decos]) and any(
245
+ [deco.name == "catch" for deco in decos]
246
+ ):
247
+ raise MetaflowException(
248
+ "Step *{step}* contains a @parallel decorator "
249
+ "with the @catch decorator. @catch is not supported with @parallel on Kubernetes.".format(
250
+ step=step
246
251
  )
252
+ )
247
253
 
248
254
  # Set run time limit for the Kubernetes job.
249
255
  self.run_time_limit = get_run_time_limit_for_task(decos)
@@ -421,6 +427,10 @@ class KubernetesDecorator(StepDecorator):
421
427
  "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME"
422
428
  ]
423
429
  meta["kubernetes-node-ip"] = os.environ["METAFLOW_KUBERNETES_NODE_IP"]
430
+ if os.environ.get("METAFLOW_KUBERNETES_JOBSET_NAME"):
431
+ meta["kubernetes-jobset-name"] = os.environ[
432
+ "METAFLOW_KUBERNETES_JOBSET_NAME"
433
+ ]
424
434
 
425
435
  # TODO (savin): Introduce equivalent support for Microsoft Azure and
426
436
  # Google Cloud Platform
@@ -453,6 +463,24 @@ class KubernetesDecorator(StepDecorator):
453
463
  self._save_logs_sidecar = Sidecar("save_logs_periodically")
454
464
  self._save_logs_sidecar.start()
455
465
 
466
+ num_parallel = None
467
+ if hasattr(flow, "_parallel_ubf_iter"):
468
+ num_parallel = flow._parallel_ubf_iter.num_parallel
469
+
470
+ if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
471
+ control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
472
+ num_parallel
473
+ )
474
+ mapper_task_ids = [control_task_id] + worker_task_ids
475
+ flow._control_mapper_tasks = [
476
+ "%s/%s/%s" % (run_id, step_name, mapper_task_id)
477
+ for mapper_task_id in mapper_task_ids
478
+ ]
479
+ flow._control_task_is_mapper_zero = True
480
+
481
+ if num_parallel and num_parallel > 1:
482
+ _setup_multinode_environment()
483
+
456
484
  def task_finished(
457
485
  self, step_name, flow, graph, is_task_ok, retry_count, max_retries
458
486
  ):
@@ -486,3 +514,20 @@ class KubernetesDecorator(StepDecorator):
486
514
  cls.package_url, cls.package_sha = flow_datastore.save_data(
487
515
  [package.blob], len_hint=1
488
516
  )[0]
517
+
518
+
519
+ def _setup_multinode_environment():
520
+ import socket
521
+
522
+ os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
523
+ os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
524
+ if os.environ.get("CONTROL_INDEX") is not None:
525
+ os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
526
+ elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
527
+ os.environ["MF_PARALLEL_NODE_INDEX"] = str(
528
+ int(os.environ["WORKER_REPLICA_INDEX"]) + 1
529
+ )
530
+ else:
531
+ raise MetaflowException(
532
+ "Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found"
533
+ )