ob-metaflow 2.12.19.1__py2.py3-none-any.whl → 2.12.22.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (35) hide show
  1. metaflow/__init__.py +11 -21
  2. metaflow/client/core.py +1 -1
  3. metaflow/cmd/main_cli.py +3 -2
  4. metaflow/extension_support/__init__.py +120 -29
  5. metaflow/flowspec.py +4 -0
  6. metaflow/info_file.py +25 -0
  7. metaflow/metaflow_config.py +0 -1
  8. metaflow/metaflow_current.py +3 -1
  9. metaflow/metaflow_environment.py +1 -7
  10. metaflow/metaflow_version.py +130 -64
  11. metaflow/package.py +2 -1
  12. metaflow/plugins/argo/argo_workflows.py +47 -8
  13. metaflow/plugins/argo/argo_workflows_cli.py +26 -0
  14. metaflow/plugins/aws/batch/batch_client.py +3 -0
  15. metaflow/plugins/kubernetes/kube_utils.py +25 -0
  16. metaflow/plugins/kubernetes/kubernetes.py +3 -0
  17. metaflow/plugins/kubernetes/kubernetes_cli.py +84 -1
  18. metaflow/plugins/kubernetes/kubernetes_client.py +97 -0
  19. metaflow/plugins/kubernetes/kubernetes_decorator.py +4 -0
  20. metaflow/plugins/parallel_decorator.py +4 -0
  21. metaflow/plugins/pypi/bootstrap.py +2 -0
  22. metaflow/plugins/pypi/conda_decorator.py +7 -1
  23. metaflow/runner/click_api.py +13 -1
  24. metaflow/runner/deployer.py +9 -2
  25. metaflow/runner/metaflow_runner.py +4 -2
  26. metaflow/runner/subprocess_manager.py +8 -3
  27. metaflow/runner/utils.py +19 -2
  28. metaflow/version.py +1 -1
  29. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/METADATA +2 -2
  30. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/RECORD +34 -33
  31. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/WHEEL +1 -1
  32. metaflow/plugins/argo/daemon.py +0 -59
  33. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/LICENSE +0 -0
  34. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/entry_points.txt +0 -0
  35. {ob_metaflow-2.12.19.1.dist-info → ob_metaflow-2.12.22.1.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,15 @@ See the documentation of get_version for more information
7
7
 
8
8
  # This file is adapted from https://github.com/aebrahim/python-git-version
9
9
 
10
- from subprocess import check_output, CalledProcessError
11
- from os import path, name, devnull, environ, listdir
12
- import json
10
+ import subprocess
11
+ from os import path, name, environ, listdir
13
12
 
14
- from metaflow import CURRENT_DIRECTORY, INFO_FILE
13
+ from metaflow.extension_support import update_package_info
14
+ from metaflow.info_file import CURRENT_DIRECTORY, read_info_file
15
+
16
+
17
+ # True/False correspond to the value `public`` in get_version
18
+ _version_cache = {True: None, False: None}
15
19
 
16
20
  __all__ = ("get_version",)
17
21
 
@@ -57,87 +61,149 @@ if name == "nt":
57
61
  GIT_COMMAND = find_git_on_windows()
58
62
 
59
63
 
60
- def call_git_describe(abbrev=7):
64
+ def call_git_describe(file_to_check, abbrev=7):
61
65
  """return the string output of git describe"""
62
66
  try:
63
-
64
- # first, make sure we are actually in a Metaflow repo,
65
- # not some other repo
66
- with open(devnull, "w") as fnull:
67
- arguments = [GIT_COMMAND, "rev-parse", "--show-toplevel"]
68
- reponame = (
69
- check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull)
70
- .decode("ascii")
71
- .strip()
72
- )
73
- if path.basename(reponame) != "metaflow":
74
- return None
75
-
76
- with open(devnull, "w") as fnull:
77
- arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev]
78
- return (
79
- check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull)
80
- .decode("ascii")
81
- .strip()
82
- )
83
-
84
- except (OSError, CalledProcessError):
67
+ wd = path.dirname(file_to_check)
68
+ filename = path.basename(file_to_check)
69
+
70
+ # First check if the file is tracked in the GIT repository we are in
71
+ # We do this because in some setups and for some bizarre reason, python files
72
+ # are installed directly into a git repository (I am looking at you brew). We
73
+ # don't want to consider this a GIT install in that case.
74
+ args = [GIT_COMMAND, "ls-files", "--error-unmatch", filename]
75
+ git_return_code = subprocess.run(
76
+ args,
77
+ cwd=wd,
78
+ stderr=subprocess.DEVNULL,
79
+ stdout=subprocess.DEVNULL,
80
+ check=False,
81
+ ).returncode
82
+
83
+ if git_return_code != 0:
84
+ return None
85
+
86
+ args = [
87
+ GIT_COMMAND,
88
+ "describe",
89
+ "--tags",
90
+ "--dirty",
91
+ "--long",
92
+ "--abbrev=%d" % abbrev,
93
+ ]
94
+ return (
95
+ subprocess.check_output(args, cwd=wd, stderr=subprocess.DEVNULL)
96
+ .decode("ascii")
97
+ .strip()
98
+ )
99
+
100
+ except (OSError, subprocess.CalledProcessError):
85
101
  return None
86
102
 
87
103
 
88
- def format_git_describe(git_str, pep440=False):
104
+ def format_git_describe(git_str, public=False):
89
105
  """format the result of calling 'git describe' as a python version"""
90
106
  if git_str is None:
91
107
  return None
92
- if "-" not in git_str: # currently at a tag
93
- return git_str
108
+ splits = git_str.split("-")
109
+ if len(splits) == 4:
110
+ # Formatted as <tag>-<post>-<hash>-dirty
111
+ tag, post, h = splits[:3]
112
+ dirty = "-" + splits[3]
94
113
  else:
95
- # formatted as version-N-githash
96
- # want to convert to version.postN-githash
97
- git_str = git_str.replace("-", ".post", 1)
98
- if pep440: # does not allow git hash afterwards
99
- return git_str.split("-")[0]
100
- else:
101
- return git_str.replace("-g", "+git")
114
+ # Formatted as <tag>-<post>-<hash>
115
+ tag, post, h = splits
116
+ dirty = ""
117
+ if post == "0":
118
+ if public:
119
+ return tag
120
+ return tag + dirty
121
+
122
+ if public:
123
+ return "%s.post%s" % (tag, post)
124
+
125
+ return "%s.post%s-git%s%s" % (tag, post, h[1:], dirty)
102
126
 
103
127
 
104
128
  def read_info_version():
105
129
  """Read version information from INFO file"""
106
- try:
107
- with open(INFO_FILE, "r") as contents:
108
- return json.load(contents).get("metaflow_version")
109
- except IOError:
110
- return None
130
+ info_file = read_info_file()
131
+ if info_file:
132
+ return info_file.get("metaflow_version")
133
+ return None
111
134
 
112
135
 
113
- def get_version(pep440=False):
136
+ def get_version(public=False):
114
137
  """Tracks the version number.
115
138
 
116
- pep440: bool
117
- When True, this function returns a version string suitable for
118
- a release as defined by PEP 440. When False, the githash (if
119
- available) will be appended to the version string.
139
+ public: bool
140
+ When True, this function returns a *public* version specification which
141
+ doesn't include any local information (dirtiness or hash). See
142
+ https://packaging.python.org/en/latest/specifications/version-specifiers/#version-scheme
120
143
 
121
- If the script is located within an active git repository,
122
- git-describe is used to get the version information.
144
+ We first check the INFO file to see if we recorded a version of Metaflow. If there
145
+ is none, we check if we are in a GIT repository and if so, form the version
146
+ from that.
123
147
 
124
- Otherwise, the version logged by package installer is returned.
125
-
126
- If even that information isn't available (likely when executing on a
127
- remote cloud instance), the version information is returned from INFO file
128
- in the current directory.
148
+ Otherwise, we return the version of Metaflow that was installed.
129
149
 
130
150
  """
131
151
 
132
- version = format_git_describe(call_git_describe(), pep440=pep440)
133
- version_addl = None
134
- if version is None: # not a git repository
135
- import metaflow
136
-
152
+ global _version_cache
153
+
154
+ # To get the version we do the following:
155
+ # - Check if we have a cached version. If so, return that
156
+ # - Then check if we have an INFO file present. If so, use that as it is
157
+ # the most reliable way to get the version. In particular, when running remotely,
158
+ # metaflow is installed in a directory and if any extension is using distutils to
159
+ # determine its version, this would return None and querying the version directly
160
+ # from the extension would fail to produce the correct result
161
+ # - Then if we are in the GIT repository and if so, use the git describe
162
+ # - If we don't have an INFO file, we look at the version information that is
163
+ # populated by metaflow and the extensions.
164
+
165
+ if _version_cache[public] is not None:
166
+ return _version_cache[public]
167
+
168
+ version = (
169
+ read_info_version()
170
+ ) # Version info is cached in INFO file; includes extension info
171
+
172
+ if version:
173
+ _version_cache[public] = version
174
+ return version
175
+
176
+ # Get the version for Metaflow, favor the GIT version
177
+ import metaflow
178
+
179
+ version = format_git_describe(
180
+ call_git_describe(file_to_check=metaflow.__file__), public=public
181
+ )
182
+ if version is None:
137
183
  version = metaflow.__version__
138
- version_addl = metaflow.__version_addl__
139
- if version is None: # not a proper python package
140
- return read_info_version()
141
- if version_addl:
142
- return "+".join([version, version_addl])
184
+
185
+ # Look for extensions and compute their versions. Properly formed extensions have
186
+ # a toplevel file which will contain a __mf_extensions__ value and a __version__
187
+ # value. We already saved the properly formed modules when loading metaflow in
188
+ # __ext_tl_modules__.
189
+ ext_versions = []
190
+ for pkg_name, extension_module in metaflow.__ext_tl_modules__:
191
+ ext_name = getattr(extension_module, "__mf_extensions__", "<unk>")
192
+ ext_version = format_git_describe(
193
+ call_git_describe(file_to_check=extension_module.__file__), public=public
194
+ )
195
+ if ext_version is None:
196
+ ext_version = getattr(extension_module, "__version__", "<unk>")
197
+ # Update the package information about reported version for the extension
198
+ update_package_info(
199
+ package_name=pkg_name,
200
+ extension_name=ext_name,
201
+ package_version=ext_version,
202
+ )
203
+ ext_versions.append("%s(%s)" % (ext_name, ext_version))
204
+
205
+ # We now have all the information about extensions so we can form the final string
206
+ if ext_versions:
207
+ version = version + "+" + ";".join(ext_versions)
208
+ _version_cache[public] = version
143
209
  return version
metaflow/package.py CHANGED
@@ -10,7 +10,8 @@ from .extension_support import EXT_PKG, package_mfext_all
10
10
  from .metaflow_config import DEFAULT_PACKAGE_SUFFIXES
11
11
  from .exception import MetaflowException
12
12
  from .util import to_unicode
13
- from . import R, INFO_FILE
13
+ from . import R
14
+ from .info_file import INFO_FILE
14
15
 
15
16
  DEFAULT_SUFFIXES_LIST = DEFAULT_PACKAGE_SUFFIXES.split(",")
16
17
  METAFLOW_SUFFIXES_LIST = [".py", ".html", ".css", ".js"]
@@ -456,11 +456,17 @@ class ArgoWorkflows(object):
456
456
  )
457
457
  seen.add(norm)
458
458
 
459
- if param.kwargs.get("type") == JSONType or isinstance(
460
- param.kwargs.get("type"), FilePathClass
461
- ):
462
- # Special-case this to avoid touching core
459
+ extra_attrs = {}
460
+ if param.kwargs.get("type") == JSONType:
461
+ param_type = str(param.kwargs.get("type").name)
462
+ elif isinstance(param.kwargs.get("type"), FilePathClass):
463
463
  param_type = str(param.kwargs.get("type").name)
464
+ extra_attrs["is_text"] = getattr(
465
+ param.kwargs.get("type"), "_is_text", True
466
+ )
467
+ extra_attrs["encoding"] = getattr(
468
+ param.kwargs.get("type"), "_encoding", "utf-8"
469
+ )
464
470
  else:
465
471
  param_type = str(param.kwargs.get("type").__name__)
466
472
 
@@ -488,6 +494,7 @@ class ArgoWorkflows(object):
488
494
  type=param_type,
489
495
  description=param.kwargs.get("help"),
490
496
  is_required=is_required,
497
+ **extra_attrs
491
498
  )
492
499
  return parameters
493
500
 
@@ -1906,6 +1913,12 @@ class ArgoWorkflows(object):
1906
1913
  jobset.environment_variable(
1907
1914
  "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
1908
1915
  )
1916
+ # We need this task-id set so that all the nodes are aware of the control
1917
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
1918
+ jobset.environment_variable(
1919
+ "MF_PARALLEL_CONTROL_TASK_ID",
1920
+ "control-{{inputs.parameters.task-id-entropy}}-0",
1921
+ )
1909
1922
  # for k, v in .items():
1910
1923
  jobset.environment_variables_from_selectors(
1911
1924
  {
@@ -2518,10 +2531,29 @@ class ArgoWorkflows(object):
2518
2531
  # Use all the affordances available to _parameters task
2519
2532
  executable = self.environment.executable("_parameters")
2520
2533
  run_id = "argo-{{workflow.name}}"
2521
- entrypoint = [executable, "-m metaflow.plugins.argo.daemon"]
2522
- heartbeat_cmds = "{entrypoint} --flow_name {flow_name} --run_id {run_id} {tags} heartbeat".format(
2534
+ script_name = os.path.basename(sys.argv[0])
2535
+ entrypoint = [executable, script_name]
2536
+ # FlowDecorators can define their own top-level options. These might affect run level information
2537
+ # so it is important to pass these to the heartbeat process as well, as it might be the first task to register a run.
2538
+ top_opts_dict = {}
2539
+ for deco in flow_decorators(self.flow):
2540
+ top_opts_dict.update(deco.get_top_level_options())
2541
+
2542
+ top_level = list(dict_to_cli_options(top_opts_dict)) + [
2543
+ "--quiet",
2544
+ "--metadata=%s" % self.metadata.TYPE,
2545
+ "--environment=%s" % self.environment.TYPE,
2546
+ "--datastore=%s" % self.flow_datastore.TYPE,
2547
+ "--datastore-root=%s" % self.flow_datastore.datastore_root,
2548
+ "--event-logger=%s" % self.event_logger.TYPE,
2549
+ "--monitor=%s" % self.monitor.TYPE,
2550
+ "--no-pylint",
2551
+ "--with=argo_workflows_internal:auto-emit-argo-events=%i"
2552
+ % self.auto_emit_argo_events,
2553
+ ]
2554
+ heartbeat_cmds = "{entrypoint} {top_level} argo-workflows heartbeat --run_id {run_id} {tags}".format(
2523
2555
  entrypoint=" ".join(entrypoint),
2524
- flow_name=self.flow.name,
2556
+ top_level=" ".join(top_level) if top_level else "",
2525
2557
  run_id=run_id,
2526
2558
  tags=" ".join(["--tag %s" % t for t in self.tags]) if self.tags else "",
2527
2559
  )
@@ -2561,8 +2593,8 @@ class ArgoWorkflows(object):
2561
2593
  cmd_str = " && ".join([init_cmds, heartbeat_cmds])
2562
2594
  cmds = shlex.split('bash -c "%s"' % cmd_str)
2563
2595
 
2564
- # TODO: Check that this is the minimal env.
2565
2596
  # Env required for sending heartbeats to the metadata service, nothing extra.
2597
+ # prod token / runtime info is required to correctly register flow branches
2566
2598
  env = {
2567
2599
  # These values are needed by Metaflow to set it's internal
2568
2600
  # state appropriately.
@@ -2572,9 +2604,16 @@ class ArgoWorkflows(object):
2572
2604
  "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
2573
2605
  "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
2574
2606
  "METAFLOW_USER": "argo-workflows",
2607
+ "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
2608
+ "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
2575
2609
  "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
2576
2610
  "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
2611
+ "METAFLOW_CARD_S3ROOT": CARD_S3ROOT,
2612
+ "METAFLOW_KUBERNETES_WORKLOAD": 1,
2613
+ "METAFLOW_KUBERNETES_FETCH_EC2_METADATA": KUBERNETES_FETCH_EC2_METADATA,
2614
+ "METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
2577
2615
  "METAFLOW_OWNER": self.username,
2616
+ "METAFLOW_PRODUCTION_TOKEN": self.production_token, # Used in identity resolving. This affects system tags.
2578
2617
  }
2579
2618
  # support Metaflow sandboxes
2580
2619
  env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
@@ -4,6 +4,7 @@ import platform
4
4
  import re
5
5
  import sys
6
6
  from hashlib import sha1
7
+ from time import sleep
7
8
 
8
9
  from metaflow import JSONType, Run, current, decorators, parameters
9
10
  from metaflow._vendor import click
@@ -959,6 +960,31 @@ def list_workflow_templates(obj, all=None):
959
960
  obj.echo_always(template_name)
960
961
 
961
962
 
963
+ # Internal CLI command to run a heartbeat daemon in an Argo Workflows Daemon container.
964
+ @argo_workflows.command(hidden=True, help="start heartbeat process for a run")
965
+ @click.option("--run_id", required=True)
966
+ @click.option(
967
+ "--tag",
968
+ "tags",
969
+ multiple=True,
970
+ default=None,
971
+ help="Annotate all objects produced by Argo Workflows runs "
972
+ "with the given tag. You can specify this option multiple "
973
+ "times to attach multiple tags.",
974
+ )
975
+ @click.pass_obj
976
+ def heartbeat(obj, run_id, tags=None):
977
+ # Try to register a run in case the start task has not taken care of it yet.
978
+ obj.metadata.register_run_id(run_id, tags)
979
+ # Start run heartbeat
980
+ obj.metadata.start_run_heartbeat(obj.flow.name, run_id)
981
+ # Keepalive loop
982
+ while True:
983
+ # Do not pollute daemon logs with anything unnecessary,
984
+ # as they might be extremely long running.
985
+ sleep(10)
986
+
987
+
962
988
  def validate_run_id(
963
989
  workflow_name, token_prefix, authorize, run_id, instructions_fn=None
964
990
  ):
@@ -89,6 +89,9 @@ class BatchJob(object):
89
89
  # Multinode
90
90
  if getattr(self, "num_parallel", 0) >= 1:
91
91
  num_nodes = self.num_parallel
92
+ # We need this task-id set so that all the nodes are aware of the control
93
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
94
+ self.environment_variable("MF_PARALLEL_CONTROL_TASK_ID", self._task_id)
92
95
  main_task_override = copy.deepcopy(self.payload["containerOverrides"])
93
96
 
94
97
  # main
@@ -0,0 +1,25 @@
1
+ from metaflow.exception import CommandException
2
+ from metaflow.util import get_username, get_latest_run_id
3
+
4
+
5
+ def parse_cli_options(flow_name, run_id, user, my_runs, echo):
6
+ if user and my_runs:
7
+ raise CommandException("--user and --my-runs are mutually exclusive.")
8
+
9
+ if run_id and my_runs:
10
+ raise CommandException("--run_id and --my-runs are mutually exclusive.")
11
+
12
+ if my_runs:
13
+ user = get_username()
14
+
15
+ latest_run = True
16
+
17
+ if user and not run_id:
18
+ latest_run = False
19
+
20
+ if not run_id and latest_run:
21
+ run_id = get_latest_run_id(echo, flow_name)
22
+ if run_id is None:
23
+ raise CommandException("A previous run id was not found. Specify --run-id.")
24
+
25
+ return flow_name, run_id, user
@@ -401,6 +401,9 @@ class Kubernetes(object):
401
401
  .label("app.kubernetes.io/name", "metaflow-task")
402
402
  .label("app.kubernetes.io/part-of", "metaflow")
403
403
  )
404
+ # We need this task-id set so that all the nodes are aware of the control
405
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
406
+ jobset.environment_variable("MF_PARALLEL_CONTROL_TASK_ID", str(task_id))
404
407
 
405
408
  ## ----------- control/worker specific values START here -----------
406
409
  # We will now set the appropriate command for the control/worker job
@@ -3,10 +3,12 @@ import sys
3
3
  import time
4
4
  import traceback
5
5
 
6
+ from metaflow.plugins.kubernetes.kube_utils import parse_cli_options
7
+ from metaflow.plugins.kubernetes.kubernetes_client import KubernetesClient
6
8
  import metaflow.tracing as tracing
7
9
  from metaflow import JSONTypeClass, util
8
10
  from metaflow._vendor import click
9
- from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
11
+ from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException
10
12
  from metaflow.metadata.util import sync_local_metadata_from_datastore
11
13
  from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
12
14
  from metaflow.mflog import TASK_LOG_SOURCE
@@ -305,3 +307,84 @@ def step(
305
307
  sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
306
308
  finally:
307
309
  _sync_metadata()
310
+
311
+
312
+ @kubernetes.command(help="List unfinished Kubernetes tasks of this flow.")
313
+ @click.option(
314
+ "--my-runs",
315
+ default=False,
316
+ is_flag=True,
317
+ help="List all my unfinished tasks.",
318
+ )
319
+ @click.option("--user", default=None, help="List unfinished tasks for the given user.")
320
+ @click.option(
321
+ "--run-id",
322
+ default=None,
323
+ help="List unfinished tasks corresponding to the run id.",
324
+ )
325
+ @click.pass_obj
326
+ def list(obj, run_id, user, my_runs):
327
+ flow_name, run_id, user = parse_cli_options(
328
+ obj.flow.name, run_id, user, my_runs, obj.echo
329
+ )
330
+ kube_client = KubernetesClient()
331
+ pods = kube_client.list(obj.flow.name, run_id, user)
332
+
333
+ def format_timestamp(timestamp=None):
334
+ if timestamp is None:
335
+ return "-"
336
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
337
+
338
+ for pod in pods:
339
+ obj.echo(
340
+ "Run: *{run_id}* "
341
+ "Pod: *{pod_id}* "
342
+ "Started At: {startedAt} "
343
+ "Status: *{status}*".format(
344
+ run_id=pod.metadata.annotations.get(
345
+ "metaflow/run_id",
346
+ pod.metadata.labels.get("workflows.argoproj.io/workflow"),
347
+ ),
348
+ pod_id=pod.metadata.name,
349
+ startedAt=format_timestamp(pod.status.start_time),
350
+ status=pod.status.phase,
351
+ )
352
+ )
353
+
354
+ if not pods:
355
+ obj.echo("No active Kubernetes pods found.")
356
+
357
+
358
+ @kubernetes.command(
359
+ help="Terminate unfinished Kubernetes tasks of this flow. Killed pods may result in newer attempts when using @retry."
360
+ )
361
+ @click.option(
362
+ "--my-runs",
363
+ default=False,
364
+ is_flag=True,
365
+ help="Kill all my unfinished tasks.",
366
+ )
367
+ @click.option(
368
+ "--user",
369
+ default=None,
370
+ help="Terminate unfinished tasks for the given user.",
371
+ )
372
+ @click.option(
373
+ "--run-id",
374
+ default=None,
375
+ help="Terminate unfinished tasks corresponding to the run id.",
376
+ )
377
+ @click.pass_obj
378
+ def kill(obj, run_id, user, my_runs):
379
+ flow_name, run_id, user = parse_cli_options(
380
+ obj.flow.name, run_id, user, my_runs, obj.echo
381
+ )
382
+
383
+ if run_id is not None and run_id.startswith("argo-") or user == "argo-workflows":
384
+ raise MetaflowException(
385
+ "Killing pods launched by Argo Workflows is not supported. "
386
+ "Use *argo-workflows terminate* instead."
387
+ )
388
+
389
+ kube_client = KubernetesClient()
390
+ kube_client.kill_pods(flow_name, run_id, user, obj.echo)
@@ -1,8 +1,10 @@
1
+ from concurrent.futures import ThreadPoolExecutor
1
2
  import os
2
3
  import sys
3
4
  import time
4
5
 
5
6
  from metaflow.exception import MetaflowException
7
+ from metaflow.metaflow_config import KUBERNETES_NAMESPACE
6
8
 
7
9
  from .kubernetes_job import KubernetesJob, KubernetesJobSet
8
10
 
@@ -28,6 +30,7 @@ class KubernetesClient(object):
28
30
  % sys.executable
29
31
  )
30
32
  self._refresh_client()
33
+ self._namespace = KUBERNETES_NAMESPACE
31
34
 
32
35
  def _refresh_client(self):
33
36
  from kubernetes import client, config
@@ -60,6 +63,100 @@ class KubernetesClient(object):
60
63
 
61
64
  return self._client
62
65
 
66
+ def _find_active_pods(self, flow_name, run_id=None, user=None):
67
+ def _request(_continue=None):
68
+ # handle paginated responses
69
+ return self._client.CoreV1Api().list_namespaced_pod(
70
+ namespace=self._namespace,
71
+ # limited selector support for K8S api. We want to cover multiple statuses: Running / Pending / Unknown
72
+ field_selector="status.phase!=Succeeded,status.phase!=Failed",
73
+ limit=1000,
74
+ _continue=_continue,
75
+ )
76
+
77
+ results = _request()
78
+
79
+ if run_id is not None:
80
+ # handle argo prefixes in run_id
81
+ run_id = run_id[run_id.startswith("argo-") and len("argo-") :]
82
+
83
+ while results.metadata._continue or results.items:
84
+ for pod in results.items:
85
+ match = (
86
+ # arbitrary pods might have no annotations at all.
87
+ pod.metadata.annotations
88
+ and pod.metadata.labels
89
+ and (
90
+ run_id is None
91
+ or (pod.metadata.annotations.get("metaflow/run_id") == run_id)
92
+ # we want to also match pods launched by argo-workflows
93
+ or (
94
+ pod.metadata.labels.get("workflows.argoproj.io/workflow")
95
+ == run_id
96
+ )
97
+ )
98
+ and (
99
+ user is None
100
+ or pod.metadata.annotations.get("metaflow/user") == user
101
+ )
102
+ and (
103
+ pod.metadata.annotations.get("metaflow/flow_name") == flow_name
104
+ )
105
+ )
106
+ if match:
107
+ yield pod
108
+ if not results.metadata._continue:
109
+ break
110
+ results = _request(results.metadata._continue)
111
+
112
+ def list(self, flow_name, run_id, user):
113
+ results = self._find_active_pods(flow_name, run_id, user)
114
+
115
+ return list(results)
116
+
117
+ def kill_pods(self, flow_name, run_id, user, echo):
118
+ from kubernetes.stream import stream
119
+
120
+ api_instance = self._client.CoreV1Api()
121
+ job_api = self._client.BatchV1Api()
122
+ pods = self._find_active_pods(flow_name, run_id, user)
123
+
124
+ def _kill_pod(pod):
125
+ echo("Killing Kubernetes pod %s\n" % pod.metadata.name)
126
+ try:
127
+ stream(
128
+ api_instance.connect_get_namespaced_pod_exec,
129
+ name=pod.metadata.name,
130
+ namespace=pod.metadata.namespace,
131
+ command=[
132
+ "/bin/sh",
133
+ "-c",
134
+ "/sbin/killall5",
135
+ ],
136
+ stderr=True,
137
+ stdin=False,
138
+ stdout=True,
139
+ tty=False,
140
+ )
141
+ except Exception:
142
+ # best effort kill for pod can fail.
143
+ try:
144
+ job_name = pod.metadata.labels.get("job-name", None)
145
+ if job_name is None:
146
+ raise Exception("Could not determine job name")
147
+
148
+ job_api.patch_namespaced_job(
149
+ name=job_name,
150
+ namespace=pod.metadata.namespace,
151
+ field_manager="metaflow",
152
+ body={"spec": {"parallelism": 0}},
153
+ )
154
+ except Exception as e:
155
+ echo("failed to kill pod %s - %s" % (pod.metadata.name, str(e)))
156
+
157
+ with ThreadPoolExecutor() as executor:
158
+ executor.map(_kill_pod, list(pods))
159
+
63
160
  def jobset(self, **kwargs):
64
161
  return KubernetesJobSet(self, **kwargs)
65
162
 
@@ -72,6 +72,10 @@ class KubernetesDecorator(StepDecorator):
72
72
  Kubernetes secrets to use when launching pod in Kubernetes. These
73
73
  secrets are in addition to the ones defined in `METAFLOW_KUBERNETES_SECRETS`
74
74
  in Metaflow configuration.
75
+ node_selector: Union[Dict[str,str], str], optional, default None
76
+ Kubernetes node selector(s) to apply to the pod running the task.
77
+ Can be passed in as a comma separated string of values e.g. "kubernetes.io/os=linux,kubernetes.io/arch=amd64"
78
+ or as a dictionary {"kubernetes.io/os": "linux", "kubernetes.io/arch": "amd64"}
75
79
  namespace : str, default METAFLOW_KUBERNETES_NAMESPACE
76
80
  Kubernetes namespace to use when launching pod in Kubernetes.
77
81
  gpu : int, optional, default None