metaflow 2.12.8__py2.py3-none-any.whl → 2.12.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +462 -42
  9. metaflow/plugins/argo/argo_workflows_cli.py +60 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  13. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  14. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  15. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  16. metaflow/plugins/cards/card_cli.py +1 -1
  17. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  18. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  19. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  20. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  21. metaflow/plugins/kubernetes/kubernetes_job.py +6 -6
  22. metaflow/plugins/kubernetes/kubernetes_jobsets.py +510 -272
  23. metaflow/plugins/parallel_decorator.py +108 -8
  24. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  25. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  26. metaflow/runner/deployer.py +386 -0
  27. metaflow/runner/metaflow_runner.py +1 -20
  28. metaflow/runner/nbdeploy.py +130 -0
  29. metaflow/runner/nbrun.py +4 -28
  30. metaflow/runner/utils.py +49 -0
  31. metaflow/runtime.py +246 -134
  32. metaflow/version.py +1 -1
  33. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/METADATA +2 -2
  34. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/RECORD +38 -32
  35. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/WHEEL +1 -1
  36. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/LICENSE +0 -0
  37. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/entry_points.txt +0 -0
  38. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,251 @@
1
+ import sys
2
+ import json
3
+ import tempfile
4
+ from typing import Optional, ClassVar, List
5
+
6
+ from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
7
+ from metaflow.runner.deployer import (
8
+ DeployerImpl,
9
+ DeployedFlow,
10
+ TriggeredRun,
11
+ get_lower_level_group,
12
+ handle_timeout,
13
+ )
14
+
15
+
16
+ def terminate(instance: TriggeredRun, **kwargs):
17
+ """
18
+ Terminate the running workflow.
19
+
20
+ Parameters
21
+ ----------
22
+ **kwargs : Any
23
+ Additional arguments to pass to the terminate command.
24
+
25
+ Returns
26
+ -------
27
+ bool
28
+ True if the command was successful, False otherwise.
29
+ """
30
+ _, run_id = instance.pathspec.split("/")
31
+
32
+ # every subclass needs to have `self.deployer_kwargs`
33
+ command = get_lower_level_group(
34
+ instance.deployer.api,
35
+ instance.deployer.top_level_kwargs,
36
+ instance.deployer.TYPE,
37
+ instance.deployer.deployer_kwargs,
38
+ ).terminate(run_id=run_id, **kwargs)
39
+
40
+ pid = instance.deployer.spm.run_command(
41
+ [sys.executable, *command],
42
+ env=instance.deployer.env_vars,
43
+ cwd=instance.deployer.cwd,
44
+ show_output=instance.deployer.show_output,
45
+ )
46
+
47
+ command_obj = instance.deployer.spm.get(pid)
48
+ return command_obj.process.returncode == 0
49
+
50
+
51
+ def production_token(instance: DeployedFlow):
52
+ """
53
+ Get the production token for the deployed flow.
54
+
55
+ Returns
56
+ -------
57
+ str, optional
58
+ The production token, None if it cannot be retrieved.
59
+ """
60
+ try:
61
+ _, production_token = StepFunctions.get_existing_deployment(
62
+ instance.deployer.name
63
+ )
64
+ return production_token
65
+ except TypeError:
66
+ return None
67
+
68
+
69
+ def list_runs(instance: DeployedFlow, states: Optional[List[str]] = None):
70
+ """
71
+ List runs of the deployed flow.
72
+
73
+ Parameters
74
+ ----------
75
+ states : Optional[List[str]], optional
76
+ A list of states to filter the runs by. Allowed values are:
77
+ RUNNING, SUCCEEDED, FAILED, TIMED_OUT, ABORTED.
78
+ If not provided, all states will be considered.
79
+
80
+ Returns
81
+ -------
82
+ List[TriggeredRun]
83
+ A list of TriggeredRun objects representing the runs of the deployed flow.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If any of the provided states are invalid or if there are duplicate states.
89
+ """
90
+ VALID_STATES = {"RUNNING", "SUCCEEDED", "FAILED", "TIMED_OUT", "ABORTED"}
91
+
92
+ if states is None:
93
+ states = []
94
+
95
+ unique_states = set(states)
96
+ if not unique_states.issubset(VALID_STATES):
97
+ invalid_states = unique_states - VALID_STATES
98
+ raise ValueError(
99
+ f"Invalid states found: {invalid_states}. Valid states are: {VALID_STATES}"
100
+ )
101
+
102
+ if len(states) != len(unique_states):
103
+ raise ValueError("Duplicate states are not allowed")
104
+
105
+ triggered_runs = []
106
+ executions = StepFunctions.list(instance.deployer.name, states)
107
+
108
+ for e in executions:
109
+ run_id = "sfn-%s" % e["name"]
110
+ tr = TriggeredRun(
111
+ deployer=instance.deployer,
112
+ content=json.dumps(
113
+ {
114
+ "metadata": instance.deployer.metadata,
115
+ "pathspec": "/".join((instance.deployer.flow_name, run_id)),
116
+ "name": run_id,
117
+ }
118
+ ),
119
+ )
120
+ tr._enrich_object({"terminate": terminate})
121
+ triggered_runs.append(tr)
122
+
123
+ return triggered_runs
124
+
125
+
126
+ def delete(instance: DeployedFlow, **kwargs):
127
+ """
128
+ Delete the deployed flow.
129
+
130
+ Parameters
131
+ ----------
132
+ **kwargs : Any
133
+ Additional arguments to pass to the delete command.
134
+
135
+ Returns
136
+ -------
137
+ bool
138
+ True if the command was successful, False otherwise.
139
+ """
140
+ command = get_lower_level_group(
141
+ instance.deployer.api,
142
+ instance.deployer.top_level_kwargs,
143
+ instance.deployer.TYPE,
144
+ instance.deployer.deployer_kwargs,
145
+ ).delete(**kwargs)
146
+
147
+ pid = instance.deployer.spm.run_command(
148
+ [sys.executable, *command],
149
+ env=instance.deployer.env_vars,
150
+ cwd=instance.deployer.cwd,
151
+ show_output=instance.deployer.show_output,
152
+ )
153
+
154
+ command_obj = instance.deployer.spm.get(pid)
155
+ return command_obj.process.returncode == 0
156
+
157
+
158
+ def trigger(instance: DeployedFlow, **kwargs):
159
+ """
160
+ Trigger a new run for the deployed flow.
161
+
162
+ Parameters
163
+ ----------
164
+ **kwargs : Any
165
+ Additional arguments to pass to the trigger command, `Parameters` in particular
166
+
167
+ Returns
168
+ -------
169
+ StepFunctionsTriggeredRun
170
+ The triggered run instance.
171
+
172
+ Raises
173
+ ------
174
+ Exception
175
+ If there is an error during the trigger process.
176
+ """
177
+ with tempfile.TemporaryDirectory() as temp_dir:
178
+ tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
179
+
180
+ # every subclass needs to have `self.deployer_kwargs`
181
+ command = get_lower_level_group(
182
+ instance.deployer.api,
183
+ instance.deployer.top_level_kwargs,
184
+ instance.deployer.TYPE,
185
+ instance.deployer.deployer_kwargs,
186
+ ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs)
187
+
188
+ pid = instance.deployer.spm.run_command(
189
+ [sys.executable, *command],
190
+ env=instance.deployer.env_vars,
191
+ cwd=instance.deployer.cwd,
192
+ show_output=instance.deployer.show_output,
193
+ )
194
+
195
+ command_obj = instance.deployer.spm.get(pid)
196
+ content = handle_timeout(tfp_runner_attribute, command_obj)
197
+
198
+ if command_obj.process.returncode == 0:
199
+ triggered_run = TriggeredRun(deployer=instance.deployer, content=content)
200
+ triggered_run._enrich_object({"terminate": terminate})
201
+ return triggered_run
202
+
203
+ raise Exception(
204
+ "Error triggering %s on %s for %s"
205
+ % (instance.deployer.name, instance.deployer.TYPE, instance.deployer.flow_file)
206
+ )
207
+
208
+
209
+ class StepFunctionsDeployer(DeployerImpl):
210
+ """
211
+ Deployer implementation for AWS Step Functions.
212
+
213
+ Attributes
214
+ ----------
215
+ TYPE : ClassVar[Optional[str]]
216
+ The type of the deployer, which is "step-functions".
217
+ """
218
+
219
+ TYPE: ClassVar[Optional[str]] = "step-functions"
220
+
221
+ def __init__(self, deployer_kwargs, **kwargs):
222
+ """
223
+ Initialize the StepFunctionsDeployer.
224
+
225
+ Parameters
226
+ ----------
227
+ deployer_kwargs : dict
228
+ The deployer-specific keyword arguments.
229
+ **kwargs : Any
230
+ Additional arguments to pass to the superclass constructor.
231
+ """
232
+ self.deployer_kwargs = deployer_kwargs
233
+ super().__init__(**kwargs)
234
+
235
+ def _enrich_deployed_flow(self, deployed_flow: DeployedFlow):
236
+ """
237
+ Enrich the DeployedFlow object with additional properties and methods.
238
+
239
+ Parameters
240
+ ----------
241
+ deployed_flow : DeployedFlow
242
+ The deployed flow object to enrich.
243
+ """
244
+ deployed_flow._enrich_object(
245
+ {
246
+ "production_token": property(production_token),
247
+ "trigger": trigger,
248
+ "delete": delete,
249
+ "list_runs": list_runs,
250
+ }
251
+ )
@@ -752,7 +752,7 @@ def create(
752
752
  return _card.render(
753
753
  task,
754
754
  stack_trace=stack_trace,
755
- ).replace(mf_card.RELOAD_POLICY_TOKEN, token)
755
+ ).replace(_card.RELOAD_POLICY_TOKEN, token)
756
756
 
757
757
  if error_stack_trace is not None and mode != "refresh":
758
758
  rendered_content = _render_error_card(error_stack_trace)
@@ -1,9 +1,9 @@
1
+ import copy
1
2
  import json
2
3
  import math
3
4
  import os
4
5
  import re
5
6
  import shlex
6
- import copy
7
7
  import time
8
8
  from typing import Dict, List, Optional
9
9
  from uuid import uuid4
@@ -14,10 +14,11 @@ from metaflow.metaflow_config import (
14
14
  ARGO_EVENTS_EVENT,
15
15
  ARGO_EVENTS_EVENT_BUS,
16
16
  ARGO_EVENTS_EVENT_SOURCE,
17
- ARGO_EVENTS_SERVICE_ACCOUNT,
18
17
  ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
19
- AWS_SECRETS_MANAGER_DEFAULT_REGION,
18
+ ARGO_EVENTS_SERVICE_ACCOUNT,
20
19
  ARGO_EVENTS_WEBHOOK_AUTH,
20
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
21
+ AZURE_KEY_VAULT_PREFIX,
21
22
  AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
22
23
  CARD_AZUREROOT,
23
24
  CARD_GSROOT,
@@ -31,18 +32,18 @@ from metaflow.metaflow_config import (
31
32
  DEFAULT_METADATA,
32
33
  DEFAULT_SECRETS_BACKEND_TYPE,
33
34
  GCP_SECRET_MANAGER_PREFIX,
34
- AZURE_KEY_VAULT_PREFIX,
35
35
  KUBERNETES_FETCH_EC2_METADATA,
36
36
  KUBERNETES_LABELS,
37
37
  KUBERNETES_SANDBOX_INIT_SCRIPT,
38
+ OTEL_ENDPOINT,
38
39
  S3_ENDPOINT_URL,
40
+ S3_SERVER_SIDE_ENCRYPTION,
39
41
  SERVICE_HEADERS,
42
+ KUBERNETES_SECRETS,
40
43
  SERVICE_INTERNAL_URL,
41
- S3_SERVER_SIDE_ENCRYPTION,
42
- OTEL_ENDPOINT,
43
44
  )
45
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
44
46
  from metaflow.metaflow_config_funcs import config_values
45
-
46
47
  from metaflow.mflog import (
47
48
  BASH_SAVE_LOGS,
48
49
  bash_capture_logs,
@@ -60,6 +61,10 @@ STDERR_FILE = "mflog_stderr"
60
61
  STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
61
62
  STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
62
63
 
64
+ METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE = (
65
+ "{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}"
66
+ )
67
+
63
68
 
64
69
  class KubernetesException(MetaflowException):
65
70
  headline = "Kubernetes error"
@@ -69,12 +74,6 @@ class KubernetesKilledException(MetaflowException):
69
74
  headline = "Kubernetes Batch job killed"
70
75
 
71
76
 
72
- def _extract_labels_and_annotations_from_job_spec(job_spec):
73
- annotations = job_spec.template.metadata.annotations
74
- labels = job_spec.template.metadata.labels
75
- return copy.copy(annotations), copy.copy(labels)
76
-
77
-
78
77
  class Kubernetes(object):
79
78
  def __init__(
80
79
  self,
@@ -154,57 +153,287 @@ class Kubernetes(object):
154
153
  and kwargs["num_parallel"]
155
154
  and int(kwargs["num_parallel"]) > 0
156
155
  ):
157
- job = self.create_job_object(**kwargs)
158
- spec = job.create_job_spec()
159
- # `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
160
- # This will be modified by the KubernetesJobSet object
161
- annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
162
- self._job = self.create_jobset(
163
- job_spec=spec,
164
- run_id=kwargs["run_id"],
165
- step_name=kwargs["step_name"],
166
- task_id=kwargs["task_id"],
167
- namespace=kwargs["namespace"],
168
- env=kwargs["env"],
169
- num_parallel=kwargs["num_parallel"],
170
- port=kwargs["port"],
171
- annotations=annotations,
172
- labels=labels,
173
- ).execute()
156
+ self._job = self.create_jobset(**kwargs).execute()
174
157
  else:
158
+ kwargs.pop("num_parallel", None)
175
159
  kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
176
160
  self._job = self.create_job_object(**kwargs).create().execute()
177
161
 
178
162
  def create_jobset(
179
163
  self,
180
- job_spec=None,
181
- run_id=None,
182
- step_name=None,
183
- task_id=None,
164
+ flow_name,
165
+ run_id,
166
+ step_name,
167
+ task_id,
168
+ attempt,
169
+ user,
170
+ code_package_sha,
171
+ code_package_url,
172
+ code_package_ds,
173
+ docker_image,
174
+ docker_image_pull_policy,
175
+ step_cli=None,
176
+ service_account=None,
177
+ secrets=None,
178
+ node_selector=None,
184
179
  namespace=None,
180
+ cpu=None,
181
+ gpu=None,
182
+ gpu_vendor=None,
183
+ disk=None,
184
+ memory=None,
185
+ use_tmpfs=None,
186
+ tmpfs_tempdir=None,
187
+ tmpfs_size=None,
188
+ tmpfs_path=None,
189
+ run_time_limit=None,
185
190
  env=None,
186
- num_parallel=None,
187
- port=None,
188
- annotations=None,
191
+ persistent_volume_claims=None,
192
+ tolerations=None,
189
193
  labels=None,
194
+ shared_memory=None,
195
+ port=None,
196
+ num_parallel=None,
190
197
  ):
191
- if env is None:
192
- env = {}
198
+ name = "js-%s" % str(uuid4())[:6]
199
+ jobset = (
200
+ KubernetesClient()
201
+ .jobset(
202
+ name=name,
203
+ namespace=namespace,
204
+ service_account=service_account,
205
+ node_selector=node_selector,
206
+ image=docker_image,
207
+ image_pull_policy=docker_image_pull_policy,
208
+ cpu=cpu,
209
+ memory=memory,
210
+ disk=disk,
211
+ gpu=gpu,
212
+ gpu_vendor=gpu_vendor,
213
+ timeout_in_seconds=run_time_limit,
214
+ # Retries are handled by Metaflow runtime
215
+ retries=0,
216
+ step_name=step_name,
217
+ # We set the jobset name as the subdomain.
218
+ # todo: [final-refactor] ask @shri what was the motive when we did initial implementation
219
+ subdomain=name,
220
+ tolerations=tolerations,
221
+ use_tmpfs=use_tmpfs,
222
+ tmpfs_tempdir=tmpfs_tempdir,
223
+ tmpfs_size=tmpfs_size,
224
+ tmpfs_path=tmpfs_path,
225
+ persistent_volume_claims=persistent_volume_claims,
226
+ shared_memory=shared_memory,
227
+ port=port,
228
+ num_parallel=num_parallel,
229
+ )
230
+ .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
231
+ .environment_variable("METAFLOW_CODE_URL", code_package_url)
232
+ .environment_variable("METAFLOW_CODE_DS", code_package_ds)
233
+ .environment_variable("METAFLOW_USER", user)
234
+ .environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
235
+ .environment_variable(
236
+ "METAFLOW_SERVICE_HEADERS",
237
+ json.dumps(SERVICE_HEADERS),
238
+ )
239
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
240
+ .environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
241
+ .environment_variable("METAFLOW_DEFAULT_DATASTORE", self._datastore.TYPE)
242
+ .environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
243
+ .environment_variable("METAFLOW_KUBERNETES_WORKLOAD", 1)
244
+ .environment_variable(
245
+ "METAFLOW_KUBERNETES_FETCH_EC2_METADATA", KUBERNETES_FETCH_EC2_METADATA
246
+ )
247
+ .environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "kubernetes")
248
+ .environment_variable(
249
+ "METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
250
+ )
251
+ .environment_variable("METAFLOW_CARD_S3ROOT", CARD_S3ROOT)
252
+ .environment_variable(
253
+ "METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER
254
+ )
255
+ .environment_variable(
256
+ "METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER
257
+ )
258
+ .environment_variable(
259
+ "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
260
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
261
+ )
262
+ .environment_variable(
263
+ "METAFLOW_GCP_SECRET_MANAGER_PREFIX", GCP_SECRET_MANAGER_PREFIX
264
+ )
265
+ .environment_variable(
266
+ "METAFLOW_AZURE_KEY_VAULT_PREFIX", AZURE_KEY_VAULT_PREFIX
267
+ )
268
+ .environment_variable("METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL)
269
+ .environment_variable(
270
+ "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT",
271
+ AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
272
+ )
273
+ .environment_variable(
274
+ "METAFLOW_DATASTORE_SYSROOT_AZURE", DATASTORE_SYSROOT_AZURE
275
+ )
276
+ .environment_variable("METAFLOW_CARD_AZUREROOT", CARD_AZUREROOT)
277
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_GS", DATASTORE_SYSROOT_GS)
278
+ .environment_variable("METAFLOW_CARD_GSROOT", CARD_GSROOT)
279
+ # support Metaflow sandboxes
280
+ .environment_variable(
281
+ "METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
282
+ )
283
+ .environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
284
+ # Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync
285
+ # between the local user instance and the remote Kubernetes pod
286
+ # assumes metadata is stored in DATASTORE_LOCAL_DIR on the Kubernetes
287
+ # pod; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL is NOT set (
288
+ # see get_datastore_root_from_config in datastore/local.py).
289
+ )
193
290
 
194
- _prefix = str(uuid4())[:6]
195
- js = KubernetesClient().jobset(
196
- name="js-%s" % _prefix,
291
+ _labels = self._get_labels(labels)
292
+ for k, v in _labels.items():
293
+ jobset.label(k, v)
294
+
295
+ for k in list(
296
+ [] if not secrets else [secrets] if isinstance(secrets, str) else secrets
297
+ ) + KUBERNETES_SECRETS.split(","):
298
+ jobset.secret(k)
299
+
300
+ jobset.environment_variables_from_selectors(
301
+ {
302
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
303
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
304
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
305
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
306
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
307
+ }
308
+ )
309
+
310
+ # Temporary passing of *some* environment variables. Do not rely on this
311
+ # mechanism as it will be removed in the near future
312
+ for k, v in config_values():
313
+ if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
314
+ jobset.environment_variable(k, v)
315
+
316
+ if S3_SERVER_SIDE_ENCRYPTION is not None:
317
+ jobset.environment_variable(
318
+ "METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
319
+ )
320
+
321
+ # Set environment variables to support metaflow.integrations.ArgoEvent
322
+ jobset.environment_variable(
323
+ "METAFLOW_ARGO_EVENTS_WEBHOOK_URL", ARGO_EVENTS_INTERNAL_WEBHOOK_URL
324
+ )
325
+ jobset.environment_variable("METAFLOW_ARGO_EVENTS_EVENT", ARGO_EVENTS_EVENT)
326
+ jobset.environment_variable(
327
+ "METAFLOW_ARGO_EVENTS_EVENT_BUS", ARGO_EVENTS_EVENT_BUS
328
+ )
329
+ jobset.environment_variable(
330
+ "METAFLOW_ARGO_EVENTS_EVENT_SOURCE", ARGO_EVENTS_EVENT_SOURCE
331
+ )
332
+ jobset.environment_variable(
333
+ "METAFLOW_ARGO_EVENTS_SERVICE_ACCOUNT", ARGO_EVENTS_SERVICE_ACCOUNT
334
+ )
335
+ jobset.environment_variable(
336
+ "METAFLOW_ARGO_EVENTS_WEBHOOK_AUTH",
337
+ ARGO_EVENTS_WEBHOOK_AUTH,
338
+ )
339
+
340
+ ## -----Jobset specific env vars START here-----
341
+ jobset.environment_variable("MF_MASTER_ADDR", jobset.jobset_control_addr)
342
+ jobset.environment_variable("MF_MASTER_PORT", str(port))
343
+ jobset.environment_variable("MF_WORLD_SIZE", str(num_parallel))
344
+ jobset.environment_variable_from_selector(
345
+ "JOBSET_RESTART_ATTEMPT",
346
+ "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
347
+ )
348
+ jobset.environment_variable_from_selector(
349
+ "METAFLOW_KUBERNETES_JOBSET_NAME",
350
+ "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
351
+ )
352
+ jobset.environment_variable_from_selector(
353
+ "MF_WORKER_REPLICA_INDEX",
354
+ "metadata.annotations['jobset.sigs.k8s.io/job-index']",
355
+ )
356
+ ## -----Jobset specific env vars END here-----
357
+
358
+ tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
359
+ if tmpfs_enabled and tmpfs_tempdir:
360
+ jobset.environment_variable("METAFLOW_TEMPDIR", tmpfs_path)
361
+
362
+ for name, value in env.items():
363
+ jobset.environment_variable(name, value)
364
+
365
+ annotations = {
366
+ "metaflow/user": user,
367
+ "metaflow/flow_name": flow_name,
368
+ "metaflow/control-task-id": task_id,
369
+ }
370
+ if current.get("project_name"):
371
+ annotations.update(
372
+ {
373
+ "metaflow/project_name": current.project_name,
374
+ "metaflow/branch_name": current.branch_name,
375
+ "metaflow/project_flow_name": current.project_flow_name,
376
+ }
377
+ )
378
+
379
+ for name, value in annotations.items():
380
+ jobset.annotation(name, value)
381
+
382
+ (
383
+ jobset.annotation("metaflow/run_id", run_id)
384
+ .annotation("metaflow/step_name", step_name)
385
+ .annotation("metaflow/attempt", attempt)
386
+ .label("app.kubernetes.io/name", "metaflow-task")
387
+ .label("app.kubernetes.io/part-of", "metaflow")
388
+ )
389
+
390
+ ## ----------- control/worker specific values START here -----------
391
+ # We will now set the appropriate command for the control/worker job
392
+ _get_command = lambda index, _tskid: self._command(
393
+ flow_name=flow_name,
197
394
  run_id=run_id,
198
- task_id=task_id,
199
395
  step_name=step_name,
200
- namespace=namespace,
201
- labels=self._get_labels(labels),
202
- annotations=annotations,
203
- num_parallel=num_parallel,
204
- job_spec=job_spec,
205
- port=port,
396
+ task_id=_tskid,
397
+ attempt=attempt,
398
+ code_package_url=code_package_url,
399
+ step_cmds=[
400
+ step_cli.replace(
401
+ METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE,
402
+ "--ubf-context $UBF_CONTEXT --split-index %s --task-id %s"
403
+ % (index, _tskid),
404
+ )
405
+ ],
406
+ )
407
+ jobset.control.replicas(1)
408
+ jobset.worker.replicas(num_parallel - 1)
409
+
410
+ # We set the appropriate command for the control/worker job
411
+ # and also set the task-id/spit-index for the control/worker job
412
+ # appropirately.
413
+ jobset.control.command(_get_command("0", str(task_id)))
414
+ jobset.worker.command(
415
+ _get_command(
416
+ "`expr $[MF_WORKER_REPLICA_INDEX] + 1`",
417
+ "-".join(
418
+ [
419
+ str(task_id),
420
+ "worker",
421
+ "$MF_WORKER_REPLICA_INDEX",
422
+ ]
423
+ ),
424
+ )
206
425
  )
207
- return js
426
+
427
+ jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
428
+ jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
429
+ # Every control job requires an environment variable of MF_CONTROL_INDEX
430
+ # set to 0 so that we can derive the MF_PARALLEL_NODE_INDEX correctly.
431
+ # Since only the control job has MF_CONTROL_INDE set to 0, all worker nodes
432
+ # will use MF_WORKER_REPLICA_INDEX
433
+ jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
434
+ ## ----------- control/worker specific values END here -----------
435
+
436
+ return jobset
208
437
 
209
438
  def create_job_object(
210
439
  self,
@@ -241,7 +470,6 @@ class Kubernetes(object):
241
470
  shared_memory=None,
242
471
  port=None,
243
472
  name_pattern=None,
244
- num_parallel=None,
245
473
  ):
246
474
  if env is None:
247
475
  env = {}
@@ -282,7 +510,6 @@ class Kubernetes(object):
282
510
  persistent_volume_claims=persistent_volume_claims,
283
511
  shared_memory=shared_memory,
284
512
  port=port,
285
- num_parallel=num_parallel,
286
513
  )
287
514
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
288
515
  .environment_variable("METAFLOW_CODE_URL", code_package_url)