ob-metaflow 2.10.7.4__py2.py3-none-any.whl → 2.10.9.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (57) hide show
  1. metaflow/cards.py +2 -0
  2. metaflow/decorators.py +1 -1
  3. metaflow/metaflow_config.py +4 -0
  4. metaflow/plugins/__init__.py +4 -0
  5. metaflow/plugins/airflow/airflow_cli.py +1 -1
  6. metaflow/plugins/argo/argo_workflows.py +5 -0
  7. metaflow/plugins/argo/argo_workflows_cli.py +1 -1
  8. metaflow/plugins/aws/aws_utils.py +1 -1
  9. metaflow/plugins/aws/batch/batch.py +4 -0
  10. metaflow/plugins/aws/batch/batch_cli.py +3 -0
  11. metaflow/plugins/aws/batch/batch_client.py +40 -11
  12. metaflow/plugins/aws/batch/batch_decorator.py +1 -0
  13. metaflow/plugins/aws/step_functions/step_functions.py +1 -0
  14. metaflow/plugins/aws/step_functions/step_functions_cli.py +1 -1
  15. metaflow/plugins/azure/azure_exceptions.py +1 -1
  16. metaflow/plugins/cards/card_cli.py +413 -28
  17. metaflow/plugins/cards/card_client.py +16 -7
  18. metaflow/plugins/cards/card_creator.py +228 -0
  19. metaflow/plugins/cards/card_datastore.py +124 -26
  20. metaflow/plugins/cards/card_decorator.py +40 -86
  21. metaflow/plugins/cards/card_modules/base.html +12 -0
  22. metaflow/plugins/cards/card_modules/basic.py +74 -8
  23. metaflow/plugins/cards/card_modules/bundle.css +1 -170
  24. metaflow/plugins/cards/card_modules/card.py +65 -0
  25. metaflow/plugins/cards/card_modules/components.py +446 -81
  26. metaflow/plugins/cards/card_modules/convert_to_native_type.py +9 -3
  27. metaflow/plugins/cards/card_modules/main.js +250 -21
  28. metaflow/plugins/cards/card_modules/test_cards.py +117 -0
  29. metaflow/plugins/cards/card_resolver.py +0 -2
  30. metaflow/plugins/cards/card_server.py +361 -0
  31. metaflow/plugins/cards/component_serializer.py +506 -42
  32. metaflow/plugins/cards/exception.py +20 -1
  33. metaflow/plugins/datastores/azure_storage.py +1 -2
  34. metaflow/plugins/datastores/gs_storage.py +1 -2
  35. metaflow/plugins/datastores/s3_storage.py +2 -1
  36. metaflow/plugins/datatools/s3/s3.py +24 -11
  37. metaflow/plugins/env_escape/client.py +2 -12
  38. metaflow/plugins/env_escape/client_modules.py +18 -14
  39. metaflow/plugins/env_escape/server.py +18 -11
  40. metaflow/plugins/env_escape/utils.py +12 -0
  41. metaflow/plugins/gcp/gs_exceptions.py +1 -1
  42. metaflow/plugins/gcp/gs_utils.py +1 -1
  43. metaflow/plugins/kubernetes/kubernetes.py +43 -6
  44. metaflow/plugins/kubernetes/kubernetes_cli.py +40 -1
  45. metaflow/plugins/kubernetes/kubernetes_decorator.py +73 -6
  46. metaflow/plugins/kubernetes/kubernetes_job.py +536 -161
  47. metaflow/plugins/pypi/conda_environment.py +5 -6
  48. metaflow/plugins/pypi/pip.py +2 -2
  49. metaflow/plugins/pypi/utils.py +15 -0
  50. metaflow/task.py +1 -0
  51. metaflow/version.py +1 -1
  52. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/METADATA +1 -1
  53. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/RECORD +57 -55
  54. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/LICENSE +0 -0
  55. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/WHEEL +0 -0
  56. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/entry_points.txt +0 -0
  57. {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/top_level.txt +0 -0
@@ -2,20 +2,18 @@ import json
2
2
  import math
3
3
  import random
4
4
  import time
5
-
6
- from metaflow.tracing import inject_tracing_vars
7
-
5
+ import os
6
+ import socket
7
+ import copy
8
8
 
9
9
  from metaflow.exception import MetaflowException
10
10
  from metaflow.metaflow_config import KUBERNETES_SECRETS
11
11
 
12
12
  CLIENT_REFRESH_INTERVAL_SECONDS = 300
13
13
 
14
-
15
14
  class KubernetesJobException(MetaflowException):
16
15
  headline = "Kubernetes job error"
17
16
 
18
-
19
17
  # Implements truncated exponential backoff from
20
18
  # https://cloud.google.com/storage/docs/retry-strategy#exponential-backoff
21
19
  def k8s_retry(deadline_seconds=60, max_backoff=32):
@@ -78,107 +76,260 @@ class KubernetesJob(object):
78
76
  tmpfs_size = self._kwargs["tmpfs_size"]
79
77
  tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
80
78
 
81
- self._job = client.V1Job(
82
- api_version="batch/v1",
83
- kind="Job",
84
- metadata=client.V1ObjectMeta(
85
- # Annotations are for humans
86
- annotations=self._kwargs.get("annotations", {}),
87
- # While labels are for Kubernetes
88
- labels=self._kwargs.get("labels", {}),
89
- generate_name=self._kwargs["generate_name"],
90
- namespace=self._kwargs["namespace"], # Defaults to `default`
91
- ),
92
- spec=client.V1JobSpec(
93
- # Retries are handled by Metaflow when it is responsible for
94
- # executing the flow. The responsibility is moved to Kubernetes
95
- # when Argo Workflows is responsible for the execution.
96
- backoff_limit=self._kwargs.get("retries", 0),
97
- completions=1, # A single non-indexed pod job
98
- ttl_seconds_after_finished=7
99
- * 60
100
- * 60 # Remove job after a week. TODO: Make this configurable
101
- * 24,
102
- template=client.V1PodTemplateSpec(
79
+ jobset_name = "js-%s" % self._kwargs["attrs"]["metaflow.task_id"].split('-')[-1]
80
+ main_job_name = "control"
81
+ main_job_index = 0
82
+ main_pod_index = 0
83
+ subdomain = jobset_name
84
+ master_port = int(self._kwargs['port']) if self._kwargs['port'] else None
85
+
86
+ passwordless_ssh = self._kwargs["attrs"]["requires_passwordless_ssh"]
87
+ if passwordless_ssh:
88
+ passwordless_ssh_service_name = subdomain
89
+ passwordless_ssh_service_selector = {
90
+ "passwordless-ssh-jobset": "true"
91
+ }
92
+ else:
93
+ passwordless_ssh_service_name = None
94
+ passwordless_ssh_service_selector = {}
95
+
96
+ fqdn_suffix = "%s.svc.cluster.local" % self._kwargs["namespace"]
97
+ jobset_main_addr = "%s-%s-%s-%s.%s.%s" % (
98
+ jobset_name,
99
+ main_job_name,
100
+ main_job_index,
101
+ main_pod_index,
102
+ subdomain,
103
+ fqdn_suffix,
104
+ )
105
+
106
+ def _install_jobset(
107
+ repo_url="https://github.com/kubernetes-sigs/jobset",
108
+ python_sdk_path="jobset/sdk/python",
109
+ ):
110
+
111
+ # TODO (Eddie): Remove this and suggest to user.
112
+
113
+ import subprocess
114
+ import tempfile
115
+ import shutil
116
+ import os
117
+
118
+ with open(os.devnull, "wb") as devnull:
119
+ cwd = os.getcwd()
120
+ tmp_dir = tempfile.mkdtemp()
121
+ os.chdir(tmp_dir)
122
+ subprocess.check_call(
123
+ ["git", "clone", repo_url], stdout=devnull, stderr=subprocess.STDOUT
124
+ )
125
+ tmp_python_sdk_path = os.path.join(tmp_dir, python_sdk_path)
126
+ os.chdir(tmp_python_sdk_path)
127
+ subprocess.check_call(
128
+ ["pip", "install", "."], stdout=devnull, stderr=subprocess.STDOUT
129
+ )
130
+ os.chdir(cwd)
131
+ shutil.rmtree(tmp_dir)
132
+
133
+ def _get_passwordless_ssh_service():
134
+
135
+ return client.V1Service(
136
+ api_version="v1",
137
+ kind="Service",
138
+ metadata=client.V1ObjectMeta(
139
+ name=passwordless_ssh_service_name,
140
+ namespace=self._kwargs["namespace"]
141
+ ),
142
+ spec=client.V1ServiceSpec(
143
+ cluster_ip="None",
144
+ internal_traffic_policy="Cluster",
145
+ ip_families=["IPv4"],
146
+ ip_family_policy="SingleStack",
147
+ selector=passwordless_ssh_service_selector,
148
+ session_affinity="None",
149
+ type="ClusterIP",
150
+ ports=[
151
+ client.V1ServicePort(
152
+ name="control",
153
+ port=22,
154
+ protocol="TCP",
155
+ target_port=22
156
+ )
157
+ ]
158
+ )
159
+ )
160
+
161
+ def _get_replicated_job(job_name, parallelism, command):
162
+ return jobset.models.jobset_v1alpha2_replicated_job.JobsetV1alpha2ReplicatedJob(
163
+ name=job_name,
164
+ template=client.V1JobTemplateSpec(
103
165
  metadata=client.V1ObjectMeta(
104
166
  annotations=self._kwargs.get("annotations", {}),
105
167
  labels=self._kwargs.get("labels", {}),
106
168
  namespace=self._kwargs["namespace"],
107
169
  ),
108
- spec=client.V1PodSpec(
109
- # Timeout is set on the pod and not the job (important!)
110
- active_deadline_seconds=self._kwargs["timeout_in_seconds"],
111
- # TODO (savin): Enable affinities for GPU scheduling.
112
- # affinity=?,
113
- containers=[
114
- client.V1Container(
115
- command=self._kwargs["command"],
116
- env=[
117
- client.V1EnvVar(name=k, value=str(v))
118
- for k, v in self._kwargs.get(
119
- "environment_variables", {}
120
- ).items()
121
- ]
122
- # And some downward API magic. Add (key, value)
123
- # pairs below to make pod metadata available
124
- # within Kubernetes container.
125
- + [
126
- client.V1EnvVar(
127
- name=k,
128
- value_from=client.V1EnvVarSource(
129
- field_ref=client.V1ObjectFieldSelector(
130
- field_path=str(v)
170
+ spec=client.V1JobSpec(
171
+ parallelism=parallelism, # how many jobs can run at once
172
+ completions=parallelism, # how many Pods the JobSet creates in total
173
+ backoff_limit=0,
174
+ ttl_seconds_after_finished=7
175
+ * 60
176
+ * 60
177
+ * 24,
178
+ template=client.V1PodTemplateSpec(
179
+ metadata=client.V1ObjectMeta(
180
+ annotations=self._kwargs.get("annotations", {}),
181
+ labels={
182
+ **self._kwargs.get("labels", {}),
183
+ **passwordless_ssh_service_selector, # TODO: necessary?
184
+ # TODO: cluster-name, app.kubernetes.io/name necessary?
185
+ },
186
+ namespace=self._kwargs["namespace"],
187
+ ),
188
+ spec=client.V1PodSpec(
189
+ active_deadline_seconds=self._kwargs[
190
+ "timeout_in_seconds"
191
+ ],
192
+ containers=[
193
+ client.V1Container(
194
+ command=command,
195
+ ports=[client.V1ContainerPort(container_port=master_port)] if master_port and job_name=="control" else [],
196
+ env=[
197
+ client.V1EnvVar(name=k, value=str(v))
198
+ for k, v in self._kwargs.get(
199
+ "environment_variables", {}
200
+ ).items()
201
+ ]
202
+ + [
203
+ client.V1EnvVar(
204
+ name=k,
205
+ value_from=client.V1EnvVarSource(
206
+ field_ref=client.V1ObjectFieldSelector(
207
+ field_path=str(v)
208
+ )
209
+ ),
210
+ )
211
+ for k, v in {
212
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
213
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
214
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
215
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
216
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
217
+ }.items()
218
+ ]
219
+ # Mimicking the AWS Batch Multinode env vars.
220
+ + [
221
+ client.V1EnvVar(
222
+ name="MASTER_ADDR",
223
+ value=jobset_main_addr,
224
+ ),
225
+ client.V1EnvVar(
226
+ name="MASTER_PORT",
227
+ value=str(master_port),
228
+ ),
229
+ client.V1EnvVar(
230
+ name="RANK",
231
+ value_from=client.V1EnvVarSource(
232
+ field_ref=client.V1ObjectFieldSelector(
233
+ field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']"
234
+ )
235
+ ),
236
+ ),
237
+ client.V1EnvVar(
238
+ name="WORLD_SIZE",
239
+ value=str(self._kwargs["num_parallel"]),
240
+ ),
241
+ client.V1EnvVar(
242
+ name="PYTHONUNBUFFERED",
243
+ value="0",
244
+ ),
245
+ ],
246
+ env_from=[
247
+ client.V1EnvFromSource(
248
+ secret_ref=client.V1SecretEnvSource(
249
+ name=str(k),
250
+ # optional=True
251
+ )
131
252
  )
253
+ for k in list(
254
+ self._kwargs.get("secrets", [])
255
+ )
256
+ + KUBERNETES_SECRETS.split(",")
257
+ if k
258
+ ],
259
+ image=self._kwargs["image"],
260
+ image_pull_policy=self._kwargs[
261
+ "image_pull_policy"
262
+ ],
263
+ name=self._kwargs["step_name"].replace(
264
+ "_", "-"
132
265
  ),
133
- )
134
- for k, v in {
135
- "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
136
- "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
137
- "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
138
- "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
139
- "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
140
- }.items()
141
- ]
142
- + [
143
- client.V1EnvVar(name=k, value=str(v))
144
- for k, v in inject_tracing_vars({}).items()
145
- ],
146
- env_from=[
147
- client.V1EnvFromSource(
148
- secret_ref=client.V1SecretEnvSource(
149
- name=str(k),
150
- # optional=True
266
+ resources=client.V1ResourceRequirements(
267
+ requests={
268
+ "cpu": str(self._kwargs["cpu"]),
269
+ "memory": "%sM"
270
+ % str(self._kwargs["memory"]),
271
+ "ephemeral-storage": "%sM"
272
+ % str(self._kwargs["disk"]),
273
+ },
274
+ limits={
275
+ "%s.com/gpu".lower()
276
+ % self._kwargs["gpu_vendor"]: str(
277
+ self._kwargs["gpu"]
278
+ )
279
+ for k in [0]
280
+ # Don't set GPU limits if gpu isn't specified.
281
+ if self._kwargs["gpu"] is not None
282
+ },
283
+ ),
284
+ volume_mounts=(
285
+ [
286
+ client.V1VolumeMount(
287
+ mount_path=self._kwargs.get(
288
+ "tmpfs_path"
289
+ ),
290
+ name="tmpfs-ephemeral-volume",
291
+ )
292
+ ]
293
+ if tmpfs_enabled
294
+ else []
151
295
  )
296
+ + (
297
+ [
298
+ client.V1VolumeMount(
299
+ mount_path=path, name=claim
300
+ )
301
+ for claim, path in self._kwargs[
302
+ "persistent_volume_claims"
303
+ ].items()
304
+ ]
305
+ if self._kwargs["persistent_volume_claims"]
306
+ is not None
307
+ else []
308
+ ),
152
309
  )
153
- for k in list(self._kwargs.get("secrets", []))
154
- + KUBERNETES_SECRETS.split(",")
155
- if k
156
310
  ],
157
- image=self._kwargs["image"],
158
- image_pull_policy=self._kwargs["image_pull_policy"],
159
- name=self._kwargs["step_name"].replace("_", "-"),
160
- resources=client.V1ResourceRequirements(
161
- requests={
162
- "cpu": str(self._kwargs["cpu"]),
163
- "memory": "%sM" % str(self._kwargs["memory"]),
164
- "ephemeral-storage": "%sM"
165
- % str(self._kwargs["disk"]),
166
- },
167
- limits={
168
- "%s.com/gpu".lower()
169
- % self._kwargs["gpu_vendor"]: str(
170
- self._kwargs["gpu"]
171
- )
172
- for k in [0]
173
- # Don't set GPU limits if gpu isn't specified.
174
- if self._kwargs["gpu"] is not None
175
- },
176
- ),
177
- volume_mounts=(
311
+ node_selector=self._kwargs.get("node_selector"),
312
+ restart_policy="Never",
313
+
314
+ set_hostname_as_fqdn=True, # configure pod hostname as pod's FQDN
315
+ share_process_namespace=False, # default
316
+ subdomain=subdomain, # FQDN = <hostname>.<subdomain>.<pod namespace>.svc.<cluster domain>
317
+
318
+ service_account_name=self._kwargs["service_account"],
319
+ termination_grace_period_seconds=0,
320
+ tolerations=[
321
+ client.V1Toleration(**toleration)
322
+ for toleration in self._kwargs.get("tolerations")
323
+ or []
324
+ ],
325
+ volumes=(
178
326
  [
179
- client.V1VolumeMount(
180
- mount_path=self._kwargs.get("tmpfs_path"),
327
+ client.V1Volume(
181
328
  name="tmpfs-ephemeral-volume",
329
+ empty_dir=client.V1EmptyDirVolumeSource(
330
+ medium="Memory",
331
+ size_limit="{}Mi".format(tmpfs_size),
332
+ ),
182
333
  )
183
334
  ]
184
335
  if tmpfs_enabled
@@ -186,72 +337,264 @@ class KubernetesJob(object):
186
337
  )
187
338
  + (
188
339
  [
189
- client.V1VolumeMount(
190
- mount_path=path, name=claim
340
+ client.V1Volume(
341
+ name=claim,
342
+ persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
343
+ claim_name=claim
344
+ ),
191
345
  )
192
- for claim, path in self._kwargs[
346
+ for claim in self._kwargs[
193
347
  "persistent_volume_claims"
194
- ].items()
348
+ ].keys()
195
349
  ]
196
350
  if self._kwargs["persistent_volume_claims"]
197
351
  is not None
198
352
  else []
199
353
  ),
200
- )
201
- ],
202
- node_selector=self._kwargs.get("node_selector"),
203
- # TODO (savin): Support image_pull_secrets
204
- # image_pull_secrets=?,
205
- # TODO (savin): Support preemption policies
206
- # preemption_policy=?,
207
- #
208
- # A Container in a Pod may fail for a number of
209
- # reasons, such as because the process in it exited
210
- # with a non-zero exit code, or the Container was
211
- # killed due to OOM etc. If this happens, fail the pod
212
- # and let Metaflow handle the retries.
213
- restart_policy="Never",
214
- service_account_name=self._kwargs["service_account"],
215
- # Terminate the container immediately on SIGTERM
216
- termination_grace_period_seconds=0,
217
- tolerations=[
218
- client.V1Toleration(**toleration)
219
- for toleration in self._kwargs.get("tolerations") or []
220
- ],
221
- volumes=(
222
- [
223
- client.V1Volume(
224
- name="tmpfs-ephemeral-volume",
225
- empty_dir=client.V1EmptyDirVolumeSource(
226
- medium="Memory",
227
- # Add default unit as ours differs from Kubernetes default.
228
- size_limit="{}Mi".format(tmpfs_size),
354
+ ),
355
+ ),
356
+ ),
357
+ ),
358
+ )
359
+
360
+ if "num_parallel" in self._kwargs and self._kwargs["num_parallel"] >= 1:
361
+
362
+ try:
363
+ import jobset
364
+ except ImportError:
365
+ _install_jobset()
366
+ import jobset
367
+
368
+ main_commands = copy.copy(self._kwargs["command"])
369
+ main_commands[-1] = main_commands[-1].replace(
370
+ "[multinode-args]", "--split-index 0"
371
+ )
372
+
373
+ task_id = self._kwargs["attrs"]["metaflow.task_id"]
374
+ secondary_commands = copy.copy(self._kwargs["command"])
375
+ # RANK needs +1 because control node is not in the worker index group, yet we want global nodes.
376
+ # Technically, control and worker could be same replicated job type, but cleaner to separate for future use cases.
377
+ secondary_commands[-1] = secondary_commands[-1].replace(
378
+ "[multinode-args]", "--split-index `expr $RANK + 1`"
379
+ )
380
+ secondary_commands[-1] = secondary_commands[-1].replace(
381
+ "ubf_control", "ubf_task"
382
+ )
383
+ secondary_commands[-1] = secondary_commands[-1].replace(
384
+ task_id,
385
+ task_id.replace("control-", "") + "-node-`expr $RANK + 1`",
386
+ )
387
+
388
+ if passwordless_ssh:
389
+ if not os.path.exists("/usr/sbin/sshd"):
390
+ raise KubernetesJobException(
391
+ "This @parallel decorator requires sshd to be installed in the container image."
392
+ "Please install OpenSSH."
393
+ )
394
+
395
+ # run sshd in background
396
+ main_commands[-1] = "/usr/sbin/sshd -D & %s" % main_commands[-1]
397
+ secondary_commands[-1] = "/usr/sbin/sshd -D & %s" % secondary_commands[-1]
398
+
399
+ self._jobset = jobset.models.jobset_v1alpha2_job_set.JobsetV1alpha2JobSet(
400
+ api_version="jobset.x-k8s.io/v1alpha2",
401
+ kind="JobSet",
402
+ metadata=client.V1ObjectMeta(
403
+ annotations=self._kwargs.get("annotations", {}),
404
+ labels=self._kwargs.get("labels", {}),
405
+ name=jobset_name,
406
+ namespace=self._kwargs["namespace"],
407
+ ),
408
+ spec=jobset.models.jobset_v1alpha2_job_set_spec.JobsetV1alpha2JobSetSpec(
409
+ network=jobset.models.jobset_v1alpha2_network.JobsetV1alpha2Network(
410
+ enable_dns_hostnames=True if not self._kwargs['attrs']['requires_passwordless_ssh'] else False,
411
+ subdomain=subdomain
412
+ ),
413
+ replicated_jobs=[
414
+ _get_replicated_job("control", 1, main_commands),
415
+ _get_replicated_job(
416
+ "worker",
417
+ self._kwargs["num_parallel"] - 1,
418
+ secondary_commands,
419
+ ),
420
+ ],
421
+ ),
422
+ )
423
+ self._passwordless_ssh_service = _get_passwordless_ssh_service()
424
+ else:
425
+ self._job = client.V1Job(
426
+ api_version="batch/v1",
427
+ kind="Job",
428
+ metadata=client.V1ObjectMeta(
429
+ # Annotations are for humans
430
+ annotations=self._kwargs.get("annotations", {}),
431
+ # While labels are for Kubernetes
432
+ labels=self._kwargs.get("labels", {}),
433
+ generate_name=self._kwargs["generate_name"],
434
+ namespace=self._kwargs["namespace"], # Defaults to `default`
435
+ ),
436
+ spec=client.V1JobSpec(
437
+ # Retries are handled by Metaflow when it is responsible for
438
+ # executing the flow. The responsibility is moved to Kubernetes
439
+ # when Argo Workflows is responsible for the execution.
440
+ backoff_limit=self._kwargs.get("retries", 0),
441
+ completions=1, # A single non-indexed pod job
442
+ ttl_seconds_after_finished=7
443
+ * 60
444
+ * 60 # Remove job after a week. TODO: Make this configurable
445
+ * 24,
446
+ template=client.V1PodTemplateSpec(
447
+ metadata=client.V1ObjectMeta(
448
+ annotations=self._kwargs.get("annotations", {}),
449
+ labels=self._kwargs.get("labels", {}),
450
+ namespace=self._kwargs["namespace"],
451
+ ),
452
+ spec=client.V1PodSpec(
453
+ # Timeout is set on the pod and not the job (important!)
454
+ active_deadline_seconds=self._kwargs["timeout_in_seconds"],
455
+ # TODO (savin): Enable affinities for GPU scheduling.
456
+ # affinity=?,
457
+ containers=[
458
+ client.V1Container(
459
+ command=self._kwargs["command"],
460
+ env=[
461
+ client.V1EnvVar(name=k, value=str(v))
462
+ for k, v in self._kwargs.get(
463
+ "environment_variables", {}
464
+ ).items()
465
+ ]
466
+ # And some downward API magic. Add (key, value)
467
+ # pairs below to make pod metadata available
468
+ # within Kubernetes container.
469
+ + [
470
+ client.V1EnvVar(
471
+ name=k,
472
+ value_from=client.V1EnvVarSource(
473
+ field_ref=client.V1ObjectFieldSelector(
474
+ field_path=str(v)
475
+ )
476
+ ),
477
+ )
478
+ for k, v in {
479
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
480
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
481
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
482
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
483
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
484
+ }.items()
485
+ ],
486
+ env_from=[
487
+ client.V1EnvFromSource(
488
+ secret_ref=client.V1SecretEnvSource(
489
+ name=str(k),
490
+ # optional=True
491
+ )
492
+ )
493
+ for k in list(self._kwargs.get("secrets", []))
494
+ + KUBERNETES_SECRETS.split(",")
495
+ if k
496
+ ],
497
+ image=self._kwargs["image"],
498
+ image_pull_policy=self._kwargs["image_pull_policy"],
499
+ name=self._kwargs["step_name"].replace("_", "-"),
500
+ resources=client.V1ResourceRequirements(
501
+ requests={
502
+ "cpu": str(self._kwargs["cpu"]),
503
+ "memory": "%sM"
504
+ % str(self._kwargs["memory"]),
505
+ "ephemeral-storage": "%sM"
506
+ % str(self._kwargs["disk"]),
507
+ },
508
+ limits={
509
+ "%s.com/gpu".lower()
510
+ % self._kwargs["gpu_vendor"]: str(
511
+ self._kwargs["gpu"]
512
+ )
513
+ for k in [0]
514
+ # Don't set GPU limits if gpu isn't specified.
515
+ if self._kwargs["gpu"] is not None
516
+ },
229
517
  ),
230
- )
231
- ]
232
- if tmpfs_enabled
233
- else []
234
- )
235
- + (
236
- [
237
- client.V1Volume(
238
- name=claim,
239
- persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
240
- claim_name=claim
518
+ volume_mounts=(
519
+ [
520
+ client.V1VolumeMount(
521
+ mount_path=self._kwargs.get(
522
+ "tmpfs_path"
523
+ ),
524
+ name="tmpfs-ephemeral-volume",
525
+ )
526
+ ]
527
+ if tmpfs_enabled
528
+ else []
529
+ )
530
+ + (
531
+ [
532
+ client.V1VolumeMount(
533
+ mount_path=path, name=claim
534
+ )
535
+ for claim, path in self._kwargs[
536
+ "persistent_volume_claims"
537
+ ].items()
538
+ ]
539
+ if self._kwargs["persistent_volume_claims"]
540
+ is not None
541
+ else []
241
542
  ),
242
543
  )
243
- for claim in self._kwargs[
244
- "persistent_volume_claims"
245
- ].keys()
246
- ]
247
- if self._kwargs["persistent_volume_claims"] is not None
248
- else []
544
+ ],
545
+ node_selector=self._kwargs.get("node_selector"),
546
+ # TODO (savin): Support image_pull_secrets
547
+ # image_pull_secrets=?,
548
+ # TODO (savin): Support preemption policies
549
+ # preemption_policy=?,
550
+ #
551
+ # A Container in a Pod may fail for a number of
552
+ # reasons, such as because the process in it exited
553
+ # with a non-zero exit code, or the Container was
554
+ # killed due to OOM etc. If this happens, fail the pod
555
+ # and let Metaflow handle the retries.
556
+ restart_policy="Never",
557
+ service_account_name=self._kwargs["service_account"],
558
+ # Terminate the container immediately on SIGTERM
559
+ termination_grace_period_seconds=0,
560
+ tolerations=[
561
+ client.V1Toleration(**toleration)
562
+ for toleration in self._kwargs.get("tolerations") or []
563
+ ],
564
+ volumes=(
565
+ [
566
+ client.V1Volume(
567
+ name="tmpfs-ephemeral-volume",
568
+ empty_dir=client.V1EmptyDirVolumeSource(
569
+ medium="Memory",
570
+ # Add default unit as ours differs from Kubernetes default.
571
+ size_limit="{}Mi".format(tmpfs_size),
572
+ ),
573
+ )
574
+ ]
575
+ if tmpfs_enabled
576
+ else []
577
+ )
578
+ + (
579
+ [
580
+ client.V1Volume(
581
+ name=claim,
582
+ persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
583
+ claim_name=claim
584
+ ),
585
+ )
586
+ for claim in self._kwargs[
587
+ "persistent_volume_claims"
588
+ ].keys()
589
+ ]
590
+ if self._kwargs["persistent_volume_claims"] is not None
591
+ else []
592
+ ),
593
+ # TODO (savin): Set termination_message_policy
249
594
  ),
250
- # TODO (savin): Set termination_message_policy
251
595
  ),
252
596
  ),
253
- ),
254
- )
597
+ )
255
598
  return self
256
599
 
257
600
  def execute(self):
@@ -262,19 +605,53 @@ class KubernetesJob(object):
262
605
  # achieve the guarantees that we are seeking.
263
606
  # https://github.com/kubernetes/enhancements/issues/1040
264
607
  # Hopefully, we will be able to get creative with kube-batch
265
- response = (
266
- client.BatchV1Api()
267
- .create_namespaced_job(
268
- body=self._job, namespace=self._kwargs["namespace"]
608
+
609
+ if "num_parallel" in self._kwargs and self._kwargs["num_parallel"] >= 1:
610
+ # TODO (Eddie): this is kinda gross. fix it.
611
+ if self._kwargs["attrs"]["requires_passwordless_ssh"]:
612
+ api_instance = client.CoreV1Api()
613
+ api_response = api_instance.create_namespaced_service(namespace=self._kwargs['namespace'], body=self._passwordless_ssh_service)
614
+
615
+ with client.ApiClient() as api_client:
616
+ api_instance = client.CustomObjectsApi(api_client)
617
+
618
+ response = api_instance.create_namespaced_custom_object(
619
+ body=self._jobset,
620
+ group="jobset.x-k8s.io",
621
+ version="v1alpha2",
622
+ namespace=self._kwargs["namespace"],
623
+ plural="jobsets",
269
624
  )
270
- .to_dict()
271
- )
272
- return RunningJob(
273
- client=self._client,
274
- name=response["metadata"]["name"],
275
- uid=response["metadata"]["uid"],
276
- namespace=response["metadata"]["namespace"],
277
- )
625
+
626
+ # HACK: Give K8s some time to actually create the job
627
+ time.sleep(10)
628
+
629
+ # TODO (Eddie): Remove hack and make RunningJobSet.
630
+ # There are many jobs running that should be monitored.
631
+ job_name = "%s-control-0" % response["metadata"]["name"]
632
+ fake_id = 123
633
+ return RunningJob(
634
+ client=self._client,
635
+ name=job_name,
636
+ uid=fake_id,
637
+ namespace=response["metadata"]["namespace"],
638
+ )
639
+
640
+ else:
641
+ response = (
642
+ client.BatchV1Api()
643
+ .create_namespaced_job(
644
+ body=self._job, namespace=self._kwargs["namespace"]
645
+ )
646
+ .to_dict()
647
+ )
648
+ return RunningJob(
649
+ client=self._client,
650
+ name=response["metadata"]["name"],
651
+ uid=response["metadata"]["uid"],
652
+ namespace=response["metadata"]["namespace"],
653
+ )
654
+
278
655
  except client.rest.ApiException as e:
279
656
  raise KubernetesJobException(
280
657
  "Unable to launch Kubernetes job.\n %s"
@@ -330,7 +707,6 @@ class KubernetesJob(object):
330
707
 
331
708
 
332
709
  class RunningJob(object):
333
-
334
710
  # State Machine implementation for the lifecycle behavior documented in
335
711
  # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/
336
712
  #
@@ -450,7 +826,6 @@ class RunningJob(object):
450
826
  client = self._client.get()
451
827
  if not self.is_done:
452
828
  if self.is_running:
453
-
454
829
  # Case 1.
455
830
  from kubernetes.stream import stream
456
831