ob-metaflow 2.11.13.1__py2.py3-none-any.whl → 2.19.7.1rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. metaflow/R.py +10 -7
  2. metaflow/__init__.py +40 -25
  3. metaflow/_vendor/imghdr/__init__.py +186 -0
  4. metaflow/_vendor/importlib_metadata/__init__.py +1063 -0
  5. metaflow/_vendor/importlib_metadata/_adapters.py +68 -0
  6. metaflow/_vendor/importlib_metadata/_collections.py +30 -0
  7. metaflow/_vendor/importlib_metadata/_compat.py +71 -0
  8. metaflow/_vendor/importlib_metadata/_functools.py +104 -0
  9. metaflow/_vendor/importlib_metadata/_itertools.py +73 -0
  10. metaflow/_vendor/importlib_metadata/_meta.py +48 -0
  11. metaflow/_vendor/importlib_metadata/_text.py +99 -0
  12. metaflow/_vendor/importlib_metadata/py.typed +0 -0
  13. metaflow/_vendor/typeguard/__init__.py +48 -0
  14. metaflow/_vendor/typeguard/_checkers.py +1070 -0
  15. metaflow/_vendor/typeguard/_config.py +108 -0
  16. metaflow/_vendor/typeguard/_decorators.py +233 -0
  17. metaflow/_vendor/typeguard/_exceptions.py +42 -0
  18. metaflow/_vendor/typeguard/_functions.py +308 -0
  19. metaflow/_vendor/typeguard/_importhook.py +213 -0
  20. metaflow/_vendor/typeguard/_memo.py +48 -0
  21. metaflow/_vendor/typeguard/_pytest_plugin.py +127 -0
  22. metaflow/_vendor/typeguard/_suppression.py +86 -0
  23. metaflow/_vendor/typeguard/_transformer.py +1229 -0
  24. metaflow/_vendor/typeguard/_union_transformer.py +55 -0
  25. metaflow/_vendor/typeguard/_utils.py +173 -0
  26. metaflow/_vendor/typeguard/py.typed +0 -0
  27. metaflow/_vendor/typing_extensions.py +3641 -0
  28. metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
  29. metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
  30. metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
  31. metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
  32. metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
  33. metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
  34. metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
  35. metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
  36. metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
  37. metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
  38. metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
  39. metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
  40. metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
  41. metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
  42. metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
  43. metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
  44. metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
  45. metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
  46. metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
  47. metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
  48. metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
  49. metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
  50. metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
  51. metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
  52. metaflow/_vendor/yaml/__init__.py +427 -0
  53. metaflow/_vendor/yaml/composer.py +139 -0
  54. metaflow/_vendor/yaml/constructor.py +748 -0
  55. metaflow/_vendor/yaml/cyaml.py +101 -0
  56. metaflow/_vendor/yaml/dumper.py +62 -0
  57. metaflow/_vendor/yaml/emitter.py +1137 -0
  58. metaflow/_vendor/yaml/error.py +75 -0
  59. metaflow/_vendor/yaml/events.py +86 -0
  60. metaflow/_vendor/yaml/loader.py +63 -0
  61. metaflow/_vendor/yaml/nodes.py +49 -0
  62. metaflow/_vendor/yaml/parser.py +589 -0
  63. metaflow/_vendor/yaml/reader.py +185 -0
  64. metaflow/_vendor/yaml/representer.py +389 -0
  65. metaflow/_vendor/yaml/resolver.py +227 -0
  66. metaflow/_vendor/yaml/scanner.py +1435 -0
  67. metaflow/_vendor/yaml/serializer.py +111 -0
  68. metaflow/_vendor/yaml/tokens.py +104 -0
  69. metaflow/cards.py +5 -0
  70. metaflow/cli.py +331 -785
  71. metaflow/cli_args.py +17 -0
  72. metaflow/cli_components/__init__.py +0 -0
  73. metaflow/cli_components/dump_cmd.py +96 -0
  74. metaflow/cli_components/init_cmd.py +52 -0
  75. metaflow/cli_components/run_cmds.py +546 -0
  76. metaflow/cli_components/step_cmd.py +334 -0
  77. metaflow/cli_components/utils.py +140 -0
  78. metaflow/client/__init__.py +1 -0
  79. metaflow/client/core.py +467 -73
  80. metaflow/client/filecache.py +75 -35
  81. metaflow/clone_util.py +7 -1
  82. metaflow/cmd/code/__init__.py +231 -0
  83. metaflow/cmd/develop/stub_generator.py +756 -288
  84. metaflow/cmd/develop/stubs.py +12 -28
  85. metaflow/cmd/main_cli.py +6 -4
  86. metaflow/cmd/make_wrapper.py +78 -0
  87. metaflow/datastore/__init__.py +1 -0
  88. metaflow/datastore/content_addressed_store.py +41 -10
  89. metaflow/datastore/datastore_set.py +11 -2
  90. metaflow/datastore/flow_datastore.py +156 -10
  91. metaflow/datastore/spin_datastore.py +91 -0
  92. metaflow/datastore/task_datastore.py +154 -39
  93. metaflow/debug.py +5 -0
  94. metaflow/decorators.py +404 -78
  95. metaflow/exception.py +8 -2
  96. metaflow/extension_support/__init__.py +527 -376
  97. metaflow/extension_support/_empty_file.py +2 -2
  98. metaflow/extension_support/plugins.py +49 -31
  99. metaflow/flowspec.py +482 -33
  100. metaflow/graph.py +210 -42
  101. metaflow/includefile.py +84 -40
  102. metaflow/lint.py +141 -22
  103. metaflow/meta_files.py +13 -0
  104. metaflow/{metadata → metadata_provider}/heartbeat.py +24 -8
  105. metaflow/{metadata → metadata_provider}/metadata.py +86 -1
  106. metaflow/metaflow_config.py +175 -28
  107. metaflow/metaflow_config_funcs.py +51 -3
  108. metaflow/metaflow_current.py +4 -10
  109. metaflow/metaflow_environment.py +139 -53
  110. metaflow/metaflow_git.py +115 -0
  111. metaflow/metaflow_profile.py +18 -0
  112. metaflow/metaflow_version.py +150 -66
  113. metaflow/mflog/__init__.py +4 -3
  114. metaflow/mflog/save_logs.py +2 -2
  115. metaflow/multicore_utils.py +31 -14
  116. metaflow/package/__init__.py +673 -0
  117. metaflow/packaging_sys/__init__.py +880 -0
  118. metaflow/packaging_sys/backend.py +128 -0
  119. metaflow/packaging_sys/distribution_support.py +153 -0
  120. metaflow/packaging_sys/tar_backend.py +99 -0
  121. metaflow/packaging_sys/utils.py +54 -0
  122. metaflow/packaging_sys/v1.py +527 -0
  123. metaflow/parameters.py +149 -28
  124. metaflow/plugins/__init__.py +74 -5
  125. metaflow/plugins/airflow/airflow.py +40 -25
  126. metaflow/plugins/airflow/airflow_cli.py +22 -5
  127. metaflow/plugins/airflow/airflow_decorator.py +1 -1
  128. metaflow/plugins/airflow/airflow_utils.py +5 -3
  129. metaflow/plugins/airflow/sensors/base_sensor.py +4 -4
  130. metaflow/plugins/airflow/sensors/external_task_sensor.py +2 -2
  131. metaflow/plugins/airflow/sensors/s3_sensor.py +2 -2
  132. metaflow/plugins/argo/argo_client.py +78 -33
  133. metaflow/plugins/argo/argo_events.py +6 -6
  134. metaflow/plugins/argo/argo_workflows.py +2410 -527
  135. metaflow/plugins/argo/argo_workflows_cli.py +571 -121
  136. metaflow/plugins/argo/argo_workflows_decorator.py +43 -12
  137. metaflow/plugins/argo/argo_workflows_deployer.py +106 -0
  138. metaflow/plugins/argo/argo_workflows_deployer_objects.py +453 -0
  139. metaflow/plugins/argo/capture_error.py +73 -0
  140. metaflow/plugins/argo/conditional_input_paths.py +35 -0
  141. metaflow/plugins/argo/exit_hooks.py +209 -0
  142. metaflow/plugins/argo/jobset_input_paths.py +15 -0
  143. metaflow/plugins/argo/param_val.py +19 -0
  144. metaflow/plugins/aws/aws_client.py +10 -3
  145. metaflow/plugins/aws/aws_utils.py +55 -2
  146. metaflow/plugins/aws/batch/batch.py +72 -5
  147. metaflow/plugins/aws/batch/batch_cli.py +33 -10
  148. metaflow/plugins/aws/batch/batch_client.py +4 -3
  149. metaflow/plugins/aws/batch/batch_decorator.py +102 -35
  150. metaflow/plugins/aws/secrets_manager/aws_secrets_manager_secrets_provider.py +13 -10
  151. metaflow/plugins/aws/step_functions/dynamo_db_client.py +0 -3
  152. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  153. metaflow/plugins/aws/step_functions/step_functions.py +65 -8
  154. metaflow/plugins/aws/step_functions/step_functions_cli.py +101 -7
  155. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -2
  156. metaflow/plugins/aws/step_functions/step_functions_deployer.py +97 -0
  157. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +264 -0
  158. metaflow/plugins/azure/azure_exceptions.py +1 -1
  159. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  160. metaflow/plugins/azure/azure_tail.py +1 -1
  161. metaflow/plugins/azure/includefile_support.py +2 -0
  162. metaflow/plugins/cards/card_cli.py +66 -30
  163. metaflow/plugins/cards/card_creator.py +25 -1
  164. metaflow/plugins/cards/card_datastore.py +21 -49
  165. metaflow/plugins/cards/card_decorator.py +132 -8
  166. metaflow/plugins/cards/card_modules/basic.py +112 -17
  167. metaflow/plugins/cards/card_modules/bundle.css +1 -1
  168. metaflow/plugins/cards/card_modules/card.py +16 -1
  169. metaflow/plugins/cards/card_modules/chevron/renderer.py +1 -1
  170. metaflow/plugins/cards/card_modules/components.py +665 -28
  171. metaflow/plugins/cards/card_modules/convert_to_native_type.py +36 -7
  172. metaflow/plugins/cards/card_modules/json_viewer.py +232 -0
  173. metaflow/plugins/cards/card_modules/main.css +1 -0
  174. metaflow/plugins/cards/card_modules/main.js +68 -49
  175. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  176. metaflow/plugins/cards/card_modules/test_cards.py +26 -12
  177. metaflow/plugins/cards/card_server.py +39 -14
  178. metaflow/plugins/cards/component_serializer.py +2 -9
  179. metaflow/plugins/cards/metadata.py +22 -0
  180. metaflow/plugins/catch_decorator.py +9 -0
  181. metaflow/plugins/datastores/azure_storage.py +10 -1
  182. metaflow/plugins/datastores/gs_storage.py +6 -2
  183. metaflow/plugins/datastores/local_storage.py +12 -6
  184. metaflow/plugins/datastores/spin_storage.py +12 -0
  185. metaflow/plugins/datatools/local.py +2 -0
  186. metaflow/plugins/datatools/s3/s3.py +126 -75
  187. metaflow/plugins/datatools/s3/s3op.py +254 -121
  188. metaflow/plugins/env_escape/__init__.py +3 -3
  189. metaflow/plugins/env_escape/client_modules.py +102 -72
  190. metaflow/plugins/env_escape/server.py +7 -0
  191. metaflow/plugins/env_escape/stub.py +24 -5
  192. metaflow/plugins/events_decorator.py +343 -185
  193. metaflow/plugins/exit_hook/__init__.py +0 -0
  194. metaflow/plugins/exit_hook/exit_hook_decorator.py +46 -0
  195. metaflow/plugins/exit_hook/exit_hook_script.py +52 -0
  196. metaflow/plugins/gcp/__init__.py +1 -1
  197. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +11 -6
  198. metaflow/plugins/gcp/gs_tail.py +10 -6
  199. metaflow/plugins/gcp/includefile_support.py +3 -0
  200. metaflow/plugins/kubernetes/kube_utils.py +108 -0
  201. metaflow/plugins/kubernetes/kubernetes.py +411 -130
  202. metaflow/plugins/kubernetes/kubernetes_cli.py +168 -36
  203. metaflow/plugins/kubernetes/kubernetes_client.py +104 -2
  204. metaflow/plugins/kubernetes/kubernetes_decorator.py +246 -88
  205. metaflow/plugins/kubernetes/kubernetes_job.py +253 -581
  206. metaflow/plugins/kubernetes/kubernetes_jobsets.py +1071 -0
  207. metaflow/plugins/kubernetes/spot_metadata_cli.py +69 -0
  208. metaflow/plugins/kubernetes/spot_monitor_sidecar.py +109 -0
  209. metaflow/plugins/logs_cli.py +359 -0
  210. metaflow/plugins/{metadata → metadata_providers}/local.py +144 -84
  211. metaflow/plugins/{metadata → metadata_providers}/service.py +103 -26
  212. metaflow/plugins/metadata_providers/spin.py +16 -0
  213. metaflow/plugins/package_cli.py +36 -24
  214. metaflow/plugins/parallel_decorator.py +128 -11
  215. metaflow/plugins/parsers.py +16 -0
  216. metaflow/plugins/project_decorator.py +51 -5
  217. metaflow/plugins/pypi/bootstrap.py +357 -105
  218. metaflow/plugins/pypi/conda_decorator.py +82 -81
  219. metaflow/plugins/pypi/conda_environment.py +187 -52
  220. metaflow/plugins/pypi/micromamba.py +157 -47
  221. metaflow/plugins/pypi/parsers.py +268 -0
  222. metaflow/plugins/pypi/pip.py +88 -13
  223. metaflow/plugins/pypi/pypi_decorator.py +37 -1
  224. metaflow/plugins/pypi/utils.py +48 -2
  225. metaflow/plugins/resources_decorator.py +2 -2
  226. metaflow/plugins/secrets/__init__.py +3 -0
  227. metaflow/plugins/secrets/secrets_decorator.py +26 -181
  228. metaflow/plugins/secrets/secrets_func.py +49 -0
  229. metaflow/plugins/secrets/secrets_spec.py +101 -0
  230. metaflow/plugins/secrets/utils.py +74 -0
  231. metaflow/plugins/tag_cli.py +4 -7
  232. metaflow/plugins/test_unbounded_foreach_decorator.py +41 -6
  233. metaflow/plugins/timeout_decorator.py +3 -3
  234. metaflow/plugins/uv/__init__.py +0 -0
  235. metaflow/plugins/uv/bootstrap.py +128 -0
  236. metaflow/plugins/uv/uv_environment.py +72 -0
  237. metaflow/procpoll.py +1 -1
  238. metaflow/pylint_wrapper.py +5 -1
  239. metaflow/runner/__init__.py +0 -0
  240. metaflow/runner/click_api.py +717 -0
  241. metaflow/runner/deployer.py +470 -0
  242. metaflow/runner/deployer_impl.py +201 -0
  243. metaflow/runner/metaflow_runner.py +714 -0
  244. metaflow/runner/nbdeploy.py +132 -0
  245. metaflow/runner/nbrun.py +225 -0
  246. metaflow/runner/subprocess_manager.py +650 -0
  247. metaflow/runner/utils.py +335 -0
  248. metaflow/runtime.py +1078 -260
  249. metaflow/sidecar/sidecar_worker.py +1 -1
  250. metaflow/system/__init__.py +5 -0
  251. metaflow/system/system_logger.py +85 -0
  252. metaflow/system/system_monitor.py +108 -0
  253. metaflow/system/system_utils.py +19 -0
  254. metaflow/task.py +521 -225
  255. metaflow/tracing/__init__.py +7 -7
  256. metaflow/tracing/span_exporter.py +31 -38
  257. metaflow/tracing/tracing_modules.py +38 -43
  258. metaflow/tuple_util.py +27 -0
  259. metaflow/user_configs/__init__.py +0 -0
  260. metaflow/user_configs/config_options.py +563 -0
  261. metaflow/user_configs/config_parameters.py +598 -0
  262. metaflow/user_decorators/__init__.py +0 -0
  263. metaflow/user_decorators/common.py +144 -0
  264. metaflow/user_decorators/mutable_flow.py +512 -0
  265. metaflow/user_decorators/mutable_step.py +424 -0
  266. metaflow/user_decorators/user_flow_decorator.py +264 -0
  267. metaflow/user_decorators/user_step_decorator.py +749 -0
  268. metaflow/util.py +243 -27
  269. metaflow/vendor.py +23 -7
  270. metaflow/version.py +1 -1
  271. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Makefile +355 -0
  272. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Tiltfile +726 -0
  273. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/pick_services.sh +105 -0
  274. ob_metaflow-2.19.7.1rc0.dist-info/METADATA +87 -0
  275. ob_metaflow-2.19.7.1rc0.dist-info/RECORD +445 -0
  276. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/WHEEL +1 -1
  277. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/entry_points.txt +1 -0
  278. metaflow/_vendor/v3_5/__init__.py +0 -1
  279. metaflow/_vendor/v3_5/importlib_metadata/__init__.py +0 -644
  280. metaflow/_vendor/v3_5/importlib_metadata/_compat.py +0 -152
  281. metaflow/package.py +0 -188
  282. ob_metaflow-2.11.13.1.dist-info/METADATA +0 -85
  283. ob_metaflow-2.11.13.1.dist-info/RECORD +0 -308
  284. /metaflow/_vendor/{v3_5/zipp.py → zipp.py} +0 -0
  285. /metaflow/{metadata → metadata_provider}/__init__.py +0 -0
  286. /metaflow/{metadata → metadata_provider}/util.py +0 -0
  287. /metaflow/plugins/{metadata → metadata_providers}/__init__.py +0 -0
  288. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info/licenses}/LICENSE +0 -0
  289. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1071 @@
1
+ import json
2
+ import math
3
+ import random
4
+ import time
5
+ from collections import namedtuple
6
+ from metaflow.exception import MetaflowException
7
+ from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
8
+ from metaflow.tracing import inject_tracing_vars
9
+ from metaflow._vendor import yaml
10
+
11
+ from .kube_utils import qos_requests_and_limits
12
+
13
+
14
+ class KubernetesJobsetException(MetaflowException):
15
+ headline = "Kubernetes jobset error"
16
+
17
+
18
+ # TODO [DUPLICATE CODE]: Refactor this method to a separate file so that
19
+ # It can be used by both KubernetesJob and KubernetesJobset
20
+ def k8s_retry(deadline_seconds=60, max_backoff=32):
21
+ def decorator(function):
22
+ from functools import wraps
23
+
24
+ @wraps(function)
25
+ def wrapper(*args, **kwargs):
26
+ from kubernetes import client
27
+
28
+ deadline = time.time() + deadline_seconds
29
+ retry_number = 0
30
+
31
+ while True:
32
+ try:
33
+ result = function(*args, **kwargs)
34
+ return result
35
+ except client.rest.ApiException as e:
36
+ if e.status == 500:
37
+ current_t = time.time()
38
+ backoff_delay = min(
39
+ math.pow(2, retry_number) + random.random(), max_backoff
40
+ )
41
+ if current_t + backoff_delay < deadline:
42
+ time.sleep(backoff_delay)
43
+ retry_number += 1
44
+ continue # retry again
45
+ else:
46
+ raise
47
+ else:
48
+ raise
49
+
50
+ return wrapper
51
+
52
+ return decorator
53
+
54
+
55
+ JobsetStatus = namedtuple(
56
+ "JobsetStatus",
57
+ [
58
+ "control_pod_failed", # boolean
59
+ "control_exit_code",
60
+ "control_pod_status", # string like (<pod-status>):(<container-status>) [used for user-messaging]
61
+ "control_started",
62
+ "control_completed",
63
+ "worker_pods_failed",
64
+ "workers_are_suspended",
65
+ "workers_have_started",
66
+ "all_jobs_are_suspended",
67
+ "jobset_finished",
68
+ "jobset_failed",
69
+ "status_unknown",
70
+ "jobset_was_terminated",
71
+ "some_jobs_are_running",
72
+ ],
73
+ )
74
+
75
+
76
+ def _basic_validation_for_js(jobset):
77
+ if not jobset.get("status") or not _retrieve_replicated_job_statuses(jobset):
78
+ return False
79
+ worker_jobs = [
80
+ w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "worker"
81
+ ]
82
+ if len(worker_jobs) == 0:
83
+ raise KubernetesJobsetException("No worker jobs found in the jobset manifest")
84
+ control_job = [
85
+ w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "control"
86
+ ]
87
+ if len(control_job) == 0:
88
+ raise KubernetesJobsetException("No control job found in the jobset manifest")
89
+ return True
90
+
91
+
92
+ def _derive_pod_status_and_status_code(control_pod):
93
+ overall_status = None
94
+ control_exit_code = None
95
+ control_pod_failed = False
96
+ if control_pod:
97
+ container_status = None
98
+ pod_status = control_pod.get("status", {}).get("phase")
99
+ container_statuses = control_pod.get("status", {}).get("containerStatuses")
100
+ if container_statuses is None:
101
+ container_status = ": ".join(
102
+ filter(
103
+ None,
104
+ [
105
+ control_pod.get("status", {}).get("reason"),
106
+ control_pod.get("status", {}).get("message"),
107
+ ],
108
+ )
109
+ )
110
+ else:
111
+ for k, v in container_statuses[0].get("state", {}).items():
112
+ if v is not None:
113
+ control_exit_code = v.get("exit_code")
114
+ container_status = ": ".join(
115
+ filter(
116
+ None,
117
+ [v.get("reason"), v.get("message")],
118
+ )
119
+ )
120
+ if container_status is None:
121
+ overall_status = "pod status: %s | container status: %s" % (
122
+ pod_status,
123
+ container_status,
124
+ )
125
+ else:
126
+ overall_status = "pod status: %s" % pod_status
127
+ if pod_status == "Failed":
128
+ control_pod_failed = True
129
+ return overall_status, control_exit_code, control_pod_failed
130
+
131
+
132
+ def _retrieve_replicated_job_statuses(jobset):
133
+ # We needed this abstraction because Jobsets changed thier schema
134
+ # in version v0.3.0 where `ReplicatedJobsStatus` became `replicatedJobsStatus`
135
+ # So to handle users having an older version of jobsets, we need to account
136
+ # for both the schemas.
137
+ if jobset.get("status", {}).get("replicatedJobsStatus", None):
138
+ return jobset.get("status").get("replicatedJobsStatus")
139
+ elif jobset.get("status", {}).get("ReplicatedJobsStatus", None):
140
+ return jobset.get("status").get("ReplicatedJobsStatus")
141
+ return None
142
+
143
+
144
+ def _construct_jobset_logical_status(jobset, control_pod=None):
145
+ if not _basic_validation_for_js(jobset):
146
+ return JobsetStatus(
147
+ control_started=False,
148
+ control_completed=False,
149
+ workers_are_suspended=False,
150
+ workers_have_started=False,
151
+ all_jobs_are_suspended=False,
152
+ jobset_finished=False,
153
+ jobset_failed=False,
154
+ status_unknown=True,
155
+ jobset_was_terminated=False,
156
+ control_exit_code=None,
157
+ control_pod_status=None,
158
+ worker_pods_failed=False,
159
+ control_pod_failed=False,
160
+ some_jobs_are_running=False,
161
+ )
162
+
163
+ js_status = jobset.get("status")
164
+
165
+ control_started = False
166
+ control_completed = False
167
+ workers_are_suspended = False
168
+ workers_have_started = False
169
+ all_jobs_are_suspended = jobset.get("spec", {}).get("suspend", False)
170
+ jobset_finished = False
171
+ jobset_failed = False
172
+ status_unknown = False
173
+ jobset_was_terminated = False
174
+ worker_pods_failed = False
175
+ some_jobs_are_running = False
176
+
177
+ total_worker_jobs = [
178
+ w["replicas"]
179
+ for w in jobset.get("spec").get("replicatedJobs", [])
180
+ if w["name"] == "worker"
181
+ ][0]
182
+ total_control_jobs = [
183
+ w["replicas"]
184
+ for w in jobset.get("spec").get("replicatedJobs", [])
185
+ if w["name"] == "control"
186
+ ][0]
187
+
188
+ if total_worker_jobs == 0 and total_control_jobs == 0:
189
+ jobset_was_terminated = True
190
+
191
+ replicated_job_statuses = _retrieve_replicated_job_statuses(jobset)
192
+ for job_status in replicated_job_statuses:
193
+ if job_status["active"] > 0:
194
+ some_jobs_are_running = True
195
+
196
+ if job_status["name"] == "control":
197
+ control_started = job_status["active"] > 0 or job_status["succeeded"] > 0
198
+ control_completed = job_status["succeeded"] > 0
199
+ if job_status["failed"] > 0:
200
+ jobset_failed = True
201
+
202
+ if job_status["name"] == "worker":
203
+ workers_have_started = job_status["active"] == total_worker_jobs
204
+ if "suspended" in job_status:
205
+ # `replicatedJobStatus` didn't have `suspend` field
206
+ # until v0.3.0. So we need to account for that.
207
+ workers_are_suspended = job_status["suspended"] > 0
208
+ if job_status["failed"] > 0:
209
+ worker_pods_failed = True
210
+ jobset_failed = True
211
+
212
+ if js_status.get("conditions"):
213
+ for condition in js_status["conditions"]:
214
+ if condition["type"] == "Completed":
215
+ jobset_finished = True
216
+ if condition["type"] == "Failed":
217
+ jobset_failed = True
218
+
219
+ (
220
+ overall_status,
221
+ control_exit_code,
222
+ control_pod_failed,
223
+ ) = _derive_pod_status_and_status_code(control_pod)
224
+
225
+ return JobsetStatus(
226
+ control_started=control_started,
227
+ control_completed=control_completed,
228
+ workers_are_suspended=workers_are_suspended,
229
+ workers_have_started=workers_have_started,
230
+ all_jobs_are_suspended=all_jobs_are_suspended,
231
+ jobset_finished=jobset_finished,
232
+ jobset_failed=jobset_failed,
233
+ status_unknown=status_unknown,
234
+ jobset_was_terminated=jobset_was_terminated,
235
+ control_exit_code=control_exit_code,
236
+ control_pod_status=overall_status,
237
+ worker_pods_failed=worker_pods_failed,
238
+ control_pod_failed=control_pod_failed,
239
+ some_jobs_are_running=some_jobs_are_running,
240
+ )
241
+
242
+
243
+ class RunningJobSet(object):
244
+ def __init__(self, client, name, namespace, group, version):
245
+ self._client = client
246
+ self._name = name
247
+ self._pod_name = None
248
+ self._namespace = namespace
249
+ self._group = group
250
+ self._version = version
251
+ self._pod = self._fetch_pod()
252
+ self._jobset = self._fetch_jobset()
253
+
254
+ import atexit
255
+
256
+ def best_effort_kill():
257
+ try:
258
+ self.kill()
259
+ except Exception:
260
+ pass
261
+
262
+ atexit.register(best_effort_kill)
263
+
264
+ def __repr__(self):
265
+ return "{}('{}/{}')".format(
266
+ self.__class__.__name__, self._namespace, self._name
267
+ )
268
+
269
+ @k8s_retry()
270
+ def _fetch_jobset(
271
+ self,
272
+ ):
273
+ # name : name of jobset.
274
+ # namespace : namespace of the jobset
275
+ # Query the jobset and return the object's status field as a JSON object
276
+ client = self._client.get()
277
+ with client.ApiClient() as api_client:
278
+ api_instance = client.CustomObjectsApi(api_client)
279
+ try:
280
+ jobset = api_instance.get_namespaced_custom_object(
281
+ group=self._group,
282
+ version=self._version,
283
+ namespace=self._namespace,
284
+ plural="jobsets",
285
+ name=self._name,
286
+ )
287
+ return jobset
288
+ except client.rest.ApiException as e:
289
+ if e.status == 404:
290
+ raise KubernetesJobsetException(
291
+ "Unable to locate Kubernetes jobset %s" % self._name
292
+ )
293
+ raise
294
+
295
+ @k8s_retry()
296
+ def _fetch_pod(self):
297
+ # Fetch pod metadata.
298
+ client = self._client.get()
299
+ pods = (
300
+ client.CoreV1Api()
301
+ .list_namespaced_pod(
302
+ namespace=self._namespace,
303
+ label_selector="jobset.sigs.k8s.io/jobset-name={}".format(self._name),
304
+ )
305
+ .to_dict()["items"]
306
+ )
307
+ if pods:
308
+ for pod in pods:
309
+ # check the labels of the pod to see if
310
+ # the `jobset.sigs.k8s.io/replicatedjob-name` is set to `control`
311
+ if (
312
+ pod["metadata"]["labels"].get(
313
+ "jobset.sigs.k8s.io/replicatedjob-name"
314
+ )
315
+ == "control"
316
+ ):
317
+ return pod
318
+ return {}
319
+
320
+ def kill(self):
321
+ plural = "jobsets"
322
+ client = self._client.get()
323
+ if not (self.is_running or self.is_waiting):
324
+ return
325
+ try:
326
+ # Killing the control pod will trigger the jobset to mark everything as failed.
327
+ # Since jobsets have a successPolicy set to `All` which ensures that everything has
328
+ # to succeed for the jobset to succeed.
329
+ from kubernetes.stream import stream
330
+
331
+ control_pod = self._fetch_pod()
332
+ stream(
333
+ client.CoreV1Api().connect_get_namespaced_pod_exec,
334
+ name=control_pod["metadata"]["name"],
335
+ namespace=control_pod["metadata"]["namespace"],
336
+ command=[
337
+ "/bin/sh",
338
+ "-c",
339
+ "/sbin/killall5",
340
+ ],
341
+ stderr=True,
342
+ stdin=False,
343
+ stdout=True,
344
+ tty=False,
345
+ )
346
+ except Exception:
347
+ with client.ApiClient() as api_client:
348
+ # If we are unable to kill the control pod then
349
+ # Delete the jobset to kill the subsequent pods.
350
+ # There are a few reasons for deleting a jobset to kill it :
351
+ # 1. Jobset has a `suspend` attribute to suspend it's execution, but this
352
+ # doesn't play nicely when jobsets are deployed with other components like kueue.
353
+ # 2. Jobset doesn't play nicely when we mutate status
354
+ # 3. Deletion is a gaurenteed way of removing any pods.
355
+ api_instance = client.CustomObjectsApi(api_client)
356
+ try:
357
+ api_instance.delete_namespaced_custom_object(
358
+ group=self._group,
359
+ version=self._version,
360
+ namespace=self._namespace,
361
+ plural=plural,
362
+ name=self._name,
363
+ )
364
+ except Exception as e:
365
+ raise KubernetesJobsetException(
366
+ "Exception when deleting existing jobset: %s\n" % e
367
+ )
368
+
369
+ @property
370
+ def id(self):
371
+ if self._pod_name:
372
+ return "pod %s" % self._pod_name
373
+ if self._pod:
374
+ self._pod_name = self._pod["metadata"]["name"]
375
+ return self.id
376
+ return "jobset %s" % self._name
377
+
378
+ @property
379
+ def is_done(self):
380
+ def done():
381
+ return (
382
+ self._jobset_is_completed
383
+ or self._jobset_has_failed
384
+ or self._jobset_was_terminated
385
+ )
386
+
387
+ if not done():
388
+ # If not done, fetch newer status
389
+ self._jobset = self._fetch_jobset()
390
+ self._pod = self._fetch_pod()
391
+ return done()
392
+
393
+ @property
394
+ def status(self):
395
+ if self.is_done:
396
+ return "Jobset is done"
397
+
398
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
399
+ if status.status_unknown:
400
+ return "Jobset status is unknown"
401
+ if status.control_started:
402
+ if status.control_pod_status:
403
+ return "Jobset is running: %s" % status.control_pod_status
404
+ return "Jobset is running"
405
+ if status.all_jobs_are_suspended:
406
+ return "Jobset is waiting to be unsuspended"
407
+
408
+ return "Jobset waiting for jobs to start"
409
+
410
+ @property
411
+ def has_succeeded(self):
412
+ return self.is_done and self._jobset_is_completed
413
+
414
+ @property
415
+ def has_failed(self):
416
+ return self.is_done and self._jobset_has_failed
417
+
418
+ @property
419
+ def is_running(self):
420
+ if self.is_done:
421
+ return False
422
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
423
+ if status.some_jobs_are_running:
424
+ return True
425
+ return False
426
+
427
+ @property
428
+ def _jobset_was_terminated(self):
429
+ return _construct_jobset_logical_status(
430
+ self._jobset, control_pod=self._pod
431
+ ).jobset_was_terminated
432
+
433
+ @property
434
+ def is_waiting(self):
435
+ return not self.is_done and not self.is_running
436
+
437
+ @property
438
+ def reason(self):
439
+ # return exit code and reason
440
+ if self.is_done and not self.has_succeeded:
441
+ self._pod = self._fetch_pod()
442
+ elif self.has_succeeded:
443
+ return 0, None
444
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
445
+ if status.control_pod_failed:
446
+ return (
447
+ status.control_exit_code,
448
+ "control-pod failed [%s]" % status.control_pod_status,
449
+ )
450
+ elif status.worker_pods_failed:
451
+ return None, "Worker pods failed"
452
+ return None, None
453
+
454
+ @property
455
+ def _jobset_is_completed(self):
456
+ return _construct_jobset_logical_status(
457
+ self._jobset, control_pod=self._pod
458
+ ).jobset_finished
459
+
460
+ @property
461
+ def _jobset_has_failed(self):
462
+ return _construct_jobset_logical_status(
463
+ self._jobset, control_pod=self._pod
464
+ ).jobset_failed
465
+
466
+
467
+ def _make_domain_name(
468
+ jobset_name, main_job_name, main_job_index, main_pod_index, namespace
469
+ ):
470
+ return "%s-%s-%s-%s.%s.%s.svc.cluster.local" % (
471
+ jobset_name,
472
+ main_job_name,
473
+ main_job_index,
474
+ main_pod_index,
475
+ jobset_name,
476
+ namespace,
477
+ )
478
+
479
+
480
+ class JobSetSpec(object):
481
+ def __init__(self, kubernetes_sdk, name, **kwargs):
482
+ self._kubernetes_sdk = kubernetes_sdk
483
+ self._kwargs = kwargs
484
+ self.name = name
485
+
486
+ def replicas(self, replicas):
487
+ self._kwargs["replicas"] = replicas
488
+ return self
489
+
490
+ def step_name(self, step_name):
491
+ self._kwargs["step_name"] = step_name
492
+ return self
493
+
494
+ def namespace(self, namespace):
495
+ self._kwargs["namespace"] = namespace
496
+ return self
497
+
498
+ def command(self, command):
499
+ self._kwargs["command"] = command
500
+ return self
501
+
502
+ def image(self, image):
503
+ self._kwargs["image"] = image
504
+ return self
505
+
506
+ def cpu(self, cpu):
507
+ self._kwargs["cpu"] = cpu
508
+ return self
509
+
510
+ def memory(self, mem):
511
+ self._kwargs["memory"] = mem
512
+ return self
513
+
514
+ def environment_variable(self, name, value):
515
+ # Never set to None
516
+ if value is None:
517
+ return self
518
+ self._kwargs["environment_variables"] = dict(
519
+ self._kwargs.get("environment_variables", {}), **{name: value}
520
+ )
521
+ return self
522
+
523
+ def secret(self, name):
524
+ if name is None:
525
+ return self
526
+ if len(self._kwargs.get("secrets", [])) == 0:
527
+ self._kwargs["secrets"] = []
528
+ self._kwargs["secrets"] = list(set(self._kwargs["secrets"] + [name]))
529
+
530
+ def environment_variable_from_selector(self, name, label_value):
531
+ # Never set to None
532
+ if label_value is None:
533
+ return self
534
+ self._kwargs["environment_variables_from_selectors"] = dict(
535
+ self._kwargs.get("environment_variables_from_selectors", {}),
536
+ **{name: label_value}
537
+ )
538
+ return self
539
+
540
+ def label(self, name, value):
541
+ self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value})
542
+ return self
543
+
544
+ def annotation(self, name, value):
545
+ self._kwargs["annotations"] = dict(
546
+ self._kwargs.get("annotations", {}), **{name: value}
547
+ )
548
+ return self
549
+
550
+ def dump(self):
551
+ client = self._kubernetes_sdk
552
+ use_tmpfs = self._kwargs["use_tmpfs"]
553
+ tmpfs_size = self._kwargs["tmpfs_size"]
554
+ tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
555
+ shared_memory = (
556
+ int(self._kwargs["shared_memory"])
557
+ if self._kwargs["shared_memory"]
558
+ else None
559
+ )
560
+ qos_requests, qos_limits = qos_requests_and_limits(
561
+ self._kwargs["qos"],
562
+ self._kwargs["cpu"],
563
+ self._kwargs["memory"],
564
+ self._kwargs["disk"],
565
+ )
566
+ security_context = self._kwargs.get("security_context", {})
567
+ _security_context = {}
568
+ if security_context is not None and len(security_context) > 0:
569
+ _security_context = {
570
+ "security_context": client.V1SecurityContext(**security_context)
571
+ }
572
+ return dict(
573
+ name=self.name,
574
+ template=client.api_client.ApiClient().sanitize_for_serialization(
575
+ client.V1JobTemplateSpec(
576
+ metadata=client.V1ObjectMeta(
577
+ namespace=self._kwargs["namespace"],
578
+ # We don't set any annotations here
579
+ # since they have been either set in the JobSpec
580
+ # or on the JobSet level
581
+ ),
582
+ spec=client.V1JobSpec(
583
+ # Retries are handled by Metaflow when it is responsible for
584
+ # executing the flow. The responsibility is moved to Kubernetes
585
+ # when Argo Workflows is responsible for the execution.
586
+ backoff_limit=self._kwargs.get("retries", 0),
587
+ completions=1,
588
+ parallelism=1,
589
+ ttl_seconds_after_finished=7
590
+ * 60
591
+ * 60 # Remove job after a week. TODO: Make this configurable
592
+ * 24,
593
+ template=client.V1PodTemplateSpec(
594
+ metadata=client.V1ObjectMeta(
595
+ annotations=self._kwargs.get("annotations", {}),
596
+ labels=self._kwargs.get("labels", {}),
597
+ namespace=self._kwargs["namespace"],
598
+ ),
599
+ spec=client.V1PodSpec(
600
+ subdomain=self._kwargs["subdomain"],
601
+ set_hostname_as_fqdn=True,
602
+ # Timeout is set on the pod and not the job (important!)
603
+ active_deadline_seconds=self._kwargs[
604
+ "timeout_in_seconds"
605
+ ],
606
+ # TODO (savin): Enable affinities for GPU scheduling.
607
+ # affinity=?,
608
+ containers=[
609
+ client.V1Container(
610
+ command=self._kwargs["command"],
611
+ termination_message_policy="FallbackToLogsOnError",
612
+ ports=(
613
+ []
614
+ if self._kwargs["port"] is None
615
+ else [
616
+ client.V1ContainerPort(
617
+ container_port=int(
618
+ self._kwargs["port"]
619
+ )
620
+ )
621
+ ]
622
+ ),
623
+ env=[
624
+ client.V1EnvVar(name=k, value=str(v))
625
+ for k, v in self._kwargs.get(
626
+ "environment_variables", {}
627
+ ).items()
628
+ ]
629
+ # And some downward API magic. Add (key, value)
630
+ # pairs below to make pod metadata available
631
+ # within Kubernetes container.
632
+ + [
633
+ client.V1EnvVar(
634
+ name=k,
635
+ value_from=client.V1EnvVarSource(
636
+ field_ref=client.V1ObjectFieldSelector(
637
+ field_path=str(v)
638
+ )
639
+ ),
640
+ )
641
+ for k, v in self._kwargs.get(
642
+ "environment_variables_from_selectors",
643
+ {},
644
+ ).items()
645
+ ]
646
+ + [
647
+ client.V1EnvVar(name=k, value=str(v))
648
+ for k, v in inject_tracing_vars({}).items()
649
+ ],
650
+ env_from=[
651
+ client.V1EnvFromSource(
652
+ secret_ref=client.V1SecretEnvSource(
653
+ name=str(k),
654
+ # optional=True
655
+ )
656
+ )
657
+ for k in list(
658
+ self._kwargs.get("secrets", [])
659
+ )
660
+ if k
661
+ ],
662
+ image=self._kwargs["image"],
663
+ image_pull_policy=self._kwargs[
664
+ "image_pull_policy"
665
+ ],
666
+ name=self._kwargs["step_name"].replace(
667
+ "_", "-"
668
+ ),
669
+ resources=client.V1ResourceRequirements(
670
+ requests=qos_requests,
671
+ limits={
672
+ **qos_limits,
673
+ **{
674
+ "%s.com/gpu".lower()
675
+ % self._kwargs["gpu_vendor"]: str(
676
+ self._kwargs["gpu"]
677
+ )
678
+ for k in [0]
679
+ # Don't set GPU limits if gpu isn't specified.
680
+ if self._kwargs["gpu"] is not None
681
+ },
682
+ },
683
+ ),
684
+ volume_mounts=(
685
+ [
686
+ client.V1VolumeMount(
687
+ mount_path=self._kwargs.get(
688
+ "tmpfs_path"
689
+ ),
690
+ name="tmpfs-ephemeral-volume",
691
+ )
692
+ ]
693
+ if tmpfs_enabled
694
+ else []
695
+ )
696
+ + (
697
+ [
698
+ client.V1VolumeMount(
699
+ mount_path="/dev/shm", name="dhsm"
700
+ )
701
+ ]
702
+ if shared_memory
703
+ else []
704
+ )
705
+ + (
706
+ [
707
+ client.V1VolumeMount(
708
+ mount_path=path, name=claim
709
+ )
710
+ for claim, path in self._kwargs[
711
+ "persistent_volume_claims"
712
+ ].items()
713
+ ]
714
+ if self._kwargs["persistent_volume_claims"]
715
+ is not None
716
+ else []
717
+ ),
718
+ **_security_context,
719
+ )
720
+ ],
721
+ node_selector=self._kwargs.get("node_selector"),
722
+ image_pull_secrets=[
723
+ client.V1LocalObjectReference(secret)
724
+ for secret in self._kwargs.get("image_pull_secrets")
725
+ or []
726
+ ],
727
+ # TODO (savin): Support preemption policies
728
+ # preemption_policy=?,
729
+ #
730
+ # A Container in a Pod may fail for a number of
731
+ # reasons, such as because the process in it exited
732
+ # with a non-zero exit code, or the Container was
733
+ # killed due to OOM etc. If this happens, fail the pod
734
+ # and let Metaflow handle the retries.
735
+ restart_policy="Never",
736
+ service_account_name=self._kwargs["service_account"],
737
+ # Terminate the container immediately on SIGTERM
738
+ termination_grace_period_seconds=0,
739
+ tolerations=[
740
+ client.V1Toleration(**toleration)
741
+ for toleration in self._kwargs.get("tolerations")
742
+ or []
743
+ ],
744
+ volumes=(
745
+ [
746
+ client.V1Volume(
747
+ name="tmpfs-ephemeral-volume",
748
+ empty_dir=client.V1EmptyDirVolumeSource(
749
+ medium="Memory",
750
+ # Add default unit as ours differs from Kubernetes default.
751
+ size_limit="{}Mi".format(tmpfs_size),
752
+ ),
753
+ )
754
+ ]
755
+ if tmpfs_enabled
756
+ else []
757
+ )
758
+ + (
759
+ [
760
+ client.V1Volume(
761
+ name="dhsm",
762
+ empty_dir=client.V1EmptyDirVolumeSource(
763
+ medium="Memory",
764
+ size_limit="{}Mi".format(shared_memory),
765
+ ),
766
+ )
767
+ ]
768
+ if shared_memory
769
+ else []
770
+ )
771
+ + (
772
+ [
773
+ client.V1Volume(
774
+ name=claim,
775
+ persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
776
+ claim_name=claim
777
+ ),
778
+ )
779
+ for claim in self._kwargs[
780
+ "persistent_volume_claims"
781
+ ].keys()
782
+ ]
783
+ if self._kwargs["persistent_volume_claims"]
784
+ is not None
785
+ else []
786
+ ),
787
+ ),
788
+ ),
789
+ ),
790
+ )
791
+ ),
792
+ replicas=self._kwargs["replicas"],
793
+ )
794
+
795
+
796
+ class KubernetesJobSet(object):
797
+ def __init__(
798
+ self,
799
+ client,
800
+ name=None,
801
+ namespace=None,
802
+ num_parallel=None,
803
+ # explcitly declaring num_parallel because we need to ensure that
804
+ # num_parallel is an INTEGER and this abstraction is called by the
805
+ # local runtime abstraction of kubernetes.
806
+ # Argo will call another abstraction that will allow setting a lot of these
807
+ # values from the top level argo code.
808
+ **kwargs
809
+ ):
810
+ self._client = client
811
+ self._annotations = {}
812
+ self._labels = {}
813
+ self._group = KUBERNETES_JOBSET_GROUP
814
+ self._version = KUBERNETES_JOBSET_VERSION
815
+ self._namespace = namespace
816
+ self.name = name
817
+
818
+ self._jobset_control_addr = _make_domain_name(
819
+ name,
820
+ "control",
821
+ 0,
822
+ 0,
823
+ namespace,
824
+ )
825
+
826
+ self._control_spec = JobSetSpec(
827
+ client.get(), name="control", namespace=namespace, **kwargs
828
+ )
829
+ self._worker_spec = JobSetSpec(
830
+ client.get(), name="worker", namespace=namespace, **kwargs
831
+ )
832
+ assert (
833
+ type(num_parallel) == int
834
+ ), "num_parallel must be an integer" # todo: [final-refactor] : fix-me
835
+
836
+ @property
837
+ def jobset_control_addr(self):
838
+ return self._jobset_control_addr
839
+
840
+ @property
841
+ def worker(self):
842
+ return self._worker_spec
843
+
844
+ @property
845
+ def control(self):
846
+ return self._control_spec
847
+
848
+ def environment_variable_from_selector(self, name, label_value):
849
+ self.worker.environment_variable_from_selector(name, label_value)
850
+ self.control.environment_variable_from_selector(name, label_value)
851
+ return self
852
+
853
+ def environment_variables_from_selectors(self, env_dict):
854
+ for name, label_value in env_dict.items():
855
+ self.worker.environment_variable_from_selector(name, label_value)
856
+ self.control.environment_variable_from_selector(name, label_value)
857
+ return self
858
+
859
+ def environment_variable(self, name, value):
860
+ self.worker.environment_variable(name, value)
861
+ self.control.environment_variable(name, value)
862
+ return self
863
+
864
+ def label(self, name, value):
865
+ self.worker.label(name, value)
866
+ self.control.label(name, value)
867
+ self._labels = dict(self._labels, **{name: value})
868
+ return self
869
+
870
+ def annotation(self, name, value):
871
+ self.worker.annotation(name, value)
872
+ self.control.annotation(name, value)
873
+ self._annotations = dict(self._annotations, **{name: value})
874
+ return self
875
+
876
+ def labels(self, labels):
877
+ for k, v in labels.items():
878
+ self.label(k, v)
879
+ return self
880
+
881
+ def annotations(self, annotations):
882
+ for k, v in annotations.items():
883
+ self.annotation(k, v)
884
+ return self
885
+
886
+ def secret(self, name):
887
+ self.worker.secret(name)
888
+ self.control.secret(name)
889
+ return self
890
+
891
+ def dump(self):
892
+ client = self._client.get()
893
+ return dict(
894
+ apiVersion=self._group + "/" + self._version,
895
+ kind="JobSet",
896
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
897
+ client.V1ObjectMeta(
898
+ name=self.name,
899
+ labels=self._labels,
900
+ annotations=self._annotations,
901
+ )
902
+ ),
903
+ spec=dict(
904
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
905
+ suspend=False,
906
+ startupPolicy=dict(
907
+ # We explicitly set an InOrder Startup policy so that
908
+ # we can ensure that the control pod starts before the worker pods.
909
+ # This is required so that when worker pods try to access the control's IP
910
+ # we are able to resolve the control's IP address.
911
+ startupPolicyOrder="InOrder"
912
+ ),
913
+ successPolicy=None,
914
+ # The Failure Policy helps setting the number of retries for the jobset.
915
+ # but we don't rely on it and instead rely on either the local scheduler
916
+ # or the Argo Workflows to handle retries.
917
+ failurePolicy=None,
918
+ network=None,
919
+ ),
920
+ status=None,
921
+ )
922
+
923
+ def execute(self):
924
+ client = self._client.get()
925
+ api_instance = client.CoreV1Api()
926
+
927
+ with client.ApiClient() as api_client:
928
+ api_instance = client.CustomObjectsApi(api_client)
929
+ try:
930
+ jobset_obj = api_instance.create_namespaced_custom_object(
931
+ group=self._group,
932
+ version=self._version,
933
+ namespace=self._namespace,
934
+ plural="jobsets",
935
+ body=self.dump(),
936
+ )
937
+ except Exception as e:
938
+ raise KubernetesJobsetException(
939
+ "Exception when calling CustomObjectsApi->create_namespaced_custom_object: %s\n"
940
+ % e
941
+ )
942
+
943
+ return RunningJobSet(
944
+ client=self._client,
945
+ name=jobset_obj["metadata"]["name"],
946
+ namespace=jobset_obj["metadata"]["namespace"],
947
+ group=self._group,
948
+ version=self._version,
949
+ )
950
+
951
+
952
+ class KubernetesArgoJobSet(object):
953
+ def __init__(self, kubernetes_sdk, name=None, namespace=None, **kwargs):
954
+ self._kubernetes_sdk = kubernetes_sdk
955
+ self._annotations = {}
956
+ self._labels = {}
957
+ self._group = KUBERNETES_JOBSET_GROUP
958
+ self._version = KUBERNETES_JOBSET_VERSION
959
+ self._namespace = namespace
960
+ self.name = name
961
+
962
+ self._jobset_control_addr = _make_domain_name(
963
+ name,
964
+ "control",
965
+ 0,
966
+ 0,
967
+ namespace,
968
+ )
969
+
970
+ self._control_spec = JobSetSpec(
971
+ kubernetes_sdk, name="control", namespace=namespace, **kwargs
972
+ )
973
+ self._worker_spec = JobSetSpec(
974
+ kubernetes_sdk, name="worker", namespace=namespace, **kwargs
975
+ )
976
+
977
+ @property
978
+ def jobset_control_addr(self):
979
+ return self._jobset_control_addr
980
+
981
+ @property
982
+ def worker(self):
983
+ return self._worker_spec
984
+
985
+ @property
986
+ def control(self):
987
+ return self._control_spec
988
+
989
+ def environment_variable_from_selector(self, name, label_value):
990
+ self.worker.environment_variable_from_selector(name, label_value)
991
+ self.control.environment_variable_from_selector(name, label_value)
992
+ return self
993
+
994
+ def environment_variables_from_selectors(self, env_dict):
995
+ for name, label_value in env_dict.items():
996
+ self.worker.environment_variable_from_selector(name, label_value)
997
+ self.control.environment_variable_from_selector(name, label_value)
998
+ return self
999
+
1000
+ def environment_variable(self, name, value):
1001
+ self.worker.environment_variable(name, value)
1002
+ self.control.environment_variable(name, value)
1003
+ return self
1004
+
1005
+ def label(self, name, value):
1006
+ self.worker.label(name, value)
1007
+ self.control.label(name, value)
1008
+ self._labels = dict(self._labels, **{name: value})
1009
+ return self
1010
+
1011
+ def labels(self, labels):
1012
+ for k, v in labels.items():
1013
+ self.label(k, v)
1014
+ return self
1015
+
1016
+ def annotation(self, name, value):
1017
+ self.worker.annotation(name, value)
1018
+ self.control.annotation(name, value)
1019
+ self._annotations = dict(self._annotations, **{name: value})
1020
+ return self
1021
+
1022
+ def annotations(self, annotations):
1023
+ for k, v in annotations.items():
1024
+ self.annotation(k, v)
1025
+ return self
1026
+
1027
+ def dump(self):
1028
+ client = self._kubernetes_sdk
1029
+ js_dict = client.ApiClient().sanitize_for_serialization(
1030
+ dict(
1031
+ apiVersion=self._group + "/" + self._version,
1032
+ kind="JobSet",
1033
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
1034
+ client.V1ObjectMeta(
1035
+ name=self.name,
1036
+ labels=self._labels,
1037
+ annotations=self._annotations,
1038
+ )
1039
+ ),
1040
+ spec=dict(
1041
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
1042
+ suspend=False,
1043
+ startupPolicy=None,
1044
+ successPolicy=None,
1045
+ # The Failure Policy helps setting the number of retries for the jobset.
1046
+ # but we don't rely on it and instead rely on either the local scheduler
1047
+ # or the Argo Workflows to handle retries.
1048
+ failurePolicy=None,
1049
+ network=None,
1050
+ ),
1051
+ status=None,
1052
+ )
1053
+ )
1054
+ data = yaml.dump(js_dict, default_flow_style=False, indent=2)
1055
+ # The values we populate in the Jobset manifest (for Argo Workflows) piggybacks on the Argo Workflow's templating engine.
1056
+ # Even though Argo Workflows's templating helps us constructing all the necessary IDs and populating the fields
1057
+ # required by Metaflow, we run into one glitch. When we construct JSON/YAML serializable objects,
1058
+ # anything between two braces such as `{{=asInt(inputs.parameters.workerCount)}}` gets quoted. This is a problem
1059
+ # since we need to pass the value of `inputs.parameters.workerCount` as an integer and not as a string.
1060
+ # If we pass it as a string, the jobset controller will not accept the Jobset CRD we submitted to kubernetes.
1061
+ # To get around this, we need to replace the quoted substring with the unquoted substring because YAML /JSON parsers
1062
+ # won't allow deserialization with the quoting trivially.
1063
+
1064
+ # This is super important because the `inputs.parameters.workerCount` is used to set the number of replicas;
1065
+ # The value for number of replicas is derived from the value of `num_parallel` (which is set in the user-code).
1066
+ # Since the value of `num_parallel` can be dynamic and can change from run to run, we need to ensure that the
1067
+ # value can be passed-down dynamically and is **explicitly set as a integer** in the Jobset Manifest submitted as a
1068
+ # part of the Argo Workflow
1069
+ quoted_substring = "'{{=asInt(inputs.parameters.workerCount)}}'"
1070
+ unquoted_substring = "{{=asInt(inputs.parameters.workerCount)}}"
1071
+ return data.replace(quoted_substring, unquoted_substring)