metaflow 2.12.33__py2.py3-none-any.whl → 2.12.35__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,11 @@
1
1
  import re
2
+ import json
2
3
 
3
4
  from metaflow import current
4
5
  from metaflow.decorators import FlowDecorator
5
6
  from metaflow.exception import MetaflowException
6
7
  from metaflow.util import is_stringish
8
+ from metaflow.parameters import DeployTimeField, deploy_time_eval
7
9
 
8
10
  # TODO: Support dynamic parameter mapping through a context object that exposes
9
11
  # flow name and user name similar to parameter context
@@ -68,6 +70,75 @@ class TriggerDecorator(FlowDecorator):
68
70
  "options": {},
69
71
  }
70
72
 
73
+ def process_event_name(self, event):
74
+ if is_stringish(event):
75
+ return {"name": str(event)}
76
+ elif isinstance(event, dict):
77
+ if "name" not in event:
78
+ raise MetaflowException(
79
+ "The *event* attribute for *@trigger* is missing the *name* key."
80
+ )
81
+ if callable(event["name"]) and not isinstance(
82
+ event["name"], DeployTimeField
83
+ ):
84
+ event["name"] = DeployTimeField(
85
+ "event_name", str, None, event["name"], False
86
+ )
87
+ event["parameters"] = self.process_parameters(event.get("parameters", {}))
88
+ return event
89
+ elif callable(event) and not isinstance(event, DeployTimeField):
90
+ return DeployTimeField("event", [str, dict], None, event, False)
91
+ else:
92
+ raise MetaflowException(
93
+ "Incorrect format for *event* attribute in *@trigger* decorator. "
94
+ "Supported formats are string and dictionary - \n"
95
+ "@trigger(event='foo') or @trigger(event={'name': 'foo', "
96
+ "'parameters': {'alpha': 'beta'}})"
97
+ )
98
+
99
+ def process_parameters(self, parameters):
100
+ new_param_values = {}
101
+ if isinstance(parameters, (list, tuple)):
102
+ for mapping in parameters:
103
+ if is_stringish(mapping):
104
+ new_param_values[mapping] = mapping
105
+ elif callable(mapping) and not isinstance(mapping, DeployTimeField):
106
+ mapping = DeployTimeField(
107
+ "parameter_val", str, None, mapping, False
108
+ )
109
+ new_param_values[mapping] = mapping
110
+ elif isinstance(mapping, (list, tuple)) and len(mapping) == 2:
111
+ if callable(mapping[0]) and not isinstance(
112
+ mapping[0], DeployTimeField
113
+ ):
114
+ mapping[0] = DeployTimeField(
115
+ "parameter_val", str, None, mapping[0], False
116
+ )
117
+ if callable(mapping[1]) and not isinstance(
118
+ mapping[1], DeployTimeField
119
+ ):
120
+ mapping[1] = DeployTimeField(
121
+ "parameter_val", str, None, mapping[1], False
122
+ )
123
+ new_param_values[mapping[0]] = mapping[1]
124
+ else:
125
+ raise MetaflowException(
126
+ "The *parameters* attribute for event is invalid. "
127
+ "It should be a list/tuple of strings and lists/tuples of size 2"
128
+ )
129
+ elif callable(parameters) and not isinstance(parameters, DeployTimeField):
130
+ return DeployTimeField(
131
+ "parameters", [list, dict, tuple], None, parameters, False
132
+ )
133
+ elif isinstance(parameters, dict):
134
+ for key, value in parameters.items():
135
+ if callable(key) and not isinstance(key, DeployTimeField):
136
+ key = DeployTimeField("flow_parameter", str, None, key, False)
137
+ if callable(value) and not isinstance(value, DeployTimeField):
138
+ value = DeployTimeField("signal_parameter", str, None, value, False)
139
+ new_param_values[key] = value
140
+ return new_param_values
141
+
71
142
  def flow_init(
72
143
  self,
73
144
  flow_name,
@@ -86,41 +157,9 @@ class TriggerDecorator(FlowDecorator):
86
157
  "attributes in *@trigger* decorator."
87
158
  )
88
159
  elif self.attributes["event"]:
89
- # event attribute supports the following formats -
90
- # 1. event='table.prod_db.members'
91
- # 2. event={'name': 'table.prod_db.members',
92
- # 'parameters': {'alpha': 'member_weight'}}
93
- if is_stringish(self.attributes["event"]):
94
- self.triggers.append({"name": str(self.attributes["event"])})
95
- elif isinstance(self.attributes["event"], dict):
96
- if "name" not in self.attributes["event"]:
97
- raise MetaflowException(
98
- "The *event* attribute for *@trigger* is missing the "
99
- "*name* key."
100
- )
101
- param_value = self.attributes["event"].get("parameters", {})
102
- if isinstance(param_value, (list, tuple)):
103
- new_param_value = {}
104
- for mapping in param_value:
105
- if is_stringish(mapping):
106
- new_param_value[mapping] = mapping
107
- elif isinstance(mapping, (list, tuple)) and len(mapping) == 2:
108
- new_param_value[mapping[0]] = mapping[1]
109
- else:
110
- raise MetaflowException(
111
- "The *parameters* attribute for event '%s' is invalid. "
112
- "It should be a list/tuple of strings and lists/tuples "
113
- "of size 2" % self.attributes["event"]["name"]
114
- )
115
- self.attributes["event"]["parameters"] = new_param_value
116
- self.triggers.append(self.attributes["event"])
117
- else:
118
- raise MetaflowException(
119
- "Incorrect format for *event* attribute in *@trigger* decorator. "
120
- "Supported formats are string and dictionary - \n"
121
- "@trigger(event='foo') or @trigger(event={'name': 'foo', "
122
- "'parameters': {'alpha': 'beta'}})"
123
- )
160
+ event = self.attributes["event"]
161
+ processed_event = self.process_event_name(event)
162
+ self.triggers.append(processed_event)
124
163
  elif self.attributes["events"]:
125
164
  # events attribute supports the following formats -
126
165
  # 1. events=[{'name': 'table.prod_db.members',
@@ -128,43 +167,17 @@ class TriggerDecorator(FlowDecorator):
128
167
  # {'name': 'table.prod_db.metadata',
129
168
  # 'parameters': {'beta': 'grade'}}]
130
169
  if isinstance(self.attributes["events"], list):
170
+ # process every event in events
131
171
  for event in self.attributes["events"]:
132
- if is_stringish(event):
133
- self.triggers.append({"name": str(event)})
134
- elif isinstance(event, dict):
135
- if "name" not in event:
136
- raise MetaflowException(
137
- "One or more events in *events* attribute for "
138
- "*@trigger* are missing the *name* key."
139
- )
140
- param_value = event.get("parameters", {})
141
- if isinstance(param_value, (list, tuple)):
142
- new_param_value = {}
143
- for mapping in param_value:
144
- if is_stringish(mapping):
145
- new_param_value[mapping] = mapping
146
- elif (
147
- isinstance(mapping, (list, tuple))
148
- and len(mapping) == 2
149
- ):
150
- new_param_value[mapping[0]] = mapping[1]
151
- else:
152
- raise MetaflowException(
153
- "The *parameters* attribute for event '%s' is "
154
- "invalid. It should be a list/tuple of strings "
155
- "and lists/tuples of size 2" % event["name"]
156
- )
157
- event["parameters"] = new_param_value
158
- self.triggers.append(event)
159
- else:
160
- raise MetaflowException(
161
- "One or more events in *events* attribute in *@trigger* "
162
- "decorator have an incorrect format. Supported format "
163
- "is dictionary - \n"
164
- "@trigger(events=[{'name': 'foo', 'parameters': {'alpha': "
165
- "'beta'}}, {'name': 'bar', 'parameters': "
166
- "{'gamma': 'kappa'}}])"
167
- )
172
+ processed_event = self.process_event_name(event)
173
+ self.triggers.append(processed_event)
174
+ elif callable(self.attributes["events"]) and not isinstance(
175
+ self.attributes["events"], DeployTimeField
176
+ ):
177
+ trig = DeployTimeField(
178
+ "events", list, None, self.attributes["events"], False
179
+ )
180
+ self.triggers.append(trig)
168
181
  else:
169
182
  raise MetaflowException(
170
183
  "Incorrect format for *events* attribute in *@trigger* decorator. "
@@ -178,7 +191,12 @@ class TriggerDecorator(FlowDecorator):
178
191
  raise MetaflowException("No event(s) specified in *@trigger* decorator.")
179
192
 
180
193
  # same event shouldn't occur more than once
181
- names = [x["name"] for x in self.triggers]
194
+ names = [
195
+ x["name"]
196
+ for x in self.triggers
197
+ if not isinstance(x, DeployTimeField)
198
+ and not isinstance(x["name"], DeployTimeField)
199
+ ]
182
200
  if len(names) != len(set(names)):
183
201
  raise MetaflowException(
184
202
  "Duplicate event names defined in *@trigger* decorator."
@@ -188,6 +206,104 @@ class TriggerDecorator(FlowDecorator):
188
206
 
189
207
  # TODO: Handle scenario for local testing using --trigger.
190
208
 
209
+ def format_deploytime_value(self):
210
+ new_triggers = []
211
+ for trigger in self.triggers:
212
+ # Case where trigger is a function that returns a list of events
213
+ # Need to do this bc we need to iterate over list later
214
+ if isinstance(trigger, DeployTimeField):
215
+ evaluated_trigger = deploy_time_eval(trigger)
216
+ if isinstance(evaluated_trigger, dict):
217
+ trigger = evaluated_trigger
218
+ elif isinstance(evaluated_trigger, str):
219
+ trigger = {"name": evaluated_trigger}
220
+ if isinstance(evaluated_trigger, list):
221
+ for trig in evaluated_trigger:
222
+ if is_stringish(trig):
223
+ new_triggers.append({"name": trig})
224
+ else: # dict or another deploytimefield
225
+ new_triggers.append(trig)
226
+ else:
227
+ new_triggers.append(trigger)
228
+ else:
229
+ new_triggers.append(trigger)
230
+
231
+ self.triggers = new_triggers
232
+ for trigger in self.triggers:
233
+ old_trigger = trigger
234
+ trigger_params = trigger.get("parameters", {})
235
+ # Case where param is a function (can return list or dict)
236
+ if isinstance(trigger_params, DeployTimeField):
237
+ trigger_params = deploy_time_eval(trigger_params)
238
+ # If params is a list of strings, convert to dict with same key and value
239
+ if isinstance(trigger_params, (list, tuple)):
240
+ new_trigger_params = {}
241
+ for mapping in trigger_params:
242
+ if is_stringish(mapping) or callable(mapping):
243
+ new_trigger_params[mapping] = mapping
244
+ elif callable(mapping) and not isinstance(mapping, DeployTimeField):
245
+ mapping = DeployTimeField(
246
+ "parameter_val", str, None, mapping, False
247
+ )
248
+ new_trigger_params[mapping] = mapping
249
+ elif isinstance(mapping, (list, tuple)) and len(mapping) == 2:
250
+ if callable(mapping[0]) and not isinstance(
251
+ mapping[0], DeployTimeField
252
+ ):
253
+ mapping[0] = DeployTimeField(
254
+ "parameter_val",
255
+ str,
256
+ None,
257
+ mapping[1],
258
+ False,
259
+ )
260
+ if callable(mapping[1]) and not isinstance(
261
+ mapping[1], DeployTimeField
262
+ ):
263
+ mapping[1] = DeployTimeField(
264
+ "parameter_val",
265
+ str,
266
+ None,
267
+ mapping[1],
268
+ False,
269
+ )
270
+
271
+ new_trigger_params[mapping[0]] = mapping[1]
272
+ else:
273
+ raise MetaflowException(
274
+ "The *parameters* attribute for event '%s' is invalid. "
275
+ "It should be a list/tuple of strings and lists/tuples "
276
+ "of size 2" % self.attributes["event"]["name"]
277
+ )
278
+ trigger_params = new_trigger_params
279
+ trigger["parameters"] = trigger_params
280
+
281
+ trigger_name = trigger.get("name")
282
+ # Case where just the name is a function (always a str)
283
+ if isinstance(trigger_name, DeployTimeField):
284
+ trigger_name = deploy_time_eval(trigger_name)
285
+ trigger["name"] = trigger_name
286
+
287
+ # Third layer
288
+ # {name:, parameters:[func, ..., ...]}
289
+ # {name:, parameters:{func : func2}}
290
+ for trigger in self.triggers:
291
+ old_trigger = trigger
292
+ trigger_params = trigger.get("parameters", {})
293
+ new_trigger_params = {}
294
+ for key, value in trigger_params.items():
295
+ if isinstance(value, DeployTimeField) and key is value:
296
+ evaluated_param = deploy_time_eval(value)
297
+ new_trigger_params[evaluated_param] = evaluated_param
298
+ elif isinstance(value, DeployTimeField):
299
+ new_trigger_params[key] = deploy_time_eval(value)
300
+ elif isinstance(key, DeployTimeField):
301
+ new_trigger_params[deploy_time_eval(key)] = value
302
+ else:
303
+ new_trigger_params[key] = value
304
+ trigger["parameters"] = new_trigger_params
305
+ self.triggers[self.triggers.index(old_trigger)] = trigger
306
+
191
307
 
192
308
  class TriggerOnFinishDecorator(FlowDecorator):
193
309
  """
@@ -312,6 +428,13 @@ class TriggerOnFinishDecorator(FlowDecorator):
312
428
  "The *project_branch* attribute of the *flow* is not a string"
313
429
  )
314
430
  self.triggers.append(result)
431
+ elif callable(self.attributes["flow"]) and not isinstance(
432
+ self.attributes["flow"], DeployTimeField
433
+ ):
434
+ trig = DeployTimeField(
435
+ "fq_name", [str, dict], None, self.attributes["flow"], False
436
+ )
437
+ self.triggers.append(trig)
315
438
  else:
316
439
  raise MetaflowException(
317
440
  "Incorrect type for *flow* attribute in *@trigger_on_finish* "
@@ -369,6 +492,13 @@ class TriggerOnFinishDecorator(FlowDecorator):
369
492
  "Supported type is string or Dict[str, str]- \n"
370
493
  "@trigger_on_finish(flows=['FooFlow', 'BarFlow']"
371
494
  )
495
+ elif callable(self.attributes["flows"]) and not isinstance(
496
+ self.attributes["flows"], DeployTimeField
497
+ ):
498
+ trig = DeployTimeField(
499
+ "flows", list, None, self.attributes["flows"], False
500
+ )
501
+ self.triggers.append(trig)
372
502
  else:
373
503
  raise MetaflowException(
374
504
  "Incorrect type for *flows* attribute in *@trigger_on_finish* "
@@ -383,6 +513,8 @@ class TriggerOnFinishDecorator(FlowDecorator):
383
513
 
384
514
  # Make triggers @project aware
385
515
  for trigger in self.triggers:
516
+ if isinstance(trigger, DeployTimeField):
517
+ continue
386
518
  if trigger["fq_name"].count(".") == 0:
387
519
  # fully qualified name is just the flow name
388
520
  trigger["flow"] = trigger["fq_name"]
@@ -427,5 +559,54 @@ class TriggerOnFinishDecorator(FlowDecorator):
427
559
  run_objs.append(run_obj)
428
560
  current._update_env({"trigger": Trigger.from_runs(run_objs)})
429
561
 
562
+ def _parse_fq_name(self, trigger):
563
+ if isinstance(trigger, DeployTimeField):
564
+ trigger["fq_name"] = deploy_time_eval(trigger["fq_name"])
565
+ if trigger["fq_name"].count(".") == 0:
566
+ # fully qualified name is just the flow name
567
+ trigger["flow"] = trigger["fq_name"]
568
+ elif trigger["fq_name"].count(".") >= 2:
569
+ # fully qualified name is of the format - project.branch.flow_name
570
+ trigger["project"], tail = trigger["fq_name"].split(".", maxsplit=1)
571
+ trigger["branch"], trigger["flow"] = tail.rsplit(".", maxsplit=1)
572
+ else:
573
+ raise MetaflowException(
574
+ "Incorrect format for *flow* in *@trigger_on_finish* "
575
+ "decorator. Specify either just the *flow_name* or a fully "
576
+ "qualified name like *project_name.branch_name.flow_name*."
577
+ )
578
+ if not re.match(r"^[A-Za-z0-9_]+$", trigger["flow"]):
579
+ raise MetaflowException(
580
+ "Invalid flow name *%s* in *@trigger_on_finish* "
581
+ "decorator. Only alphanumeric characters and "
582
+ "underscores(_) are allowed." % trigger["flow"]
583
+ )
584
+ return trigger
585
+
586
+ def format_deploytime_value(self):
587
+ for trigger in self.triggers:
588
+ # Case were trigger is a function that returns a list
589
+ # Need to do this bc we need to iterate over list and process
590
+ if isinstance(trigger, DeployTimeField):
591
+ deploy_value = deploy_time_eval(trigger)
592
+ if isinstance(deploy_value, list):
593
+ self.triggers = deploy_value
594
+ else:
595
+ break
596
+ for trigger in self.triggers:
597
+ # Entire trigger is a function (returns either string or dict)
598
+ old_trig = trigger
599
+ if isinstance(trigger, DeployTimeField):
600
+ trigger = deploy_time_eval(trigger)
601
+ if isinstance(trigger, dict):
602
+ trigger["fq_name"] = trigger.get("name")
603
+ trigger["project"] = trigger.get("project")
604
+ trigger["branch"] = trigger.get("project_branch")
605
+ # We also added this bc it won't be formatted yet
606
+ if isinstance(trigger, str):
607
+ trigger = {"fq_name": trigger}
608
+ trigger = self._parse_fq_name(trigger)
609
+ self.triggers[self.triggers.index(old_trig)] = trigger
610
+
430
611
  def get_top_level_options(self):
431
612
  return list(self._option_values.items())
@@ -23,3 +23,32 @@ def parse_cli_options(flow_name, run_id, user, my_runs, echo):
23
23
  raise CommandException("A previous run id was not found. Specify --run-id.")
24
24
 
25
25
  return flow_name, run_id, user
26
+
27
+
28
+ def qos_requests_and_limits(qos: str, cpu: int, memory: int, storage: int):
29
+ "return resource requests and limits for the kubernetes pod based on the given QoS Class"
30
+ # case insensitive matching for QoS class
31
+ qos = qos.lower()
32
+ # Determine the requests and limits to define chosen QoS class
33
+ qos_limits = {}
34
+ qos_requests = {}
35
+ if qos == "guaranteed":
36
+ # Guaranteed - has both cpu/memory limits. requests not required, as these will be inferred.
37
+ qos_limits = {
38
+ "cpu": str(cpu),
39
+ "memory": "%sM" % str(memory),
40
+ "ephemeral-storage": "%sM" % str(storage),
41
+ }
42
+ # NOTE: Even though Kubernetes will produce matching requests for the specified limits, this happens late in the lifecycle.
43
+ # We specify them explicitly here to make some K8S tooling happy, in case they rely on .resources.requests being present at time of submitting the job.
44
+ qos_requests = qos_limits
45
+ else:
46
+ # Burstable - not Guaranteed, and has a memory/cpu limit or request
47
+ qos_requests = {
48
+ "cpu": str(cpu),
49
+ "memory": "%sM" % str(memory),
50
+ "ephemeral-storage": "%sM" % str(storage),
51
+ }
52
+ # TODO: Add support for BestEffort once there is a use case for it.
53
+ # BestEffort - no limit or requests for cpu/memory
54
+ return qos_requests, qos_limits
@@ -196,6 +196,7 @@ class Kubernetes(object):
196
196
  shared_memory=None,
197
197
  port=None,
198
198
  num_parallel=None,
199
+ qos=None,
199
200
  ):
200
201
  name = "js-%s" % str(uuid4())[:6]
201
202
  jobset = (
@@ -228,6 +229,7 @@ class Kubernetes(object):
228
229
  shared_memory=shared_memory,
229
230
  port=port,
230
231
  num_parallel=num_parallel,
232
+ qos=qos,
231
233
  )
232
234
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
233
235
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -488,6 +490,7 @@ class Kubernetes(object):
488
490
  shared_memory=None,
489
491
  port=None,
490
492
  name_pattern=None,
493
+ qos=None,
491
494
  ):
492
495
  if env is None:
493
496
  env = {}
@@ -528,6 +531,7 @@ class Kubernetes(object):
528
531
  persistent_volume_claims=persistent_volume_claims,
529
532
  shared_memory=shared_memory,
530
533
  port=port,
534
+ qos=qos,
531
535
  )
532
536
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
533
537
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -126,6 +126,12 @@ def kubernetes():
126
126
  type=int,
127
127
  help="Number of parallel nodes to run as a multi-node job.",
128
128
  )
129
+ @click.option(
130
+ "--qos",
131
+ default=None,
132
+ type=str,
133
+ help="Quality of Service class for the Kubernetes pod",
134
+ )
129
135
  @click.pass_context
130
136
  def step(
131
137
  ctx,
@@ -154,6 +160,7 @@ def step(
154
160
  shared_memory=None,
155
161
  port=None,
156
162
  num_parallel=None,
163
+ qos=None,
157
164
  **kwargs
158
165
  ):
159
166
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -294,6 +301,7 @@ def step(
294
301
  shared_memory=shared_memory,
295
302
  port=port,
296
303
  num_parallel=num_parallel,
304
+ qos=qos,
297
305
  )
298
306
  except Exception as e:
299
307
  traceback.print_exc(chain=False)
@@ -26,6 +26,7 @@ from metaflow.metaflow_config import (
26
26
  KUBERNETES_SERVICE_ACCOUNT,
27
27
  KUBERNETES_SHARED_MEMORY,
28
28
  KUBERNETES_TOLERATIONS,
29
+ KUBERNETES_QOS,
29
30
  )
30
31
  from metaflow.plugins.resources_decorator import ResourcesDecorator
31
32
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
@@ -41,6 +42,8 @@ except NameError:
41
42
  unicode = str
42
43
  basestring = str
43
44
 
45
+ SUPPORTED_KUBERNETES_QOS_CLASSES = ["Guaranteed", "Burstable"]
46
+
44
47
 
45
48
  class KubernetesDecorator(StepDecorator):
46
49
  """
@@ -109,6 +112,8 @@ class KubernetesDecorator(StepDecorator):
109
112
  hostname_resolution_timeout: int, default 10 * 60
110
113
  Timeout in seconds for the workers tasks in the gang scheduled cluster to resolve the hostname of control task.
111
114
  Only applicable when @parallel is used.
115
+ qos: str, default: Burstable
116
+ Quality of Service class to assign to the pod. Supported values are: Guaranteed, Burstable, BestEffort
112
117
  """
113
118
 
114
119
  name = "kubernetes"
@@ -136,6 +141,7 @@ class KubernetesDecorator(StepDecorator):
136
141
  "compute_pool": None,
137
142
  "executable": None,
138
143
  "hostname_resolution_timeout": 10 * 60,
144
+ "qos": KUBERNETES_QOS,
139
145
  }
140
146
  package_url = None
141
147
  package_sha = None
@@ -259,6 +265,17 @@ class KubernetesDecorator(StepDecorator):
259
265
  self.step = step
260
266
  self.flow_datastore = flow_datastore
261
267
 
268
+ if (
269
+ self.attributes["qos"] is not None
270
+ # case insensitive matching.
271
+ and self.attributes["qos"].lower()
272
+ not in [c.lower() for c in SUPPORTED_KUBERNETES_QOS_CLASSES]
273
+ ):
274
+ raise MetaflowException(
275
+ "*%s* is not a valid Kubernetes QoS class. Choose one of the following: %s"
276
+ % (self.attributes["qos"], ", ".join(SUPPORTED_KUBERNETES_QOS_CLASSES))
277
+ )
278
+
262
279
  if any([deco.name == "batch" for deco in decos]):
263
280
  raise MetaflowException(
264
281
  "Step *{step}* is marked for execution both on AWS Batch and "
@@ -15,6 +15,8 @@ from .kubernetes_jobsets import (
15
15
  KubernetesJobSet,
16
16
  ) # We need this import for Kubernetes Client.
17
17
 
18
+ from .kube_utils import qos_requests_and_limits
19
+
18
20
 
19
21
  class KubernetesJobException(MetaflowException):
20
22
  headline = "Kubernetes job error"
@@ -74,6 +76,13 @@ class KubernetesJob(object):
74
76
  if self._kwargs["shared_memory"]
75
77
  else None
76
78
  )
79
+ qos_requests, qos_limits = qos_requests_and_limits(
80
+ self._kwargs["qos"],
81
+ self._kwargs["cpu"],
82
+ self._kwargs["memory"],
83
+ self._kwargs["disk"],
84
+ )
85
+
77
86
  return client.V1JobSpec(
78
87
  # Retries are handled by Metaflow when it is responsible for
79
88
  # executing the flow. The responsibility is moved to Kubernetes
@@ -154,20 +163,18 @@ class KubernetesJob(object):
154
163
  image_pull_policy=self._kwargs["image_pull_policy"],
155
164
  name=self._kwargs["step_name"].replace("_", "-"),
156
165
  resources=client.V1ResourceRequirements(
157
- requests={
158
- "cpu": str(self._kwargs["cpu"]),
159
- "memory": "%sM" % str(self._kwargs["memory"]),
160
- "ephemeral-storage": "%sM"
161
- % str(self._kwargs["disk"]),
162
- },
166
+ requests=qos_requests,
163
167
  limits={
164
- "%s.com/gpu".lower()
165
- % self._kwargs["gpu_vendor"]: str(
166
- self._kwargs["gpu"]
167
- )
168
- for k in [0]
169
- # Don't set GPU limits if gpu isn't specified.
170
- if self._kwargs["gpu"] is not None
168
+ **qos_limits,
169
+ **{
170
+ "%s.com/gpu".lower()
171
+ % self._kwargs["gpu_vendor"]: str(
172
+ self._kwargs["gpu"]
173
+ )
174
+ for k in [0]
175
+ # Don't set GPU limits if gpu isn't specified.
176
+ if self._kwargs["gpu"] is not None
177
+ },
171
178
  },
172
179
  ),
173
180
  volume_mounts=(
@@ -9,6 +9,8 @@ from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_
9
9
  from metaflow.tracing import inject_tracing_vars
10
10
  from metaflow.metaflow_config import KUBERNETES_SECRETS
11
11
 
12
+ from .kube_utils import qos_requests_and_limits
13
+
12
14
 
13
15
  class KubernetesJobsetException(MetaflowException):
14
16
  headline = "Kubernetes jobset error"
@@ -554,7 +556,12 @@ class JobSetSpec(object):
554
556
  if self._kwargs["shared_memory"]
555
557
  else None
556
558
  )
557
-
559
+ qos_requests, qos_limits = qos_requests_and_limits(
560
+ self._kwargs["qos"],
561
+ self._kwargs["cpu"],
562
+ self._kwargs["memory"],
563
+ self._kwargs["disk"],
564
+ )
558
565
  return dict(
559
566
  name=self.name,
560
567
  template=client.api_client.ApiClient().sanitize_for_serialization(
@@ -653,21 +660,18 @@ class JobSetSpec(object):
653
660
  "_", "-"
654
661
  ),
655
662
  resources=client.V1ResourceRequirements(
656
- requests={
657
- "cpu": str(self._kwargs["cpu"]),
658
- "memory": "%sM"
659
- % str(self._kwargs["memory"]),
660
- "ephemeral-storage": "%sM"
661
- % str(self._kwargs["disk"]),
662
- },
663
+ requests=qos_requests,
663
664
  limits={
664
- "%s.com/gpu".lower()
665
- % self._kwargs["gpu_vendor"]: str(
666
- self._kwargs["gpu"]
667
- )
668
- for k in [0]
669
- # Don't set GPU limits if gpu isn't specified.
670
- if self._kwargs["gpu"] is not None
665
+ **qos_limits,
666
+ **{
667
+ "%s.com/gpu".lower()
668
+ % self._kwargs["gpu_vendor"]: str(
669
+ self._kwargs["gpu"]
670
+ )
671
+ for k in [0]
672
+ # Don't set GPU limits if gpu isn't specified.
673
+ if self._kwargs["gpu"] is not None
674
+ },
671
675
  },
672
676
  ),
673
677
  volume_mounts=(