ob-metaflow 2.12.32.1__py2.py3-none-any.whl → 2.12.35.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

metaflow/flowspec.py CHANGED
@@ -38,6 +38,7 @@ INTERNAL_ARTIFACTS_SET = set(
38
38
  "_unbounded_foreach",
39
39
  "_control_mapper_tasks",
40
40
  "_control_task_is_mapper_zero",
41
+ "_parallel_ubf_iter",
41
42
  ]
42
43
  )
43
44
 
@@ -378,6 +378,8 @@ KUBERNETES_PORT = from_conf("KUBERNETES_PORT", None)
378
378
  KUBERNETES_CPU = from_conf("KUBERNETES_CPU", None)
379
379
  KUBERNETES_MEMORY = from_conf("KUBERNETES_MEMORY", None)
380
380
  KUBERNETES_DISK = from_conf("KUBERNETES_DISK", None)
381
+ # Default kubernetes QoS class
382
+ KUBERNETES_QOS = from_conf("KUBERNETES_QOS", "burstable")
381
383
 
382
384
  ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
383
385
  ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")
@@ -6,7 +6,18 @@ from tempfile import NamedTemporaryFile
6
6
  import time
7
7
  import metaflow.tracing as tracing
8
8
 
9
- from typing import Any, Callable, Iterable, Iterator, List, Optional
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Iterable,
13
+ Iterator,
14
+ List,
15
+ Optional,
16
+ NoReturn,
17
+ Tuple,
18
+ TypeVar,
19
+ Union,
20
+ )
10
21
 
11
22
  try:
12
23
  # Python 2
@@ -30,7 +41,13 @@ class MulticoreException(Exception):
30
41
  pass
31
42
 
32
43
 
33
- def _spawn(func, arg, dir):
44
+ _A = TypeVar("_A")
45
+ _R = TypeVar("_R")
46
+
47
+
48
+ def _spawn(
49
+ func: Callable[[_A], _R], arg: _A, dir: Optional[str]
50
+ ) -> Union[Tuple[int, str], NoReturn]:
34
51
  with NamedTemporaryFile(prefix="parallel_map_", dir=dir, delete=False) as tmpfile:
35
52
  output_file = tmpfile.name
36
53
 
@@ -63,11 +80,11 @@ def _spawn(func, arg, dir):
63
80
 
64
81
 
65
82
  def parallel_imap_unordered(
66
- func: Callable[[Any], Any],
67
- iterable: Iterable[Any],
83
+ func: Callable[[_A], _R],
84
+ iterable: Iterable[_A],
68
85
  max_parallel: Optional[int] = None,
69
86
  dir: Optional[str] = None,
70
- ) -> Iterator[Any]:
87
+ ) -> Iterator[_R]:
71
88
  """
72
89
  Parallelizes execution of a function using multiprocessing. The result
73
90
  order is not guaranteed.
@@ -79,9 +96,9 @@ def parallel_imap_unordered(
79
96
  iterable : Iterable[Any]
80
97
  Iterable over arguments to pass to fun
81
98
  max_parallel int, optional, default None
82
- Maximum parallelism. If not specified, uses the number of CPUs
99
+ Maximum parallelism. If not specified, it uses the number of CPUs
83
100
  dir : str, optional, default None
84
- If specified, directory where temporary files are created
101
+ If specified, it's the directory where temporary files are created
85
102
 
86
103
  Yields
87
104
  ------
@@ -121,14 +138,14 @@ def parallel_imap_unordered(
121
138
 
122
139
 
123
140
  def parallel_map(
124
- func: Callable[[Any], Any],
125
- iterable: Iterable[Any],
141
+ func: Callable[[_A], _R],
142
+ iterable: Iterable[_A],
126
143
  max_parallel: Optional[int] = None,
127
144
  dir: Optional[str] = None,
128
- ) -> List[Any]:
145
+ ) -> List[_R]:
129
146
  """
130
147
  Parallelizes execution of a function using multiprocessing. The result
131
- order is that of the arguments in `iterable`
148
+ order is that of the arguments in `iterable`.
132
149
 
133
150
  Parameters
134
151
  ----------
@@ -137,9 +154,9 @@ def parallel_map(
137
154
  iterable : Iterable[Any]
138
155
  Iterable over arguments to pass to fun
139
156
  max_parallel int, optional, default None
140
- Maximum parallelism. If not specified, uses the number of CPUs
157
+ Maximum parallelism. If not specified, it uses the number of CPUs
141
158
  dir : str, optional, default None
142
- If specified, directory where temporary files are created
159
+ If specified, it's the directory where temporary files are created
143
160
 
144
161
  Returns
145
162
  -------
@@ -155,4 +172,4 @@ def parallel_map(
155
172
  res = parallel_imap_unordered(
156
173
  wrapper, enumerate(iterable), max_parallel=max_parallel, dir=dir
157
174
  )
158
- return [r for idx, r in sorted(res)]
175
+ return [r for _, r in sorted(res)]
@@ -46,6 +46,7 @@ from metaflow.parameters import (
46
46
  # TODO: Move chevron to _vendor
47
47
  from metaflow.plugins.cards.card_modules import chevron
48
48
  from metaflow.plugins.kubernetes.kubernetes import Kubernetes
49
+ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
49
50
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
50
51
  from metaflow.util import compress_list, dict_to_cli_options, get_username
51
52
 
@@ -428,25 +429,25 @@ class Airflow(object):
428
429
  if k8s_deco.attributes["namespace"] is not None
429
430
  else "default"
430
431
  )
431
-
432
+ qos_requests, qos_limits = qos_requests_and_limits(
433
+ k8s_deco.attributes["qos"],
434
+ k8s_deco.attributes["cpu"],
435
+ k8s_deco.attributes["memory"],
436
+ k8s_deco.attributes["disk"],
437
+ )
432
438
  resources = dict(
433
- requests={
434
- "cpu": k8s_deco.attributes["cpu"],
435
- "memory": "%sM" % str(k8s_deco.attributes["memory"]),
436
- "ephemeral-storage": str(k8s_deco.attributes["disk"]),
437
- }
439
+ requests=qos_requests,
440
+ limits={
441
+ **qos_limits,
442
+ **{
443
+ "%s.com/gpu".lower()
444
+ % k8s_deco.attributes["gpu_vendor"]: str(k8s_deco.attributes["gpu"])
445
+ for k in [0]
446
+ # Don't set GPU limits if gpu isn't specified.
447
+ if k8s_deco.attributes["gpu"] is not None
448
+ },
449
+ },
438
450
  )
439
- if k8s_deco.attributes["gpu"] is not None:
440
- resources.update(
441
- dict(
442
- limits={
443
- "%s.com/gpu".lower()
444
- % k8s_deco.attributes["gpu_vendor"]: str(
445
- k8s_deco.attributes["gpu"]
446
- )
447
- }
448
- )
449
- )
450
451
 
451
452
  annotations = {
452
453
  "metaflow/production_token": self.production_token,
@@ -54,6 +54,7 @@ from metaflow.metaflow_config import (
54
54
  from metaflow.metaflow_config_funcs import config_values, init_config
55
55
  from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
56
56
  from metaflow.parameters import deploy_time_eval
57
+ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
57
58
  from metaflow.plugins.kubernetes.kubernetes import (
58
59
  parse_kube_keyvalue_list,
59
60
  validate_kube_labels,
@@ -1858,6 +1859,13 @@ class ArgoWorkflows(object):
1858
1859
  if tmpfs_enabled and tmpfs_tempdir:
1859
1860
  env["METAFLOW_TEMPDIR"] = tmpfs_path
1860
1861
 
1862
+ qos_requests, qos_limits = qos_requests_and_limits(
1863
+ resources["qos"],
1864
+ resources["cpu"],
1865
+ resources["memory"],
1866
+ resources["disk"],
1867
+ )
1868
+
1861
1869
  # Create a ContainerTemplate for this node. Ideally, we would have
1862
1870
  # liked to inline this ContainerTemplate and avoid scanning the workflow
1863
1871
  # twice, but due to issues with variable substitution, we will have to
@@ -1921,6 +1929,7 @@ class ArgoWorkflows(object):
1921
1929
  persistent_volume_claims=resources["persistent_volume_claims"],
1922
1930
  shared_memory=shared_memory,
1923
1931
  port=port,
1932
+ qos=resources["qos"],
1924
1933
  )
1925
1934
 
1926
1935
  for k, v in env.items():
@@ -2113,17 +2122,17 @@ class ArgoWorkflows(object):
2113
2122
  image=resources["image"],
2114
2123
  image_pull_policy=resources["image_pull_policy"],
2115
2124
  resources=kubernetes_sdk.V1ResourceRequirements(
2116
- requests={
2117
- "cpu": str(resources["cpu"]),
2118
- "memory": "%sM" % str(resources["memory"]),
2119
- "ephemeral-storage": "%sM"
2120
- % str(resources["disk"]),
2121
- },
2125
+ requests=qos_requests,
2122
2126
  limits={
2123
- "%s.com/gpu".lower()
2124
- % resources["gpu_vendor"]: str(resources["gpu"])
2125
- for k in [0]
2126
- if resources["gpu"] is not None
2127
+ **qos_limits,
2128
+ **{
2129
+ "%s.com/gpu".lower()
2130
+ % resources["gpu_vendor"]: str(
2131
+ resources["gpu"]
2132
+ )
2133
+ for k in [0]
2134
+ if resources["gpu"] is not None
2135
+ },
2127
2136
  },
2128
2137
  ),
2129
2138
  # Configure secrets
@@ -2360,7 +2369,7 @@ class ArgoWorkflows(object):
2360
2369
  "memory": "500Mi",
2361
2370
  },
2362
2371
  ),
2363
- )
2372
+ ).to_dict()
2364
2373
  )
2365
2374
  ),
2366
2375
  Template("capture-error-hook-fn-preflight").steps(
@@ -2719,7 +2728,7 @@ class ArgoWorkflows(object):
2719
2728
  },
2720
2729
  ),
2721
2730
  )
2722
- )
2731
+ ).to_dict()
2723
2732
  )
2724
2733
  )
2725
2734
 
@@ -2889,7 +2898,7 @@ class ArgoWorkflows(object):
2889
2898
  "memory": "250Mi",
2890
2899
  },
2891
2900
  ),
2892
- )
2901
+ ).to_dict()
2893
2902
  )
2894
2903
  )
2895
2904
  .service_account_name(ARGO_EVENTS_SERVICE_ACCOUNT)
@@ -10,7 +10,7 @@ from metaflow.metaflow_config import KUBERNETES_NAMESPACE
10
10
  from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
11
11
  from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun
12
12
 
13
- from metaflow.runner.utils import get_lower_level_group, handle_timeout
13
+ from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo
14
14
 
15
15
 
16
16
  def generate_fake_flow_file_contents(
@@ -341,18 +341,14 @@ class ArgoWorkflowsDeployedFlow(DeployedFlow):
341
341
  Exception
342
342
  If there is an error during the trigger process.
343
343
  """
344
- with tempfile.TemporaryDirectory() as temp_dir:
345
- tfp_runner_attribute = tempfile.NamedTemporaryFile(
346
- dir=temp_dir, delete=False
347
- )
348
-
344
+ with temporary_fifo() as (attribute_file_path, attribute_file_fd):
349
345
  # every subclass needs to have `self.deployer_kwargs`
350
346
  command = get_lower_level_group(
351
347
  self.deployer.api,
352
348
  self.deployer.top_level_kwargs,
353
349
  self.deployer.TYPE,
354
350
  self.deployer.deployer_kwargs,
355
- ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
351
+ ).trigger(deployer_attribute_file=attribute_file_path, **kwargs)
356
352
 
357
353
  pid = self.deployer.spm.run_command(
358
354
  [sys.executable, *command],
@@ -363,7 +359,7 @@ class ArgoWorkflowsDeployedFlow(DeployedFlow):
363
359
 
364
360
  command_obj = self.deployer.spm.get(pid)
365
361
  content = handle_timeout(
366
- tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
362
+ attribute_file_fd, command_obj, self.deployer.file_read_timeout
367
363
  )
368
364
 
369
365
  if command_obj.process.returncode == 0:
@@ -6,7 +6,7 @@ from typing import ClassVar, Optional, List
6
6
  from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
7
7
  from metaflow.runner.deployer import DeployedFlow, TriggeredRun
8
8
 
9
- from metaflow.runner.utils import get_lower_level_group, handle_timeout
9
+ from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo
10
10
 
11
11
 
12
12
  class StepFunctionsTriggeredRun(TriggeredRun):
@@ -196,18 +196,14 @@ class StepFunctionsDeployedFlow(DeployedFlow):
196
196
  Exception
197
197
  If there is an error during the trigger process.
198
198
  """
199
- with tempfile.TemporaryDirectory() as temp_dir:
200
- tfp_runner_attribute = tempfile.NamedTemporaryFile(
201
- dir=temp_dir, delete=False
202
- )
203
-
199
+ with temporary_fifo() as (attribute_file_path, attribute_file_fd):
204
200
  # every subclass needs to have `self.deployer_kwargs`
205
201
  command = get_lower_level_group(
206
202
  self.deployer.api,
207
203
  self.deployer.top_level_kwargs,
208
204
  self.deployer.TYPE,
209
205
  self.deployer.deployer_kwargs,
210
- ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
206
+ ).trigger(deployer_attribute_file=attribute_file_path, **kwargs)
211
207
 
212
208
  pid = self.deployer.spm.run_command(
213
209
  [sys.executable, *command],
@@ -218,7 +214,7 @@ class StepFunctionsDeployedFlow(DeployedFlow):
218
214
 
219
215
  command_obj = self.deployer.spm.get(pid)
220
216
  content = handle_timeout(
221
- tfp_runner_attribute, command_obj, self.deployer.file_read_timeout
217
+ attribute_file_fd, command_obj, self.deployer.file_read_timeout
222
218
  )
223
219
 
224
220
  if command_obj.process.returncode == 0:
@@ -600,7 +600,9 @@ class S3(object):
600
600
  # returned are Unicode.
601
601
  key = getattr(key_value, "key", key_value)
602
602
  if self._s3root is None:
603
- parsed = urlparse(to_unicode(key))
603
+ # NOTE: S3 allows fragments as part of object names, e.g. /dataset #1/data.txt
604
+ # Without allow_fragments=False the parsed.path for an object name with fragments is incomplete.
605
+ parsed = urlparse(to_unicode(key), allow_fragments=False)
604
606
  if parsed.scheme == "s3" and parsed.path:
605
607
  return key
606
608
  else:
@@ -765,7 +767,9 @@ class S3(object):
765
767
  """
766
768
 
767
769
  url = self._url(key)
768
- src = urlparse(url)
770
+ # NOTE: S3 allows fragments as part of object names, e.g. /dataset #1/data.txt
771
+ # Without allow_fragments=False the parsed src.path for an object name with fragments is incomplete.
772
+ src = urlparse(url, allow_fragments=False)
769
773
 
770
774
  def _info(s3, tmp):
771
775
  resp = s3.head_object(Bucket=src.netloc, Key=src.path.lstrip('/"'))
@@ -891,7 +895,9 @@ class S3(object):
891
895
  DOWNLOAD_MAX_CHUNK = 2 * 1024 * 1024 * 1024 - 1
892
896
 
893
897
  url, r = self._url_and_range(key)
894
- src = urlparse(url)
898
+ # NOTE: S3 allows fragments as part of object names, e.g. /dataset #1/data.txt
899
+ # Without allow_fragments=False the parsed src.path for an object name with fragments is incomplete.
900
+ src = urlparse(url, allow_fragments=False)
895
901
 
896
902
  def _download(s3, tmp):
897
903
  if r:
@@ -1173,7 +1179,9 @@ class S3(object):
1173
1179
  blob.close = lambda: None
1174
1180
 
1175
1181
  url = self._url(key)
1176
- src = urlparse(url)
1182
+ # NOTE: S3 allows fragments as part of object names, e.g. /dataset #1/data.txt
1183
+ # Without allow_fragments=False the parsed src.path for an object name with fragments is incomplete.
1184
+ src = urlparse(url, allow_fragments=False)
1177
1185
  extra_args = None
1178
1186
  if content_type or metadata or self._encryption:
1179
1187
  extra_args = {}
@@ -170,7 +170,7 @@ class TriggerDecorator(FlowDecorator):
170
170
  # process every event in events
171
171
  for event in self.attributes["events"]:
172
172
  processed_event = self.process_event_name(event)
173
- self.triggers.append("processed event", processed_event)
173
+ self.triggers.append(processed_event)
174
174
  elif callable(self.attributes["events"]) and not isinstance(
175
175
  self.attributes["events"], DeployTimeField
176
176
  ):
@@ -23,3 +23,32 @@ def parse_cli_options(flow_name, run_id, user, my_runs, echo):
23
23
  raise CommandException("A previous run id was not found. Specify --run-id.")
24
24
 
25
25
  return flow_name, run_id, user
26
+
27
+
28
+ def qos_requests_and_limits(qos: str, cpu: int, memory: int, storage: int):
29
+ "return resource requests and limits for the kubernetes pod based on the given QoS Class"
30
+ # case insensitive matching for QoS class
31
+ qos = qos.lower()
32
+ # Determine the requests and limits to define chosen QoS class
33
+ qos_limits = {}
34
+ qos_requests = {}
35
+ if qos == "guaranteed":
36
+ # Guaranteed - has both cpu/memory limits. requests not required, as these will be inferred.
37
+ qos_limits = {
38
+ "cpu": str(cpu),
39
+ "memory": "%sM" % str(memory),
40
+ "ephemeral-storage": "%sM" % str(storage),
41
+ }
42
+ # NOTE: Even though Kubernetes will produce matching requests for the specified limits, this happens late in the lifecycle.
43
+ # We specify them explicitly here to make some K8S tooling happy, in case they rely on .resources.requests being present at time of submitting the job.
44
+ qos_requests = qos_limits
45
+ else:
46
+ # Burstable - not Guaranteed, and has a memory/cpu limit or request
47
+ qos_requests = {
48
+ "cpu": str(cpu),
49
+ "memory": "%sM" % str(memory),
50
+ "ephemeral-storage": "%sM" % str(storage),
51
+ }
52
+ # TODO: Add support for BestEffort once there is a use case for it.
53
+ # BestEffort - no limit or requests for cpu/memory
54
+ return qos_requests, qos_limits
@@ -196,6 +196,7 @@ class Kubernetes(object):
196
196
  shared_memory=None,
197
197
  port=None,
198
198
  num_parallel=None,
199
+ qos=None,
199
200
  ):
200
201
  name = "js-%s" % str(uuid4())[:6]
201
202
  jobset = (
@@ -228,6 +229,7 @@ class Kubernetes(object):
228
229
  shared_memory=shared_memory,
229
230
  port=port,
230
231
  num_parallel=num_parallel,
232
+ qos=qos,
231
233
  )
232
234
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
233
235
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -504,6 +506,7 @@ class Kubernetes(object):
504
506
  shared_memory=None,
505
507
  port=None,
506
508
  name_pattern=None,
509
+ qos=None,
507
510
  ):
508
511
  if env is None:
509
512
  env = {}
@@ -544,6 +547,7 @@ class Kubernetes(object):
544
547
  persistent_volume_claims=persistent_volume_claims,
545
548
  shared_memory=shared_memory,
546
549
  port=port,
550
+ qos=qos,
547
551
  )
548
552
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
549
553
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -126,6 +126,12 @@ def kubernetes():
126
126
  type=int,
127
127
  help="Number of parallel nodes to run as a multi-node job.",
128
128
  )
129
+ @click.option(
130
+ "--qos",
131
+ default=None,
132
+ type=str,
133
+ help="Quality of Service class for the Kubernetes pod",
134
+ )
129
135
  @click.pass_context
130
136
  def step(
131
137
  ctx,
@@ -154,6 +160,7 @@ def step(
154
160
  shared_memory=None,
155
161
  port=None,
156
162
  num_parallel=None,
163
+ qos=None,
157
164
  **kwargs
158
165
  ):
159
166
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -294,6 +301,7 @@ def step(
294
301
  shared_memory=shared_memory,
295
302
  port=port,
296
303
  num_parallel=num_parallel,
304
+ qos=qos,
297
305
  )
298
306
  except Exception as e:
299
307
  traceback.print_exc(chain=False)
@@ -26,6 +26,7 @@ from metaflow.metaflow_config import (
26
26
  KUBERNETES_SERVICE_ACCOUNT,
27
27
  KUBERNETES_SHARED_MEMORY,
28
28
  KUBERNETES_TOLERATIONS,
29
+ KUBERNETES_QOS,
29
30
  )
30
31
  from metaflow.plugins.resources_decorator import ResourcesDecorator
31
32
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
@@ -43,6 +44,8 @@ except NameError:
43
44
  unicode = str
44
45
  basestring = str
45
46
 
47
+ SUPPORTED_KUBERNETES_QOS_CLASSES = ["Guaranteed", "Burstable"]
48
+
46
49
 
47
50
  class KubernetesDecorator(StepDecorator):
48
51
  """
@@ -111,6 +114,8 @@ class KubernetesDecorator(StepDecorator):
111
114
  hostname_resolution_timeout: int, default 10 * 60
112
115
  Timeout in seconds for the workers tasks in the gang scheduled cluster to resolve the hostname of control task.
113
116
  Only applicable when @parallel is used.
117
+ qos: str, default: Burstable
118
+ Quality of Service class to assign to the pod. Supported values are: Guaranteed, Burstable, BestEffort
114
119
  """
115
120
 
116
121
  name = "kubernetes"
@@ -138,6 +143,7 @@ class KubernetesDecorator(StepDecorator):
138
143
  "compute_pool": None,
139
144
  "executable": None,
140
145
  "hostname_resolution_timeout": 10 * 60,
146
+ "qos": KUBERNETES_QOS,
141
147
  }
142
148
  package_url = None
143
149
  package_sha = None
@@ -261,6 +267,17 @@ class KubernetesDecorator(StepDecorator):
261
267
  self.step = step
262
268
  self.flow_datastore = flow_datastore
263
269
 
270
+ if (
271
+ self.attributes["qos"] is not None
272
+ # case insensitive matching.
273
+ and self.attributes["qos"].lower()
274
+ not in [c.lower() for c in SUPPORTED_KUBERNETES_QOS_CLASSES]
275
+ ):
276
+ raise MetaflowException(
277
+ "*%s* is not a valid Kubernetes QoS class. Choose one of the following: %s"
278
+ % (self.attributes["qos"], ", ".join(SUPPORTED_KUBERNETES_QOS_CLASSES))
279
+ )
280
+
264
281
  if any([deco.name == "batch" for deco in decos]):
265
282
  raise MetaflowException(
266
283
  "Step *{step}* is marked for execution both on AWS Batch and "
@@ -16,6 +16,8 @@ from .kubernetes_jobsets import (
16
16
  KubernetesJobSet,
17
17
  ) # We need this import for Kubernetes Client.
18
18
 
19
+ from .kube_utils import qos_requests_and_limits
20
+
19
21
 
20
22
  class KubernetesJobException(MetaflowException):
21
23
  headline = "Kubernetes job error"
@@ -75,6 +77,12 @@ class KubernetesJob(object):
75
77
  if self._kwargs["shared_memory"]
76
78
  else None
77
79
  )
80
+ qos_requests, qos_limits = qos_requests_and_limits(
81
+ self._kwargs["qos"],
82
+ self._kwargs["cpu"],
83
+ self._kwargs["memory"],
84
+ self._kwargs["disk"],
85
+ )
78
86
  initial_configs = init_config()
79
87
  for entry in ["OBP_PERIMETER", "OBP_INTEGRATIONS_SECRETS_METADATA_URL"]:
80
88
  if entry not in initial_configs:
@@ -176,20 +184,18 @@ class KubernetesJob(object):
176
184
  image_pull_policy=self._kwargs["image_pull_policy"],
177
185
  name=self._kwargs["step_name"].replace("_", "-"),
178
186
  resources=client.V1ResourceRequirements(
179
- requests={
180
- "cpu": str(self._kwargs["cpu"]),
181
- "memory": "%sM" % str(self._kwargs["memory"]),
182
- "ephemeral-storage": "%sM"
183
- % str(self._kwargs["disk"]),
184
- },
187
+ requests=qos_requests,
185
188
  limits={
186
- "%s.com/gpu".lower()
187
- % self._kwargs["gpu_vendor"]: str(
188
- self._kwargs["gpu"]
189
- )
190
- for k in [0]
191
- # Don't set GPU limits if gpu isn't specified.
192
- if self._kwargs["gpu"] is not None
189
+ **qos_limits,
190
+ **{
191
+ "%s.com/gpu".lower()
192
+ % self._kwargs["gpu_vendor"]: str(
193
+ self._kwargs["gpu"]
194
+ )
195
+ for k in [0]
196
+ # Don't set GPU limits if gpu isn't specified.
197
+ if self._kwargs["gpu"] is not None
198
+ },
193
199
  },
194
200
  ),
195
201
  volume_mounts=(
@@ -9,6 +9,8 @@ from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_
9
9
  from metaflow.tracing import inject_tracing_vars
10
10
  from metaflow.metaflow_config import KUBERNETES_SECRETS
11
11
 
12
+ from .kube_utils import qos_requests_and_limits
13
+
12
14
 
13
15
  class KubernetesJobsetException(MetaflowException):
14
16
  headline = "Kubernetes jobset error"
@@ -554,7 +556,12 @@ class JobSetSpec(object):
554
556
  if self._kwargs["shared_memory"]
555
557
  else None
556
558
  )
557
-
559
+ qos_requests, qos_limits = qos_requests_and_limits(
560
+ self._kwargs["qos"],
561
+ self._kwargs["cpu"],
562
+ self._kwargs["memory"],
563
+ self._kwargs["disk"],
564
+ )
558
565
  return dict(
559
566
  name=self.name,
560
567
  template=client.api_client.ApiClient().sanitize_for_serialization(
@@ -653,21 +660,18 @@ class JobSetSpec(object):
653
660
  "_", "-"
654
661
  ),
655
662
  resources=client.V1ResourceRequirements(
656
- requests={
657
- "cpu": str(self._kwargs["cpu"]),
658
- "memory": "%sM"
659
- % str(self._kwargs["memory"]),
660
- "ephemeral-storage": "%sM"
661
- % str(self._kwargs["disk"]),
662
- },
663
+ requests=qos_requests,
663
664
  limits={
664
- "%s.com/gpu".lower()
665
- % self._kwargs["gpu_vendor"]: str(
666
- self._kwargs["gpu"]
667
- )
668
- for k in [0]
669
- # Don't set GPU limits if gpu isn't specified.
670
- if self._kwargs["gpu"] is not None
665
+ **qos_limits,
666
+ **{
667
+ "%s.com/gpu".lower()
668
+ % self._kwargs["gpu_vendor"]: str(
669
+ self._kwargs["gpu"]
670
+ )
671
+ for k in [0]
672
+ # Don't set GPU limits if gpu isn't specified.
673
+ if self._kwargs["gpu"] is not None
674
+ },
671
675
  },
672
676
  ),
673
677
  volume_mounts=(