metaflow 2.12.19__py2.py3-none-any.whl → 2.12.21__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. metaflow/__init__.py +11 -21
  2. metaflow/client/core.py +1 -1
  3. metaflow/cmd/main_cli.py +3 -2
  4. metaflow/extension_support/__init__.py +120 -29
  5. metaflow/flowspec.py +4 -0
  6. metaflow/info_file.py +25 -0
  7. metaflow/metaflow_config.py +0 -1
  8. metaflow/metaflow_current.py +3 -1
  9. metaflow/metaflow_environment.py +1 -7
  10. metaflow/metaflow_version.py +130 -64
  11. metaflow/package.py +2 -1
  12. metaflow/plugins/argo/argo_workflows.py +10 -1
  13. metaflow/plugins/aws/batch/batch_client.py +3 -0
  14. metaflow/plugins/kubernetes/kube_utils.py +25 -0
  15. metaflow/plugins/kubernetes/kubernetes.py +3 -0
  16. metaflow/plugins/kubernetes/kubernetes_cli.py +84 -1
  17. metaflow/plugins/kubernetes/kubernetes_client.py +97 -0
  18. metaflow/plugins/kubernetes/kubernetes_decorator.py +4 -0
  19. metaflow/plugins/parallel_decorator.py +4 -0
  20. metaflow/plugins/pypi/bootstrap.py +2 -0
  21. metaflow/plugins/pypi/conda_decorator.py +7 -1
  22. metaflow/runner/click_api.py +13 -1
  23. metaflow/runner/deployer.py +9 -2
  24. metaflow/runner/metaflow_runner.py +4 -2
  25. metaflow/runner/subprocess_manager.py +8 -3
  26. metaflow/runner/utils.py +19 -2
  27. metaflow/version.py +1 -1
  28. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/METADATA +2 -2
  29. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/RECORD +33 -31
  30. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/WHEEL +1 -1
  31. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/LICENSE +0 -0
  32. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/entry_points.txt +0 -0
  33. {metaflow-2.12.19.dist-info → metaflow-2.12.21.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,15 @@ See the documentation of get_version for more information
7
7
 
8
8
  # This file is adapted from https://github.com/aebrahim/python-git-version
9
9
 
10
- from subprocess import check_output, CalledProcessError
11
- from os import path, name, devnull, environ, listdir
12
- import json
10
+ import subprocess
11
+ from os import path, name, environ, listdir
13
12
 
14
- from metaflow import CURRENT_DIRECTORY, INFO_FILE
13
+ from metaflow.extension_support import update_package_info
14
+ from metaflow.info_file import CURRENT_DIRECTORY, read_info_file
15
+
16
+
17
+ # True/False correspond to the value `public`` in get_version
18
+ _version_cache = {True: None, False: None}
15
19
 
16
20
  __all__ = ("get_version",)
17
21
 
@@ -57,87 +61,149 @@ if name == "nt":
57
61
  GIT_COMMAND = find_git_on_windows()
58
62
 
59
63
 
60
- def call_git_describe(abbrev=7):
64
+ def call_git_describe(file_to_check, abbrev=7):
61
65
  """return the string output of git describe"""
62
66
  try:
63
-
64
- # first, make sure we are actually in a Metaflow repo,
65
- # not some other repo
66
- with open(devnull, "w") as fnull:
67
- arguments = [GIT_COMMAND, "rev-parse", "--show-toplevel"]
68
- reponame = (
69
- check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull)
70
- .decode("ascii")
71
- .strip()
72
- )
73
- if path.basename(reponame) != "metaflow":
74
- return None
75
-
76
- with open(devnull, "w") as fnull:
77
- arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev]
78
- return (
79
- check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull)
80
- .decode("ascii")
81
- .strip()
82
- )
83
-
84
- except (OSError, CalledProcessError):
67
+ wd = path.dirname(file_to_check)
68
+ filename = path.basename(file_to_check)
69
+
70
+ # First check if the file is tracked in the GIT repository we are in
71
+ # We do this because in some setups and for some bizarre reason, python files
72
+ # are installed directly into a git repository (I am looking at you brew). We
73
+ # don't want to consider this a GIT install in that case.
74
+ args = [GIT_COMMAND, "ls-files", "--error-unmatch", filename]
75
+ git_return_code = subprocess.run(
76
+ args,
77
+ cwd=wd,
78
+ stderr=subprocess.DEVNULL,
79
+ stdout=subprocess.DEVNULL,
80
+ check=False,
81
+ ).returncode
82
+
83
+ if git_return_code != 0:
84
+ return None
85
+
86
+ args = [
87
+ GIT_COMMAND,
88
+ "describe",
89
+ "--tags",
90
+ "--dirty",
91
+ "--long",
92
+ "--abbrev=%d" % abbrev,
93
+ ]
94
+ return (
95
+ subprocess.check_output(args, cwd=wd, stderr=subprocess.DEVNULL)
96
+ .decode("ascii")
97
+ .strip()
98
+ )
99
+
100
+ except (OSError, subprocess.CalledProcessError):
85
101
  return None
86
102
 
87
103
 
88
- def format_git_describe(git_str, pep440=False):
104
+ def format_git_describe(git_str, public=False):
89
105
  """format the result of calling 'git describe' as a python version"""
90
106
  if git_str is None:
91
107
  return None
92
- if "-" not in git_str: # currently at a tag
93
- return git_str
108
+ splits = git_str.split("-")
109
+ if len(splits) == 4:
110
+ # Formatted as <tag>-<post>-<hash>-dirty
111
+ tag, post, h = splits[:3]
112
+ dirty = "-" + splits[3]
94
113
  else:
95
- # formatted as version-N-githash
96
- # want to convert to version.postN-githash
97
- git_str = git_str.replace("-", ".post", 1)
98
- if pep440: # does not allow git hash afterwards
99
- return git_str.split("-")[0]
100
- else:
101
- return git_str.replace("-g", "+git")
114
+ # Formatted as <tag>-<post>-<hash>
115
+ tag, post, h = splits
116
+ dirty = ""
117
+ if post == "0":
118
+ if public:
119
+ return tag
120
+ return tag + dirty
121
+
122
+ if public:
123
+ return "%s.post%s" % (tag, post)
124
+
125
+ return "%s.post%s-git%s%s" % (tag, post, h[1:], dirty)
102
126
 
103
127
 
104
128
  def read_info_version():
105
129
  """Read version information from INFO file"""
106
- try:
107
- with open(INFO_FILE, "r") as contents:
108
- return json.load(contents).get("metaflow_version")
109
- except IOError:
110
- return None
130
+ info_file = read_info_file()
131
+ if info_file:
132
+ return info_file.get("metaflow_version")
133
+ return None
111
134
 
112
135
 
113
- def get_version(pep440=False):
136
+ def get_version(public=False):
114
137
  """Tracks the version number.
115
138
 
116
- pep440: bool
117
- When True, this function returns a version string suitable for
118
- a release as defined by PEP 440. When False, the githash (if
119
- available) will be appended to the version string.
139
+ public: bool
140
+ When True, this function returns a *public* version specification which
141
+ doesn't include any local information (dirtiness or hash). See
142
+ https://packaging.python.org/en/latest/specifications/version-specifiers/#version-scheme
120
143
 
121
- If the script is located within an active git repository,
122
- git-describe is used to get the version information.
144
+ We first check the INFO file to see if we recorded a version of Metaflow. If there
145
+ is none, we check if we are in a GIT repository and if so, form the version
146
+ from that.
123
147
 
124
- Otherwise, the version logged by package installer is returned.
125
-
126
- If even that information isn't available (likely when executing on a
127
- remote cloud instance), the version information is returned from INFO file
128
- in the current directory.
148
+ Otherwise, we return the version of Metaflow that was installed.
129
149
 
130
150
  """
131
151
 
132
- version = format_git_describe(call_git_describe(), pep440=pep440)
133
- version_addl = None
134
- if version is None: # not a git repository
135
- import metaflow
136
-
152
+ global _version_cache
153
+
154
+ # To get the version we do the following:
155
+ # - Check if we have a cached version. If so, return that
156
+ # - Then check if we have an INFO file present. If so, use that as it is
157
+ # the most reliable way to get the version. In particular, when running remotely,
158
+ # metaflow is installed in a directory and if any extension is using distutils to
159
+ # determine its version, this would return None and querying the version directly
160
+ # from the extension would fail to produce the correct result
161
+ # - Then if we are in the GIT repository and if so, use the git describe
162
+ # - If we don't have an INFO file, we look at the version information that is
163
+ # populated by metaflow and the extensions.
164
+
165
+ if _version_cache[public] is not None:
166
+ return _version_cache[public]
167
+
168
+ version = (
169
+ read_info_version()
170
+ ) # Version info is cached in INFO file; includes extension info
171
+
172
+ if version:
173
+ _version_cache[public] = version
174
+ return version
175
+
176
+ # Get the version for Metaflow, favor the GIT version
177
+ import metaflow
178
+
179
+ version = format_git_describe(
180
+ call_git_describe(file_to_check=metaflow.__file__), public=public
181
+ )
182
+ if version is None:
137
183
  version = metaflow.__version__
138
- version_addl = metaflow.__version_addl__
139
- if version is None: # not a proper python package
140
- return read_info_version()
141
- if version_addl:
142
- return "+".join([version, version_addl])
184
+
185
+ # Look for extensions and compute their versions. Properly formed extensions have
186
+ # a toplevel file which will contain a __mf_extensions__ value and a __version__
187
+ # value. We already saved the properly formed modules when loading metaflow in
188
+ # __ext_tl_modules__.
189
+ ext_versions = []
190
+ for pkg_name, extension_module in metaflow.__ext_tl_modules__:
191
+ ext_name = getattr(extension_module, "__mf_extensions__", "<unk>")
192
+ ext_version = format_git_describe(
193
+ call_git_describe(file_to_check=extension_module.__file__), public=public
194
+ )
195
+ if ext_version is None:
196
+ ext_version = getattr(extension_module, "__version__", "<unk>")
197
+ # Update the package information about reported version for the extension
198
+ update_package_info(
199
+ package_name=pkg_name,
200
+ extension_name=ext_name,
201
+ package_version=ext_version,
202
+ )
203
+ ext_versions.append("%s(%s)" % (ext_name, ext_version))
204
+
205
+ # We now have all the information about extensions so we can form the final string
206
+ if ext_versions:
207
+ version = version + "+" + ";".join(ext_versions)
208
+ _version_cache[public] = version
143
209
  return version
metaflow/package.py CHANGED
@@ -10,7 +10,8 @@ from .extension_support import EXT_PKG, package_mfext_all
10
10
  from .metaflow_config import DEFAULT_PACKAGE_SUFFIXES
11
11
  from .exception import MetaflowException
12
12
  from .util import to_unicode
13
- from . import R, INFO_FILE
13
+ from . import R
14
+ from .info_file import INFO_FILE
14
15
 
15
16
  DEFAULT_SUFFIXES_LIST = DEFAULT_PACKAGE_SUFFIXES.split(",")
16
17
  METAFLOW_SUFFIXES_LIST = [".py", ".html", ".css", ".js"]
@@ -1905,6 +1905,12 @@ class ArgoWorkflows(object):
1905
1905
  jobset.environment_variable(
1906
1906
  "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
1907
1907
  )
1908
+ # We need this task-id set so that all the nodes are aware of the control
1909
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
1910
+ jobset.environment_variable(
1911
+ "MF_PARALLEL_CONTROL_TASK_ID",
1912
+ "control-{{inputs.parameters.task-id-entropy}}-0",
1913
+ )
1908
1914
  # for k, v in .items():
1909
1915
  jobset.environment_variables_from_selectors(
1910
1916
  {
@@ -2552,8 +2558,8 @@ class ArgoWorkflows(object):
2552
2558
  cmd_str = " && ".join([init_cmds, heartbeat_cmds])
2553
2559
  cmds = shlex.split('bash -c "%s"' % cmd_str)
2554
2560
 
2555
- # TODO: Check that this is the minimal env.
2556
2561
  # Env required for sending heartbeats to the metadata service, nothing extra.
2562
+ # prod token / runtime info is required to correctly register flow branches
2557
2563
  env = {
2558
2564
  # These values are needed by Metaflow to set it's internal
2559
2565
  # state appropriately.
@@ -2565,7 +2571,10 @@ class ArgoWorkflows(object):
2565
2571
  "METAFLOW_USER": "argo-workflows",
2566
2572
  "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
2567
2573
  "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
2574
+ "METAFLOW_KUBERNETES_WORKLOAD": 1,
2575
+ "METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
2568
2576
  "METAFLOW_OWNER": self.username,
2577
+ "METAFLOW_PRODUCTION_TOKEN": self.production_token,
2569
2578
  }
2570
2579
  # support Metaflow sandboxes
2571
2580
  env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
@@ -89,6 +89,9 @@ class BatchJob(object):
89
89
  # Multinode
90
90
  if getattr(self, "num_parallel", 0) >= 1:
91
91
  num_nodes = self.num_parallel
92
+ # We need this task-id set so that all the nodes are aware of the control
93
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
94
+ self.environment_variable("MF_PARALLEL_CONTROL_TASK_ID", self._task_id)
92
95
  main_task_override = copy.deepcopy(self.payload["containerOverrides"])
93
96
 
94
97
  # main
@@ -0,0 +1,25 @@
1
+ from metaflow.exception import CommandException
2
+ from metaflow.util import get_username, get_latest_run_id
3
+
4
+
5
+ def parse_cli_options(flow_name, run_id, user, my_runs, echo):
6
+ if user and my_runs:
7
+ raise CommandException("--user and --my-runs are mutually exclusive.")
8
+
9
+ if run_id and my_runs:
10
+ raise CommandException("--run_id and --my-runs are mutually exclusive.")
11
+
12
+ if my_runs:
13
+ user = get_username()
14
+
15
+ latest_run = True
16
+
17
+ if user and not run_id:
18
+ latest_run = False
19
+
20
+ if not run_id and latest_run:
21
+ run_id = get_latest_run_id(echo, flow_name)
22
+ if run_id is None:
23
+ raise CommandException("A previous run id was not found. Specify --run-id.")
24
+
25
+ return flow_name, run_id, user
@@ -401,6 +401,9 @@ class Kubernetes(object):
401
401
  .label("app.kubernetes.io/name", "metaflow-task")
402
402
  .label("app.kubernetes.io/part-of", "metaflow")
403
403
  )
404
+ # We need this task-id set so that all the nodes are aware of the control
405
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
406
+ jobset.environment_variable("MF_PARALLEL_CONTROL_TASK_ID", str(task_id))
404
407
 
405
408
  ## ----------- control/worker specific values START here -----------
406
409
  # We will now set the appropriate command for the control/worker job
@@ -3,10 +3,12 @@ import sys
3
3
  import time
4
4
  import traceback
5
5
 
6
+ from metaflow.plugins.kubernetes.kube_utils import parse_cli_options
7
+ from metaflow.plugins.kubernetes.kubernetes_client import KubernetesClient
6
8
  import metaflow.tracing as tracing
7
9
  from metaflow import JSONTypeClass, util
8
10
  from metaflow._vendor import click
9
- from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
11
+ from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException
10
12
  from metaflow.metadata.util import sync_local_metadata_from_datastore
11
13
  from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
12
14
  from metaflow.mflog import TASK_LOG_SOURCE
@@ -305,3 +307,84 @@ def step(
305
307
  sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
306
308
  finally:
307
309
  _sync_metadata()
310
+
311
+
312
+ @kubernetes.command(help="List unfinished Kubernetes tasks of this flow.")
313
+ @click.option(
314
+ "--my-runs",
315
+ default=False,
316
+ is_flag=True,
317
+ help="List all my unfinished tasks.",
318
+ )
319
+ @click.option("--user", default=None, help="List unfinished tasks for the given user.")
320
+ @click.option(
321
+ "--run-id",
322
+ default=None,
323
+ help="List unfinished tasks corresponding to the run id.",
324
+ )
325
+ @click.pass_obj
326
+ def list(obj, run_id, user, my_runs):
327
+ flow_name, run_id, user = parse_cli_options(
328
+ obj.flow.name, run_id, user, my_runs, obj.echo
329
+ )
330
+ kube_client = KubernetesClient()
331
+ pods = kube_client.list(obj.flow.name, run_id, user)
332
+
333
+ def format_timestamp(timestamp=None):
334
+ if timestamp is None:
335
+ return "-"
336
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
337
+
338
+ for pod in pods:
339
+ obj.echo(
340
+ "Run: *{run_id}* "
341
+ "Pod: *{pod_id}* "
342
+ "Started At: {startedAt} "
343
+ "Status: *{status}*".format(
344
+ run_id=pod.metadata.annotations.get(
345
+ "metaflow/run_id",
346
+ pod.metadata.labels.get("workflows.argoproj.io/workflow"),
347
+ ),
348
+ pod_id=pod.metadata.name,
349
+ startedAt=format_timestamp(pod.status.start_time),
350
+ status=pod.status.phase,
351
+ )
352
+ )
353
+
354
+ if not pods:
355
+ obj.echo("No active Kubernetes pods found.")
356
+
357
+
358
+ @kubernetes.command(
359
+ help="Terminate unfinished Kubernetes tasks of this flow. Killed pods may result in newer attempts when using @retry."
360
+ )
361
+ @click.option(
362
+ "--my-runs",
363
+ default=False,
364
+ is_flag=True,
365
+ help="Kill all my unfinished tasks.",
366
+ )
367
+ @click.option(
368
+ "--user",
369
+ default=None,
370
+ help="Terminate unfinished tasks for the given user.",
371
+ )
372
+ @click.option(
373
+ "--run-id",
374
+ default=None,
375
+ help="Terminate unfinished tasks corresponding to the run id.",
376
+ )
377
+ @click.pass_obj
378
+ def kill(obj, run_id, user, my_runs):
379
+ flow_name, run_id, user = parse_cli_options(
380
+ obj.flow.name, run_id, user, my_runs, obj.echo
381
+ )
382
+
383
+ if run_id is not None and run_id.startswith("argo-") or user == "argo-workflows":
384
+ raise MetaflowException(
385
+ "Killing pods launched by Argo Workflows is not supported. "
386
+ "Use *argo-workflows terminate* instead."
387
+ )
388
+
389
+ kube_client = KubernetesClient()
390
+ kube_client.kill_pods(flow_name, run_id, user, obj.echo)
@@ -1,8 +1,10 @@
1
+ from concurrent.futures import ThreadPoolExecutor
1
2
  import os
2
3
  import sys
3
4
  import time
4
5
 
5
6
  from metaflow.exception import MetaflowException
7
+ from metaflow.metaflow_config import KUBERNETES_NAMESPACE
6
8
 
7
9
  from .kubernetes_job import KubernetesJob, KubernetesJobSet
8
10
 
@@ -28,6 +30,7 @@ class KubernetesClient(object):
28
30
  % sys.executable
29
31
  )
30
32
  self._refresh_client()
33
+ self._namespace = KUBERNETES_NAMESPACE
31
34
 
32
35
  def _refresh_client(self):
33
36
  from kubernetes import client, config
@@ -60,6 +63,100 @@ class KubernetesClient(object):
60
63
 
61
64
  return self._client
62
65
 
66
+ def _find_active_pods(self, flow_name, run_id=None, user=None):
67
+ def _request(_continue=None):
68
+ # handle paginated responses
69
+ return self._client.CoreV1Api().list_namespaced_pod(
70
+ namespace=self._namespace,
71
+ # limited selector support for K8S api. We want to cover multiple statuses: Running / Pending / Unknown
72
+ field_selector="status.phase!=Succeeded,status.phase!=Failed",
73
+ limit=1000,
74
+ _continue=_continue,
75
+ )
76
+
77
+ results = _request()
78
+
79
+ if run_id is not None:
80
+ # handle argo prefixes in run_id
81
+ run_id = run_id[run_id.startswith("argo-") and len("argo-") :]
82
+
83
+ while results.metadata._continue or results.items:
84
+ for pod in results.items:
85
+ match = (
86
+ # arbitrary pods might have no annotations at all.
87
+ pod.metadata.annotations
88
+ and pod.metadata.labels
89
+ and (
90
+ run_id is None
91
+ or (pod.metadata.annotations.get("metaflow/run_id") == run_id)
92
+ # we want to also match pods launched by argo-workflows
93
+ or (
94
+ pod.metadata.labels.get("workflows.argoproj.io/workflow")
95
+ == run_id
96
+ )
97
+ )
98
+ and (
99
+ user is None
100
+ or pod.metadata.annotations.get("metaflow/user") == user
101
+ )
102
+ and (
103
+ pod.metadata.annotations.get("metaflow/flow_name") == flow_name
104
+ )
105
+ )
106
+ if match:
107
+ yield pod
108
+ if not results.metadata._continue:
109
+ break
110
+ results = _request(results.metadata._continue)
111
+
112
+ def list(self, flow_name, run_id, user):
113
+ results = self._find_active_pods(flow_name, run_id, user)
114
+
115
+ return list(results)
116
+
117
+ def kill_pods(self, flow_name, run_id, user, echo):
118
+ from kubernetes.stream import stream
119
+
120
+ api_instance = self._client.CoreV1Api()
121
+ job_api = self._client.BatchV1Api()
122
+ pods = self._find_active_pods(flow_name, run_id, user)
123
+
124
+ def _kill_pod(pod):
125
+ echo("Killing Kubernetes pod %s\n" % pod.metadata.name)
126
+ try:
127
+ stream(
128
+ api_instance.connect_get_namespaced_pod_exec,
129
+ name=pod.metadata.name,
130
+ namespace=pod.metadata.namespace,
131
+ command=[
132
+ "/bin/sh",
133
+ "-c",
134
+ "/sbin/killall5",
135
+ ],
136
+ stderr=True,
137
+ stdin=False,
138
+ stdout=True,
139
+ tty=False,
140
+ )
141
+ except Exception:
142
+ # best effort kill for pod can fail.
143
+ try:
144
+ job_name = pod.metadata.labels.get("job-name", None)
145
+ if job_name is None:
146
+ raise Exception("Could not determine job name")
147
+
148
+ job_api.patch_namespaced_job(
149
+ name=job_name,
150
+ namespace=pod.metadata.namespace,
151
+ field_manager="metaflow",
152
+ body={"spec": {"parallelism": 0}},
153
+ )
154
+ except Exception as e:
155
+ echo("failed to kill pod %s - %s" % (pod.metadata.name, str(e)))
156
+
157
+ with ThreadPoolExecutor() as executor:
158
+ executor.map(_kill_pod, list(pods))
159
+
63
160
  def jobset(self, **kwargs):
64
161
  return KubernetesJobSet(self, **kwargs)
65
162
 
@@ -70,6 +70,10 @@ class KubernetesDecorator(StepDecorator):
70
70
  Kubernetes secrets to use when launching pod in Kubernetes. These
71
71
  secrets are in addition to the ones defined in `METAFLOW_KUBERNETES_SECRETS`
72
72
  in Metaflow configuration.
73
+ node_selector: Union[Dict[str,str], str], optional, default None
74
+ Kubernetes node selector(s) to apply to the pod running the task.
75
+ Can be passed in as a comma separated string of values e.g. "kubernetes.io/os=linux,kubernetes.io/arch=amd64"
76
+ or as a dictionary {"kubernetes.io/os": "linux", "kubernetes.io/arch": "amd64"}
73
77
  namespace : str, default METAFLOW_KUBERNETES_NAMESPACE
74
78
  Kubernetes namespace to use when launching pod in Kubernetes.
75
79
  gpu : int, optional, default None
@@ -24,6 +24,8 @@ class ParallelDecorator(StepDecorator):
24
24
  The total number of tasks created by @parallel
25
25
  - node_index : int
26
26
  The index of the current task in all the @parallel tasks.
27
+ - control_task_id : Optional[str]
28
+ The task ID of the control task. Available to all tasks.
27
29
 
28
30
  is_parallel -> bool
29
31
  True if the current step is a @parallel step.
@@ -67,6 +69,7 @@ class ParallelDecorator(StepDecorator):
67
69
  main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"),
68
70
  num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")),
69
71
  node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")),
72
+ control_task_id=os.environ.get("MF_PARALLEL_CONTROL_TASK_ID", None),
70
73
  )
71
74
  ),
72
75
  )
@@ -177,6 +180,7 @@ def _local_multinode_control_task_step_func(
177
180
  num_parallel = foreach_iter.num_parallel
178
181
  os.environ["MF_PARALLEL_NUM_NODES"] = str(num_parallel)
179
182
  os.environ["MF_PARALLEL_MAIN_IP"] = "127.0.0.1"
183
+ os.environ["MF_PARALLEL_CONTROL_TASK_ID"] = str(current.task_id)
180
184
 
181
185
  run_id = current.run_id
182
186
  step_name = current.step_name
@@ -103,6 +103,7 @@ if __name__ == "__main__":
103
103
  echo "@EXPLICIT" > "$tmpfile";
104
104
  ls -d {conda_pkgs_dir}/*/* >> "$tmpfile";
105
105
  export PATH=$PATH:$(pwd)/micromamba;
106
+ export CONDA_PKGS_DIRS=$(pwd)/micromamba/pkgs;
106
107
  micromamba create --yes --offline --no-deps --safety-checks=disabled --no-extra-safety-checks --prefix {prefix} --file "$tmpfile";
107
108
  rm "$tmpfile"''',
108
109
  ]
@@ -123,6 +124,7 @@ if __name__ == "__main__":
123
124
  [
124
125
  f"""set -e;
125
126
  export PATH=$PATH:$(pwd)/micromamba;
127
+ export CONDA_PKGS_DIRS=$(pwd)/micromamba/pkgs;
126
128
  micromamba run --prefix {prefix} python -m pip --disable-pip-version-check install --root-user-action=ignore --no-compile {pypi_pkgs_dir}/*.whl --no-user"""
127
129
  ]
128
130
  )
@@ -12,7 +12,7 @@ from metaflow.metadata import MetaDatum
12
12
  from metaflow.metaflow_environment import InvalidEnvironmentException
13
13
  from metaflow.util import get_metaflow_root
14
14
 
15
- from ... import INFO_FILE
15
+ from ...info_file import INFO_FILE
16
16
 
17
17
 
18
18
  class CondaStepDecorator(StepDecorator):
@@ -353,6 +353,12 @@ class CondaFlowDecorator(FlowDecorator):
353
353
  def flow_init(
354
354
  self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
355
355
  ):
356
+ # NOTE: Important for extensions implementing custom virtual environments.
357
+ # Without this steps will not have an implicit conda step decorator on them unless the environment adds one in its decospecs.
358
+ from metaflow import decorators
359
+
360
+ decorators._attach_decorators(flow, ["conda"])
361
+
356
362
  # @conda uses a conda environment to create a virtual environment.
357
363
  # The conda environment can be created through micromamba.
358
364
  _supported_virtual_envs = ["conda"]
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import sys
2
3
 
3
4
  if sys.version_info < (3, 7):
@@ -12,6 +13,7 @@ import importlib
12
13
  import inspect
13
14
  import itertools
14
15
  import uuid
16
+ import json
15
17
  from collections import OrderedDict
16
18
  from typing import Any, Callable, Dict, List, Optional
17
19
  from typing import OrderedDict as TOrderedDict
@@ -37,6 +39,9 @@ from metaflow.exception import MetaflowException
37
39
  from metaflow.includefile import FilePathClass
38
40
  from metaflow.parameters import JSONTypeClass, flow_context
39
41
 
42
+ # Define a recursive type alias for JSON
43
+ JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None]
44
+
40
45
  click_to_python_types = {
41
46
  StringParamType: str,
42
47
  IntParamType: int,
@@ -48,7 +53,7 @@ click_to_python_types = {
48
53
  Tuple: tuple,
49
54
  Choice: str,
50
55
  File: str,
51
- JSONTypeClass: str,
56
+ JSONTypeClass: JSON,
52
57
  FilePathClass: str,
53
58
  }
54
59
 
@@ -82,6 +87,11 @@ def _method_sanity_check(
82
87
  % (supplied_k, annotations[supplied_k], defaults[supplied_k])
83
88
  )
84
89
 
90
+ # because Click expects stringified JSON..
91
+ supplied_v = (
92
+ json.dumps(supplied_v) if annotations[supplied_k] == JSON else supplied_v
93
+ )
94
+
85
95
  if supplied_k in possible_arg_params:
86
96
  cli_name = possible_arg_params[supplied_k].opts[0].strip("-")
87
97
  method_params["args"][cli_name] = supplied_v
@@ -142,6 +152,8 @@ loaded_modules = {}
142
152
 
143
153
 
144
154
  def extract_flow_class_from_file(flow_file: str) -> FlowSpec:
155
+ if not os.path.exists(flow_file):
156
+ raise FileNotFoundError("Flow file not present at '%s'" % flow_file)
145
157
  # Check if the module has already been loaded
146
158
  if flow_file in loaded_modules:
147
159
  module = loaded_modules[flow_file]