ob-metaflow 2.10.7.4__py2.py3-none-any.whl → 2.10.9.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (57) hide show
  1. metaflow/cards.py +2 -0
  2. metaflow/decorators.py +1 -1
  3. metaflow/metaflow_config.py +4 -0
  4. metaflow/plugins/__init__.py +4 -0
  5. metaflow/plugins/airflow/airflow_cli.py +1 -1
  6. metaflow/plugins/argo/argo_workflows.py +5 -0
  7. metaflow/plugins/argo/argo_workflows_cli.py +1 -1
  8. metaflow/plugins/aws/aws_utils.py +1 -1
  9. metaflow/plugins/aws/batch/batch.py +4 -0
  10. metaflow/plugins/aws/batch/batch_cli.py +3 -0
  11. metaflow/plugins/aws/batch/batch_client.py +40 -11
  12. metaflow/plugins/aws/batch/batch_decorator.py +1 -0
  13. metaflow/plugins/aws/step_functions/step_functions.py +1 -0
  14. metaflow/plugins/aws/step_functions/step_functions_cli.py +1 -1
  15. metaflow/plugins/azure/azure_exceptions.py +1 -1
  16. metaflow/plugins/cards/card_cli.py +413 -28
  17. metaflow/plugins/cards/card_client.py +16 -7
  18. metaflow/plugins/cards/card_creator.py +228 -0
  19. metaflow/plugins/cards/card_datastore.py +124 -26
  20. metaflow/plugins/cards/card_decorator.py +40 -86
  21. metaflow/plugins/cards/card_modules/base.html +12 -0
  22. metaflow/plugins/cards/card_modules/basic.py +74 -8
  23. metaflow/plugins/cards/card_modules/bundle.css +1 -170
  24. metaflow/plugins/cards/card_modules/card.py +65 -0
  25. metaflow/plugins/cards/card_modules/components.py +446 -81
  26. metaflow/plugins/cards/card_modules/convert_to_native_type.py +9 -3
  27. metaflow/plugins/cards/card_modules/main.js +250 -21
  28. metaflow/plugins/cards/card_modules/test_cards.py +117 -0
  29. metaflow/plugins/cards/card_resolver.py +0 -2
  30. metaflow/plugins/cards/card_server.py +361 -0
  31. metaflow/plugins/cards/component_serializer.py +506 -42
  32. metaflow/plugins/cards/exception.py +20 -1
  33. metaflow/plugins/datastores/azure_storage.py +1 -2
  34. metaflow/plugins/datastores/gs_storage.py +1 -2
  35. metaflow/plugins/datastores/s3_storage.py +2 -1
  36. metaflow/plugins/datatools/s3/s3.py +24 -11
  37. metaflow/plugins/env_escape/client.py +2 -12
  38. metaflow/plugins/env_escape/client_modules.py +18 -14
  39. metaflow/plugins/env_escape/server.py +18 -11
  40. metaflow/plugins/env_escape/utils.py +12 -0
  41. metaflow/plugins/gcp/gs_exceptions.py +1 -1
  42. metaflow/plugins/gcp/gs_utils.py +1 -1
  43. metaflow/plugins/kubernetes/kubernetes.py +43 -6
  44. metaflow/plugins/kubernetes/kubernetes_cli.py +40 -1
  45. metaflow/plugins/kubernetes/kubernetes_decorator.py +73 -6
  46. metaflow/plugins/kubernetes/kubernetes_job.py +536 -161
  47. metaflow/plugins/pypi/conda_environment.py +5 -6
  48. metaflow/plugins/pypi/pip.py +2 -2
  49. metaflow/plugins/pypi/utils.py +15 -0
  50. metaflow/task.py +1 -0
  51. metaflow/version.py +1 -1
  52. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/METADATA +1 -1
  53. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/RECORD +57 -55
  54. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/LICENSE +0 -0
  55. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/WHEEL +0 -0
  56. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/entry_points.txt +0 -0
  57. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/top_level.txt +0 -0
@@ -116,7 +116,7 @@ class UnresolvableDatastoreException(MetaflowException):
116
116
  super(UnresolvableDatastoreException, self).__init__(msg)
117
117
 
118
118
 
119
- class IncorrectArguementException(MetaflowException):
119
+ class IncorrectArgumentException(MetaflowException):
120
120
  headline = (
121
121
  "`get_cards` function requires a `Task` object or pathspec as an argument"
122
122
  )
@@ -138,3 +138,22 @@ class IncorrectPathspecException(MetaflowException):
138
138
  % pthspec
139
139
  )
140
140
  super().__init__(msg=msg, lineno=None)
141
+
142
+
143
+ class ComponentOverwriteNotSupportedException(MetaflowException):
144
+ headline = "Component overwrite is not supported"
145
+
146
+ def __init__(self, component_id, card_id, card_type):
147
+ id_str = ""
148
+ if card_id is not None:
149
+ id_str = "id='%s'" % card_id
150
+ msg = (
151
+ "Card component overwrite is not supported. "
152
+ "Component with id %s already exists in the @card(type='%s', %s). \n"
153
+ "Instead of calling `current.card.components[ID] = MyComponent`. "
154
+ "You can overwrite the entire component Array by calling "
155
+ "`current.card.components = [MyComponent]`"
156
+ ) % (component_id, card_type, id_str)
157
+ super().__init__(
158
+ msg=msg,
159
+ )
@@ -250,9 +250,8 @@ class _AzureRootClient(object):
250
250
  class AzureStorage(DataStoreStorage):
251
251
  TYPE = "azure"
252
252
 
253
+ @check_azure_deps
253
254
  def __init__(self, root=None):
254
- # cannot decorate __init__... invoke it with dummy decoratee
255
- check_azure_deps(lambda: 0)
256
255
  super(AzureStorage, self).__init__(root)
257
256
  self._tmproot = ARTIFACT_LOCALROOT
258
257
  self._default_scope_token = None
@@ -145,9 +145,8 @@ class _GSRootClient(object):
145
145
  class GSStorage(DataStoreStorage):
146
146
  TYPE = "gs"
147
147
 
148
+ @check_gs_deps
148
149
  def __init__(self, root=None):
149
- # cannot decorate __init__... invoke it with dummy decoratee
150
- check_gs_deps(lambda: 0)
151
150
  super(GSStorage, self).__init__(root)
152
151
  self._tmproot = ARTIFACT_LOCALROOT
153
152
  self._root_client = None
@@ -2,7 +2,7 @@ import os
2
2
 
3
3
  from itertools import starmap
4
4
 
5
- from metaflow.plugins.datatools.s3.s3 import S3, S3Client, S3PutObject
5
+ from metaflow.plugins.datatools.s3.s3 import S3, S3Client, S3PutObject, check_s3_deps
6
6
  from metaflow.metaflow_config import DATASTORE_SYSROOT_S3, ARTIFACT_LOCALROOT
7
7
  from metaflow.datastore.datastore_storage import CloseAfterUse, DataStoreStorage
8
8
 
@@ -18,6 +18,7 @@ except:
18
18
  class S3Storage(DataStoreStorage):
19
19
  TYPE = "s3"
20
20
 
21
+ @check_s3_deps
21
22
  def __init__(self, root=None):
22
23
  super(S3Storage, self).__init__(root)
23
24
  self.s3_client = S3Client()
@@ -51,15 +51,25 @@ from .s3util import (
51
51
  if TYPE_CHECKING:
52
52
  from metaflow.client import Run
53
53
 
54
- try:
55
- import boto3
56
- from boto3.s3.transfer import TransferConfig
57
54
 
58
- DOWNLOAD_FILE_THRESHOLD = 2 * TransferConfig().multipart_threshold
59
- DOWNLOAD_MAX_CHUNK = 2 * 1024 * 1024 * 1024 - 1
60
- boto_found = True
61
- except:
62
- boto_found = False
55
+ def _check_and_init_s3_deps():
56
+ try:
57
+ import boto3
58
+ from boto3.s3.transfer import TransferConfig
59
+ except (ImportError, ModuleNotFoundError):
60
+ raise MetaflowException("You need to install 'boto3' in order to use S3.")
61
+
62
+
63
+ def check_s3_deps(func):
64
+ """The decorated function checks S3 dependencies (as needed for AWS S3 storage backend).
65
+ This includes boto3.
66
+ """
67
+
68
+ def _inner_func(*args, **kwargs):
69
+ _check_and_init_s3_deps()
70
+ return func(*args, **kwargs)
71
+
72
+ return _inner_func
63
73
 
64
74
 
65
75
  TEST_INJECT_RETRYABLE_FAILURES = int(
@@ -498,6 +508,7 @@ class S3(object):
498
508
  def get_root_from_config(cls, echo, create_on_absent=True):
499
509
  return DATATOOLS_S3ROOT
500
510
 
511
+ @check_s3_deps
501
512
  def __init__(
502
513
  self,
503
514
  tmproot: str = TEMPDIR,
@@ -508,9 +519,6 @@ class S3(object):
508
519
  encryption: Optional[str] = S3_SERVER_SIDE_ENCRYPTION,
509
520
  **kwargs
510
521
  ):
511
- if not boto_found:
512
- raise MetaflowException("You need to install 'boto3' in order to use S3.")
513
-
514
522
  if run:
515
523
  # 1. use a (current) run ID with optional customizations
516
524
  if DATATOOLS_S3ROOT is None:
@@ -875,6 +883,11 @@ class S3(object):
875
883
  `S3Object`
876
884
  An S3Object corresponding to the object requested.
877
885
  """
886
+ from boto3.s3.transfer import TransferConfig
887
+
888
+ DOWNLOAD_FILE_THRESHOLD = 2 * TransferConfig().multipart_threshold
889
+ DOWNLOAD_MAX_CHUNK = 2 * 1024 * 1024 * 1024 - 1
890
+
878
891
  url, r = self._url_and_range(key)
879
892
  src = urlparse(url)
880
893
 
@@ -37,6 +37,7 @@ from .data_transferer import DataTransferer, ObjReference
37
37
  from .exception_transferer import load_exception
38
38
  from .override_decorators import LocalAttrOverride, LocalException, LocalOverride
39
39
  from .stub import create_class
40
+ from .utils import get_canonical_name
40
41
 
41
42
  BIND_TIMEOUT = 0.1
42
43
  BIND_RETRY = 0
@@ -336,7 +337,7 @@ class Client(object):
336
337
  def get_local_class(self, name, obj_id=None):
337
338
  # Gets (and creates if needed), the class mapping to the remote
338
339
  # class of name 'name'.
339
- name = self._get_canonical_name(name)
340
+ name = get_canonical_name(name, self._aliases)
340
341
  if name == "function":
341
342
  # Special handling of pickled functions. We create a new class that
342
343
  # simply has a __call__ method that will forward things back to
@@ -398,17 +399,6 @@ class Client(object):
398
399
  local_instance = local_class(self, remote_class_name, obj_id)
399
400
  return local_instance
400
401
 
401
- def _get_canonical_name(self, name):
402
- # We look at the aliases looking for the most specific match first
403
- base_name = self._aliases.get(name)
404
- if base_name is not None:
405
- return base_name
406
- for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
407
- base_name = self._aliases.get(name[:idx])
408
- if base_name is not None:
409
- return ".".join([base_name, name[idx + 1 :]])
410
- return name
411
-
412
402
  def _communicate(self, msg):
413
403
  if os.getpid() != self._active_pid:
414
404
  raise RuntimeError(
@@ -8,6 +8,7 @@ import sys
8
8
  from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL
9
9
  from .client import Client
10
10
  from .override_decorators import LocalException
11
+ from .utils import get_canonical_name
11
12
 
12
13
 
13
14
  def _clean_client(client):
@@ -23,6 +24,7 @@ class _WrappedModule(object):
23
24
  r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605
24
25
  )
25
26
  self._exports = {}
27
+ self._aliases = exports["aliases"]
26
28
  for k in ("classes", "functions", "values"):
27
29
  result = []
28
30
  for item in exports[k]:
@@ -43,6 +45,11 @@ class _WrappedModule(object):
43
45
  return self._prefix
44
46
  if name in ("__file__", "__path__"):
45
47
  return self._client.name
48
+
49
+ # Make the name canonical because the prefix is also canonical.
50
+ name = get_canonical_name(self._prefix + "." + name, self._aliases)[
51
+ len(self._prefix) + 1 :
52
+ ]
46
53
  if name in self._exports["classes"]:
47
54
  # We load classes lazily
48
55
  return self._client.get_local_class("%s.%s" % (self._prefix, name))
@@ -87,6 +94,7 @@ class _WrappedModule(object):
87
94
  "_client",
88
95
  "_exports",
89
96
  "_exception_classes",
97
+ "_aliases",
90
98
  ):
91
99
  object.__setattr__(self, name, value)
92
100
  return
@@ -95,6 +103,11 @@ class _WrappedModule(object):
95
103
  # module when loading
96
104
  object.__setattr__(self, name, value)
97
105
  return
106
+
107
+ # Make the name canonical because the prefix is also canonical.
108
+ name = get_canonical_name(self._prefix + "." + name, self._aliases)[
109
+ len(self._prefix) + 1 :
110
+ ]
98
111
  if name in self._exports["values"]:
99
112
  self._client.stub_request(
100
113
  None, OP_SETVAL, "%s.%s" % (self._prefix, name), value
@@ -126,7 +139,7 @@ class ModuleImporter(object):
126
139
 
127
140
  def find_module(self, fullname, path=None):
128
141
  if self._handled_modules is not None:
129
- if fullname in self._handled_modules:
142
+ if get_canonical_name(fullname, self._aliases) in self._handled_modules:
130
143
  return self
131
144
  return None
132
145
  if any([fullname.startswith(prefix) for prefix in self._module_prefixes]):
@@ -224,24 +237,15 @@ class ModuleImporter(object):
224
237
  self._handled_modules[prefix] = _WrappedModule(
225
238
  self, prefix, exports, formed_exception_classes, self._client
226
239
  )
227
- fullname = self._get_canonical_name(fullname)
228
- module = self._handled_modules.get(fullname)
240
+ canonical_fullname = get_canonical_name(fullname, self._aliases)
241
+ # Modules are created canonically but we need to return something for any
242
+ # of the aliases.
243
+ module = self._handled_modules.get(canonical_fullname)
229
244
  if module is None:
230
245
  raise ImportError
231
246
  sys.modules[fullname] = module
232
247
  return module
233
248
 
234
- def _get_canonical_name(self, name):
235
- # We look at the aliases looking for the most specific match first
236
- base_name = self._aliases.get(name)
237
- if base_name is not None:
238
- return base_name
239
- for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
240
- base_name = self._aliases.get(name[:idx])
241
- if base_name is not None:
242
- return ".".join([base_name, name[idx + 1 :]])
243
- return name
244
-
245
249
 
246
250
  def create_modules(python_executable, pythonpath, max_pickle_version, path, prefixes):
247
251
  # This is an extra verification to make sure we are not trying to use the
@@ -53,7 +53,7 @@ from .override_decorators import (
53
53
  RemoteExceptionSerializer,
54
54
  )
55
55
  from .exception_transferer import dump_exception
56
- from .utils import get_methods
56
+ from .utils import get_methods, get_canonical_name
57
57
 
58
58
  BIND_TIMEOUT = 0.1
59
59
  BIND_RETRY = 1
@@ -61,7 +61,6 @@ BIND_RETRY = 1
61
61
 
62
62
  class Server(object):
63
63
  def __init__(self, config_dir, max_pickle_version):
64
-
65
64
  self._max_pickle_version = data_transferer.defaultProtocol = max_pickle_version
66
65
  try:
67
66
  mappings = importlib.import_module(".server_mappings", package=config_dir)
@@ -108,6 +107,11 @@ class Server(object):
108
107
  for alias in aliases:
109
108
  a = self._aliases.setdefault(alias, base_name)
110
109
  if a != base_name:
110
+ # Technically we could have a that aliases b and b that aliases c
111
+ # and then a that aliases c. This would error out in that case
112
+ # even though it is valid. It is easy for the user to get around
113
+ # this by listing aliases in the same order so we don't support
114
+ # it for now.
111
115
  raise ValueError(
112
116
  "%s is an alias to both %s and %s" % (alias, base_name, a)
113
117
  )
@@ -155,12 +159,13 @@ class Server(object):
155
159
  parent_to_child = {}
156
160
 
157
161
  for ex_name, ex_cls in self._known_exceptions.items():
162
+ ex_name_canonical = get_canonical_name(ex_name, self._aliases)
158
163
  parents = []
159
164
  for base in ex_cls.__mro__[1:]:
160
165
  if base is object:
161
166
  raise ValueError(
162
- "Exported exceptions not rooted in a builtin exception are not supported: %s"
163
- % ex_name
167
+ "Exported exceptions not rooted in a builtin exception "
168
+ "are not supported: %s." % ex_name
164
169
  )
165
170
  if base.__module__ == "builtins":
166
171
  # We found our base exception
@@ -168,17 +173,19 @@ class Server(object):
168
173
  break
169
174
  else:
170
175
  fqn = ".".join([base.__module__, base.__name__])
171
- if fqn in self._known_exceptions:
172
- parents.append(fqn)
173
- children = parent_to_child.setdefault(fqn, [])
174
- children.append(ex_name)
176
+ canonical_fqn = get_canonical_name(fqn, self._aliases)
177
+ if canonical_fqn in self._known_exceptions:
178
+ parents.append(canonical_fqn)
179
+ children = parent_to_child.setdefault(canonical_fqn, [])
180
+ children.append(ex_name_canonical)
175
181
  else:
176
182
  raise ValueError(
177
183
  "Exported exception %s has non exported and non builtin parent "
178
- "exception: %s" % (ex_name, fqn)
184
+ "exception: %s. Known exceptions: %s"
185
+ % (ex_name, fqn, str(self._known_exceptions))
179
186
  )
180
- name_to_parent_count[ex_name] = len(parents) - 1
181
- name_to_parents[ex_name] = parents
187
+ name_to_parent_count[ex_name_canonical] = len(parents) - 1
188
+ name_to_parents[ex_name_canonical] = parents
182
189
 
183
190
  # We now form the exceptions and put them in self._known_exceptions in
184
191
  # the proper order (topologically)
@@ -20,3 +20,15 @@ def get_methods(class_object):
20
20
  elif isinstance(attribute, classmethod):
21
21
  all_methods["___c___%s" % name] = inspect.getdoc(attribute)
22
22
  return all_methods
23
+
24
+
25
+ def get_canonical_name(name, aliases):
26
+ # We look at the aliases looking for the most specific match first
27
+ base_name = aliases.get(name)
28
+ if base_name is not None:
29
+ return base_name
30
+ for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
31
+ base_name = aliases.get(name[:idx])
32
+ if base_name is not None:
33
+ return ".".join([base_name, name[idx + 1 :]])
34
+ return name
@@ -2,4 +2,4 @@ from metaflow.exception import MetaflowException
2
2
 
3
3
 
4
4
  class MetaflowGSPackageError(MetaflowException):
5
- headline = "Missing required packages google-cloud-storage google-auth"
5
+ headline = "Missing required packages 'google-cloud-storage' and 'google-auth'"
@@ -34,7 +34,7 @@ def _check_and_init_gs_deps():
34
34
 
35
35
 
36
36
  def check_gs_deps(func):
37
- """The decorated function checks GS dependencies (as needed for Azure storage backend). This includes
37
+ """The decorated function checks GS dependencies (as needed for Google Cloud storage backend). This includes
38
38
  various GCP SDK packages, as well as a Python version of >=3.7
39
39
  """
40
40
 
@@ -4,6 +4,7 @@ import os
4
4
  import re
5
5
  import shlex
6
6
  import time
7
+ import copy
7
8
  from typing import Dict, List, Optional
8
9
  import uuid
9
10
  from uuid import uuid4
@@ -174,6 +175,10 @@ class Kubernetes(object):
174
175
  persistent_volume_claims=None,
175
176
  tolerations=None,
176
177
  labels=None,
178
+ annotations=None,
179
+ num_parallel=0,
180
+ attrs={},
181
+ port=None,
177
182
  ):
178
183
  if env is None:
179
184
  env = {}
@@ -213,6 +218,9 @@ class Kubernetes(object):
213
218
  tmpfs_size=tmpfs_size,
214
219
  tmpfs_path=tmpfs_path,
215
220
  persistent_volume_claims=persistent_volume_claims,
221
+ num_parallel=num_parallel,
222
+ attrs=attrs,
223
+ port=port,
216
224
  )
217
225
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
218
226
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -266,6 +274,7 @@ class Kubernetes(object):
266
274
  # see get_datastore_root_from_config in datastore/local.py).
267
275
  )
268
276
 
277
+ self.num_parallel = num_parallel
269
278
  # Temporary passing of *some* environment variables. Do not rely on this
270
279
  # mechanism as it will be removed in the near future
271
280
  for k, v in config_values():
@@ -341,7 +350,7 @@ class Kubernetes(object):
341
350
  sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
342
351
  return 0.5 + sigmoid * 30.0
343
352
 
344
- def wait_for_launch(job):
353
+ def wait_for_launch(job, child_jobs):
345
354
  status = job.status
346
355
  echo(
347
356
  "Task is starting (%s)..." % status,
@@ -351,11 +360,38 @@ class Kubernetes(object):
351
360
  t = time.time()
352
361
  start_time = time.time()
353
362
  while job.is_waiting:
354
- new_status = job.status
355
- if status != new_status or (time.time() - t) > 30:
356
- status = new_status
363
+ # new_status = job.status
364
+ if status != job.status or (time.time() - t) > 30:
365
+ if not child_jobs:
366
+ child_statuses = ""
367
+ else:
368
+ status_keys = set(
369
+ [child_job.status for child_job in child_jobs]
370
+ )
371
+ status_counts = [
372
+ (
373
+ status,
374
+ len(
375
+ [
376
+ child_job.status == status
377
+ for child_job in child_jobs
378
+ ]
379
+ ),
380
+ )
381
+ for status in status_keys
382
+ ]
383
+ child_statuses = " (parallel node status: [{}])".format(
384
+ ", ".join(
385
+ [
386
+ "{}:{}".format(status, num)
387
+ for (status, num) in sorted(status_counts)
388
+ ]
389
+ )
390
+ )
391
+
392
+ status = job.status
357
393
  echo(
358
- "Task is starting (%s)..." % status,
394
+ "Task is starting (status %s)... %s" % (status, child_statuses),
359
395
  "stderr",
360
396
  job_id=job.id,
361
397
  )
@@ -367,8 +403,9 @@ class Kubernetes(object):
367
403
  stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
368
404
  stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
369
405
 
406
+ child_jobs = []
370
407
  # 1) Loop until the job has started
371
- wait_for_launch(self._job)
408
+ wait_for_launch(self._job, child_jobs)
372
409
 
373
410
  # 2) Tail logs until the job has finished
374
411
  tail_logs(
@@ -107,6 +107,26 @@ def kubernetes():
107
107
  type=JSONTypeClass(),
108
108
  multiple=False,
109
109
  )
110
+ @click.option(
111
+ "--labels",
112
+ default=None,
113
+ type=JSONTypeClass(),
114
+ multiple=False,
115
+ )
116
+ @click.option(
117
+ "--annotations",
118
+ default=None,
119
+ type=JSONTypeClass(),
120
+ multiple=False,
121
+ )
122
+ @click.option("--ubf-context", default=None, type=click.Choice([None, "ubf_control"]))
123
+ @click.option(
124
+ "--num-parallel",
125
+ default=0,
126
+ type=int,
127
+ help="Number of parallel nodes to run as a multi-node job.",
128
+ )
129
+ @click.option("--port", default=None, help="port number")
110
130
  @click.pass_context
111
131
  def step(
112
132
  ctx,
@@ -132,6 +152,10 @@ def step(
132
152
  run_time_limit=None,
133
153
  persistent_volume_claims=None,
134
154
  tolerations=None,
155
+ labels=None,
156
+ annotations=None,
157
+ num_parallel=None,
158
+ port=None,
135
159
  **kwargs
136
160
  ):
137
161
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -177,11 +201,17 @@ def step(
177
201
  )
178
202
  time.sleep(minutes_between_retries * 60)
179
203
 
204
+ step_args = " ".join(util.dict_to_cli_options(kwargs))
205
+ num_parallel = num_parallel or 0
206
+ if num_parallel and num_parallel > 1:
207
+ # For multinode, we need to add a placeholder that can be mutated by the caller
208
+ step_args += " [multinode-args]"
209
+
180
210
  step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
181
211
  entrypoint="%s -u %s" % (executable, os.path.basename(sys.argv[0])),
182
212
  top_args=" ".join(util.dict_to_cli_options(ctx.parent.parent.params)),
183
213
  step=step_name,
184
- step_args=" ".join(util.dict_to_cli_options(kwargs)),
214
+ step_args=step_args,
185
215
  )
186
216
 
187
217
  # Set log tailing.
@@ -207,6 +237,10 @@ def step(
207
237
  ),
208
238
  )
209
239
 
240
+ attrs = {
241
+ "metaflow.task_id": kwargs["task_id"],
242
+ "requires_passwordless_ssh": any([getattr(deco, "requires_passwordless_ssh", False) for deco in node.decorators]),
243
+ }
210
244
  try:
211
245
  kubernetes = Kubernetes(
212
246
  datastore=ctx.obj.flow_datastore,
@@ -245,6 +279,11 @@ def step(
245
279
  env=env,
246
280
  persistent_volume_claims=persistent_volume_claims,
247
281
  tolerations=tolerations,
282
+ labels=labels,
283
+ annotations=annotations,
284
+ num_parallel=num_parallel,
285
+ port=port,
286
+ attrs=attrs,
248
287
  )
249
288
  except Exception as e:
250
289
  traceback.print_exc(chain=False)
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import platform
4
4
  import sys
5
+ import time
5
6
 
6
7
  from metaflow import current
7
8
  from metaflow.decorators import StepDecorator
@@ -20,10 +21,12 @@ from metaflow.metaflow_config import (
20
21
  KUBERNETES_PERSISTENT_VOLUME_CLAIMS,
21
22
  KUBERNETES_TOLERATIONS,
22
23
  KUBERNETES_SERVICE_ACCOUNT,
24
+ KUBERNETES_PORT,
23
25
  )
24
26
  from metaflow.plugins.resources_decorator import ResourcesDecorator
25
27
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
26
28
  from metaflow.sidecar import Sidecar
29
+ from metaflow.unbounded_foreach import UBF_CONTROL
27
30
 
28
31
  from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
29
32
  from .kubernetes import KubernetesException, parse_kube_keyvalue_list
@@ -88,6 +91,8 @@ class KubernetesDecorator(StepDecorator):
88
91
  persistent_volume_claims: Dict[str, str], optional
89
92
  A map (dictionary) of persistent volumes to be mounted to the pod for this step. The map is from persistent
90
93
  volumes to the path to which the volume is to be mounted, e.g., `{'pvc-name': '/path/to/mount/on'}`.
94
+ port: int, optional
95
+ Number of the port to specify in the Kubernetes job object
91
96
  """
92
97
 
93
98
  name = "kubernetes"
@@ -110,6 +115,7 @@ class KubernetesDecorator(StepDecorator):
110
115
  "tmpfs_size": None,
111
116
  "tmpfs_path": "/metaflow_temp",
112
117
  "persistent_volume_claims": None, # e.g., {"pvc-name": "/mnt/vol", "another-pvc": "/mnt/vol2"}
118
+ "port": None,
113
119
  }
114
120
  package_url = None
115
121
  package_sha = None
@@ -195,6 +201,8 @@ class KubernetesDecorator(StepDecorator):
195
201
  if not self.attributes["tmpfs_size"]:
196
202
  # default tmpfs behavior - https://man7.org/linux/man-pages/man5/tmpfs.5.html
197
203
  self.attributes["tmpfs_size"] = int(self.attributes["memory"]) // 2
204
+ if not self.attributes["port"]:
205
+ self.attributes["port"] = KUBERNETES_PORT
198
206
 
199
207
  # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
200
208
  def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
@@ -216,12 +224,6 @@ class KubernetesDecorator(StepDecorator):
216
224
  "Kubernetes. Please use one or the other.".format(step=step)
217
225
  )
218
226
 
219
- for deco in decos:
220
- if getattr(deco, "IS_PARALLEL", False):
221
- raise KubernetesException(
222
- "@kubernetes does not support parallel execution currently."
223
- )
224
-
225
227
  # Set run time limit for the Kubernetes job.
226
228
  self.run_time_limit = get_run_time_limit_for_task(decos)
227
229
  if self.run_time_limit < 60:
@@ -432,6 +434,27 @@ class KubernetesDecorator(StepDecorator):
432
434
  self._save_logs_sidecar = Sidecar("save_logs_periodically")
433
435
  self._save_logs_sidecar.start()
434
436
 
437
+ num_parallel = int(os.environ.get("WORLD_SIZE", 0))
438
+ if num_parallel >= 1:
439
+ if ubf_context == UBF_CONTROL:
440
+ control_task_id = current.task_id
441
+ top_task_id = control_task_id.replace("control-", "")
442
+ mapper_task_ids = [control_task_id] + [
443
+ "%s-node-%d" % (top_task_id, node_idx)
444
+ for node_idx in range(1, num_parallel)
445
+ ]
446
+ flow._control_mapper_tasks = [
447
+ "%s/%s/%s" % (run_id, step_name, mapper_task_id)
448
+ for mapper_task_id in mapper_task_ids
449
+ ]
450
+ flow._control_task_is_mapper_zero = True
451
+ else:
452
+ worker_job_rank = int(os.environ["RANK"])
453
+ os.environ["RANK"] = str(worker_job_rank + 1)
454
+
455
+ if num_parallel >= 1:
456
+ _setup_multinode_environment()
457
+
435
458
  def task_finished(
436
459
  self, step_name, flow, graph, is_task_ok, retry_count, max_retries
437
460
  ):
@@ -459,9 +482,53 @@ class KubernetesDecorator(StepDecorator):
459
482
  # Best effort kill
460
483
  pass
461
484
 
485
+ if is_task_ok and len(getattr(flow, "_control_mapper_tasks", [])) > 1:
486
+ self._wait_for_mapper_tasks(flow, step_name)
487
+
488
+ def _wait_for_mapper_tasks(self, flow, step_name):
489
+ """
490
+ When launching multinode task with UBF, need to wait for the secondary
491
+ tasks to finish cleanly and produce their output before exiting the
492
+ main task. Otherwise, the main task finishing will cause secondary nodes
493
+ to terminate immediately, and possibly prematurely.
494
+ """
495
+ from metaflow import Step # avoid circular dependency
496
+
497
+ TIMEOUT = 600
498
+ last_completion_timeout = time.time() + TIMEOUT
499
+ print("Waiting for batch secondary tasks to finish")
500
+ while last_completion_timeout > time.time():
501
+ time.sleep(2)
502
+ try:
503
+ step_path = "%s/%s/%s" % (flow.name, current.run_id, step_name)
504
+ tasks = [task for task in Step(step_path)]
505
+ if len(tasks) == len(flow._control_mapper_tasks):
506
+ if all(
507
+ task.finished_at is not None for task in tasks
508
+ ): # for some reason task.finished fails
509
+ return True
510
+ else:
511
+ print(
512
+ "Waiting for all parallel tasks to finish. Finished: {}/{}".format(
513
+ len(tasks),
514
+ len(flow._control_mapper_tasks),
515
+ )
516
+ )
517
+ except Exception as e:
518
+ pass
519
+ raise Exception(
520
+ "Batch secondary workers did not finish in %s seconds" % TIMEOUT
521
+ )
522
+
462
523
  @classmethod
463
524
  def _save_package_once(cls, flow_datastore, package):
464
525
  if cls.package_url is None:
465
526
  cls.package_url, cls.package_sha = flow_datastore.save_data(
466
527
  [package.blob], len_hint=1
467
528
  )[0]
529
+
530
+ def _setup_multinode_environment():
531
+ import socket
532
+ os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
533
+ os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
534
+ os.environ["MF_PARALLEL_NODE_INDEX"] = os.environ["RANK"]