skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -6,18 +6,18 @@ ManagedJobCodeGen.
6
6
  """
7
7
  import collections
8
8
  import enum
9
- import inspect
10
9
  import os
11
10
  import pathlib
12
11
  import shlex
13
- import shutil
14
12
  import textwrap
15
13
  import time
14
+ import traceback
16
15
  import typing
17
- from typing import Any, Dict, List, Optional, Tuple, Union
16
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
18
17
 
19
18
  import colorama
20
19
  import filelock
20
+ import psutil
21
21
  from typing_extensions import Literal
22
22
 
23
23
  from sky import backends
@@ -26,14 +26,18 @@ from sky import global_user_state
26
26
  from sky import sky_logging
27
27
  from sky.backends import backend_utils
28
28
  from sky.jobs import constants as managed_job_constants
29
+ from sky.jobs import scheduler
29
30
  from sky.jobs import state as managed_job_state
30
31
  from sky.skylet import constants
31
32
  from sky.skylet import job_lib
32
33
  from sky.skylet import log_lib
34
+ from sky.usage import usage_lib
33
35
  from sky.utils import common_utils
34
36
  from sky.utils import log_utils
37
+ from sky.utils import message_utils
35
38
  from sky.utils import rich_utils
36
39
  from sky.utils import subprocess_utils
40
+ from sky.utils import ux_utils
37
41
 
38
42
  if typing.TYPE_CHECKING:
39
43
  import sky
@@ -41,14 +45,7 @@ if typing.TYPE_CHECKING:
41
45
 
42
46
  logger = sky_logging.init_logger(__name__)
43
47
 
44
- # Add user hash so that two users don't have the same controller VM on
45
- # shared-account clouds such as GCP.
46
- JOB_CONTROLLER_NAME: str = (
47
- f'sky-jobs-controller-{common_utils.get_user_hash()}')
48
- LEGACY_JOB_CONTROLLER_NAME: str = (
49
- f'sky-spot-controller-{common_utils.get_user_hash()}')
50
48
  SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
51
- LEGACY_SIGNAL_FILE_PREFIX = '/tmp/sky_spot_controller_signal_{}'
52
49
  # Controller checks its job's status every this many seconds.
53
50
  JOB_STATUS_CHECK_GAP_SECONDS = 20
54
51
 
@@ -57,17 +54,21 @@ JOB_STARTED_STATUS_CHECK_GAP_SECONDS = 5
57
54
 
58
55
  _LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS = 5
59
56
 
60
- _JOB_WAITING_STATUS_MESSAGE = ('[bold cyan]Waiting for the task to start'
61
- '{status_str}.[/] It may take a few minutes.')
57
+ _JOB_WAITING_STATUS_MESSAGE = ux_utils.spinner_message(
58
+ 'Waiting for task to start[/]'
59
+ '{status_str}. It may take a few minutes.\n'
60
+ ' [dim]View controller logs: sky jobs logs --controller {job_id}')
62
61
  _JOB_CANCELLED_MESSAGE = (
63
- '[bold cyan]Waiting for the task status to be updated.'
64
- '[/] It may take a minute.')
62
+ ux_utils.spinner_message('Waiting for task status to be updated.') +
63
+ ' It may take a minute.')
65
64
 
66
65
  # The maximum time to wait for the managed job status to transition to terminal
67
66
  # state, after the job finished. This is a safeguard to avoid the case where
68
67
  # the managed job status fails to be updated and keep the `sky jobs logs`
69
- # blocking for a long time.
70
- _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 20
68
+ # blocking for a long time. This should be significantly longer than the
69
+ # JOB_STATUS_CHECK_GAP_SECONDS to avoid timing out before the controller can
70
+ # update the state.
71
+ _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 40
71
72
 
72
73
 
73
74
  class UserSignal(enum.Enum):
@@ -78,11 +79,50 @@ class UserSignal(enum.Enum):
78
79
 
79
80
 
80
81
  # ====== internal functions ======
82
+ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
83
+ """Terminate the cluster."""
84
+ from sky import core # pylint: disable=import-outside-toplevel
85
+ retry_cnt = 0
86
+ # In some cases, e.g. botocore.exceptions.NoCredentialsError due to AWS
87
+ # metadata service throttling, the failed sky.down attempt can take 10-11
88
+ # seconds. In this case, we need the backoff to significantly reduce the
89
+ # rate of requests - that is, significantly increase the time between
90
+ # requests. We set the initial backoff to 15 seconds, so that once it grows
91
+ # exponentially it will quickly dominate the 10-11 seconds that we already
92
+ # see between requests. We set the max backoff very high, since it's
93
+ # generally much more important to eventually succeed than to fail fast.
94
+ backoff = common_utils.Backoff(
95
+ initial_backoff=15,
96
+ # 1.6 ** 5 = 10.48576 < 20, so we won't hit this with default max_retry
97
+ max_backoff_factor=20)
98
+ while True:
99
+ try:
100
+ usage_lib.messages.usage.set_internal()
101
+ core.down(cluster_name)
102
+ return
103
+ except exceptions.ClusterDoesNotExist:
104
+ # The cluster is already down.
105
+ logger.debug(f'The cluster {cluster_name} is already down.')
106
+ return
107
+ except Exception as e: # pylint: disable=broad-except
108
+ retry_cnt += 1
109
+ if retry_cnt >= max_retry:
110
+ raise RuntimeError(
111
+ f'Failed to terminate the cluster {cluster_name}.') from e
112
+ logger.error(
113
+ f'Failed to terminate the cluster {cluster_name}. Retrying.'
114
+ f'Details: {common_utils.format_exception(e)}')
115
+ with ux_utils.enable_traceback():
116
+ logger.error(f' Traceback: {traceback.format_exc()}')
117
+ time.sleep(backoff.current_backoff())
118
+
119
+
81
120
  def get_job_status(backend: 'backends.CloudVmRayBackend',
82
121
  cluster_name: str) -> Optional['job_lib.JobStatus']:
83
122
  """Check the status of the job running on a managed job cluster.
84
123
 
85
- It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_SETUP or CANCELLED.
124
+ It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
125
+ FAILED_SETUP or CANCELLED.
86
126
  """
87
127
  handle = global_user_state.get_handle_from_cluster_name(cluster_name)
88
128
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
@@ -101,57 +141,222 @@ def get_job_status(backend: 'backends.CloudVmRayBackend',
101
141
  return status
102
142
 
103
143
 
104
- def update_managed_job_status(job_id: Optional[int] = None):
105
- """Update managed job status if the controller job failed abnormally.
144
+ def _controller_process_alive(pid: int, job_id: int) -> bool:
145
+ """Check if the controller process is alive."""
146
+ try:
147
+ process = psutil.Process(pid)
148
+ # The last two args of the command line should be --job-id <id>
149
+ job_args = process.cmdline()[-2:]
150
+ return process.is_running() and job_args == ['--job-id', str(job_id)]
151
+ except psutil.NoSuchProcess:
152
+ return False
153
+
154
+
155
+ def update_managed_jobs_statuses(job_id: Optional[int] = None):
156
+ """Update managed job status if the controller process failed abnormally.
157
+
158
+ Check the status of the controller process. If it is not running, it must
159
+ have exited abnormally, and we should set the job status to
160
+ FAILED_CONTROLLER. `end_at` will be set to the current timestamp for the job
161
+ when above happens, which could be not accurate based on the frequency this
162
+ function is called.
106
163
 
107
- Check the status of the controller job. If it is not running, it must have
108
- exited abnormally, and we should set the job status to FAILED_CONTROLLER.
109
- `end_at` will be set to the current timestamp for the job when above
110
- happens, which could be not accurate based on the frequency this function
111
- is called.
164
+ Note: we expect that job_id, if provided, refers to a nonterminal job or a
165
+ job that has not completed its cleanup (schedule state not DONE).
112
166
  """
113
- if job_id is None:
114
- job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None)
115
- else:
116
- job_ids = [job_id]
117
- for job_id_ in job_ids:
118
- controller_status = job_lib.get_status(job_id_)
167
+
168
+ def _cleanup_job_clusters(job_id: int) -> Optional[str]:
169
+ """Clean up clusters for a job. Returns error message if any.
170
+
171
+ This function should not throw any exception. If it fails, it will
172
+ capture the error message, and log/return it.
173
+ """
174
+ error_msg = None
175
+ tasks = managed_job_state.get_managed_jobs(job_id)
176
+ for task in tasks:
177
+ task_name = task['job_name']
178
+ cluster_name = generate_managed_job_cluster_name(task_name, job_id)
179
+ handle = global_user_state.get_handle_from_cluster_name(
180
+ cluster_name)
181
+ if handle is not None:
182
+ try:
183
+ terminate_cluster(cluster_name)
184
+ except Exception as e: # pylint: disable=broad-except
185
+ error_msg = (
186
+ f'Failed to terminate cluster {cluster_name}: '
187
+ f'{common_utils.format_exception(e, use_bracket=True)}')
188
+ logger.exception(error_msg, exc_info=e)
189
+ return error_msg
190
+
191
+ # For backwards compatible jobs
192
+ # TODO(cooperc): Remove before 0.11.0.
193
+ def _handle_legacy_job(job_id: int):
194
+ controller_status = job_lib.get_status(job_id)
119
195
  if controller_status is None or controller_status.is_terminal():
120
- logger.error(f'Controller for job {job_id_} has exited abnormally. '
121
- 'Setting the job status to FAILED_CONTROLLER.')
122
- tasks = managed_job_state.get_managed_jobs(job_id_)
123
- for task in tasks:
124
- task_name = task['job_name']
125
- # Tear down the abnormal cluster to avoid resource leakage.
126
- cluster_name = generate_managed_job_cluster_name(
127
- task_name, job_id_)
128
- handle = global_user_state.get_handle_from_cluster_name(
129
- cluster_name)
130
- if handle is not None:
131
- backend = backend_utils.get_backend_from_handle(handle)
132
- max_retry = 3
133
- for retry_cnt in range(max_retry):
134
- try:
135
- backend.teardown(handle, terminate=True)
136
- break
137
- except RuntimeError:
138
- logger.error('Failed to tear down the cluster '
139
- f'{cluster_name!r}. Retrying '
140
- f'[{retry_cnt}/{max_retry}].')
141
-
142
- # The controller job for this managed job is not running: it must
143
- # have exited abnormally, and we should set the job status to
144
- # FAILED_CONTROLLER.
145
- # The `set_failed` will only update the task's status if the
146
- # status is non-terminal.
196
+ logger.error(f'Controller process for legacy job {job_id} is '
197
+ 'in an unexpected state.')
198
+
199
+ cleanup_error = _cleanup_job_clusters(job_id)
200
+ if cleanup_error:
201
+ # Unconditionally set the job to failed_controller if the
202
+ # cleanup fails.
203
+ managed_job_state.set_failed(
204
+ job_id,
205
+ task_id=None,
206
+ failure_type=managed_job_state.ManagedJobStatus.
207
+ FAILED_CONTROLLER,
208
+ failure_reason=
209
+ 'Legacy controller process has exited abnormally, and '
210
+ f'cleanup failed: {cleanup_error}. For more details, run: '
211
+ f'sky jobs logs --controller {job_id}',
212
+ override_terminal=True)
213
+ return
214
+
215
+ # It's possible for the job to have transitioned to
216
+ # another terminal state while between when we checked its
217
+ # state and now. In that case, set_failed won't do
218
+ # anything, which is fine.
147
219
  managed_job_state.set_failed(
148
- job_id_,
220
+ job_id,
149
221
  task_id=None,
150
222
  failure_type=managed_job_state.ManagedJobStatus.
151
223
  FAILED_CONTROLLER,
152
- failure_reason=
153
- 'Controller process has exited abnormally. For more details,'
154
- f' run: sky jobs logs --controller {job_id_}')
224
+ failure_reason=(
225
+ 'Legacy controller process has exited abnormally. For '
226
+ f'more details, run: sky jobs logs --controller {job_id}'))
227
+
228
+ # Get jobs that need checking (non-terminal or not DONE)
229
+ job_ids = managed_job_state.get_jobs_to_check_status(job_id)
230
+ if not job_ids:
231
+ # job_id is already terminal, or if job_id is None, there are no jobs
232
+ # that need to be checked.
233
+ return
234
+
235
+ for job_id in job_ids:
236
+ tasks = managed_job_state.get_managed_jobs(job_id)
237
+ # Note: controller_pid and schedule_state are in the job_info table
238
+ # which is joined to the spot table, so all tasks with the same job_id
239
+ # will have the same value for these columns. This is what lets us just
240
+ # take tasks[0]['controller_pid'] and tasks[0]['schedule_state'].
241
+ schedule_state = tasks[0]['schedule_state']
242
+
243
+ # Backwards compatibility: this job was submitted when ray was still
244
+ # used for managing the parallelism of job controllers, before #4485.
245
+ # TODO(cooperc): Remove before 0.11.0.
246
+ if (schedule_state is
247
+ managed_job_state.ManagedJobScheduleState.INVALID):
248
+ _handle_legacy_job(job_id)
249
+ continue
250
+
251
+ # Handle jobs with schedule state (non-legacy jobs):
252
+ pid = tasks[0]['controller_pid']
253
+ if schedule_state == managed_job_state.ManagedJobScheduleState.DONE:
254
+ # There are two cases where we could get a job that is DONE.
255
+ # 1. At query time (get_jobs_to_check_status), the job was not yet
256
+ # DONE, but since then (before get_managed_jobs is called) it has
257
+ # hit a terminal status, marked itself done, and exited. This is
258
+ # fine.
259
+ # 2. The job is DONE, but in a non-terminal status. This is
260
+ # unexpected. For instance, the task status is RUNNING, but the
261
+ # job schedule_state is DONE.
262
+ if all(task['status'].is_terminal() for task in tasks):
263
+ # Turns out this job is fine, even though it got pulled by
264
+ # get_jobs_to_check_status. Probably case #1 above.
265
+ continue
266
+
267
+ logger.error(f'Job {job_id} has DONE schedule state, but some '
268
+ f'tasks are not terminal. Task statuses: '
269
+ f'{", ".join(task["status"].value for task in tasks)}')
270
+ failure_reason = ('Inconsistent internal job state. This is a bug.')
271
+ elif pid is None:
272
+ # Non-legacy job and controller process has not yet started.
273
+ controller_status = job_lib.get_status(job_id)
274
+ if controller_status == job_lib.JobStatus.FAILED_SETUP:
275
+ # We should fail the case where the controller status is
276
+ # FAILED_SETUP, as it is due to the failure of dependency setup
277
+ # on the controller.
278
+ # TODO(cooperc): We should also handle the case where controller
279
+ # status is FAILED_DRIVER or FAILED.
280
+ logger.error('Failed to setup the cloud dependencies for '
281
+ 'the managed job.')
282
+ elif (schedule_state in [
283
+ managed_job_state.ManagedJobScheduleState.INACTIVE,
284
+ managed_job_state.ManagedJobScheduleState.WAITING,
285
+ ]):
286
+ # It is expected that the controller hasn't been started yet.
287
+ continue
288
+ elif (schedule_state ==
289
+ managed_job_state.ManagedJobScheduleState.LAUNCHING):
290
+ # This is unlikely but technically possible. There's a brief
291
+ # period between marking job as scheduled (LAUNCHING) and
292
+ # actually launching the controller process and writing the pid
293
+ # back to the table.
294
+ # TODO(cooperc): Find a way to detect if we get stuck in this
295
+ # state.
296
+ logger.info(f'Job {job_id} is in {schedule_state.value} state, '
297
+ 'but controller process hasn\'t started yet.')
298
+ continue
299
+
300
+ logger.error(f'Expected to find a controller pid for state '
301
+ f'{schedule_state.value} but found none.')
302
+ failure_reason = f'No controller pid set for {schedule_state.value}'
303
+ else:
304
+ logger.debug(f'Checking controller pid {pid}')
305
+ if _controller_process_alive(pid, job_id):
306
+ # The controller is still running, so this job is fine.
307
+ continue
308
+
309
+ # Double check job is not already DONE before marking as failed, to
310
+ # avoid the race where the controller marked itself as DONE and
311
+ # exited between the state check and the pid check. Since the job
312
+ # controller process will mark itself DONE _before_ exiting, if it
313
+ # has exited and it's still not DONE now, it is abnormal.
314
+ if (managed_job_state.get_job_schedule_state(job_id) ==
315
+ managed_job_state.ManagedJobScheduleState.DONE):
316
+ # Never mind, the job is DONE now. This is fine.
317
+ continue
318
+
319
+ logger.error(f'Controller process for {job_id} seems to be dead.')
320
+ failure_reason = 'Controller process is dead'
321
+
322
+ # At this point, either pid is None or process is dead.
323
+
324
+ # The controller process for this managed job is not running: it must
325
+ # have exited abnormally, and we should set the job status to
326
+ # FAILED_CONTROLLER.
327
+ logger.error(f'Controller process for job {job_id} has exited '
328
+ 'abnormally. Setting the job status to FAILED_CONTROLLER.')
329
+
330
+ # Cleanup clusters and capture any errors.
331
+ cleanup_error = _cleanup_job_clusters(job_id)
332
+ cleanup_error_msg = ''
333
+ if cleanup_error:
334
+ cleanup_error_msg = f'Also, cleanup failed: {cleanup_error}. '
335
+
336
+ # Set all tasks to FAILED_CONTROLLER, regardless of current status.
337
+ # This may change a job from SUCCEEDED or another terminal state to
338
+ # FAILED_CONTROLLER. This is what we want - we are sure that this
339
+ # controller process crashed, so we want to capture that even if the
340
+ # underlying job succeeded.
341
+ # Note: 2+ invocations of update_managed_jobs_statuses could be running
342
+ # at the same time, so this could override the FAILED_CONTROLLER status
343
+ # set by another invocation of update_managed_jobs_statuses. That should
344
+ # be okay. The only difference could be that one process failed to clean
345
+ # up the cluster while the other succeeds. No matter which
346
+ # failure_reason ends up in the database, the outcome is acceptable.
347
+ # We assume that no other code path outside the controller process will
348
+ # update the job status.
349
+ managed_job_state.set_failed(
350
+ job_id,
351
+ task_id=None,
352
+ failure_type=managed_job_state.ManagedJobStatus.FAILED_CONTROLLER,
353
+ failure_reason=
354
+ f'Controller process has exited abnormally ({failure_reason}). '
355
+ f'{cleanup_error_msg}'
356
+ f'For more details, run: sky jobs logs --controller {job_id}',
357
+ override_terminal=True)
358
+
359
+ scheduler.job_done(job_id, idempotent=True)
155
360
 
156
361
 
157
362
  def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
@@ -167,10 +372,32 @@ def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
167
372
  subprocess_utils.handle_returncode(returncode, code,
168
373
  'Failed to get job time.',
169
374
  stdout + stderr)
170
- stdout = common_utils.decode_payload(stdout)
375
+ stdout = message_utils.decode_payload(stdout)
171
376
  return float(stdout)
172
377
 
173
378
 
379
+ def try_to_get_job_end_time(backend: 'backends.CloudVmRayBackend',
380
+ cluster_name: str) -> float:
381
+ """Try to get the end time of the job.
382
+
383
+ If the job is preempted or we can't connect to the instance for whatever
384
+ reason, fall back to the current time.
385
+ """
386
+ try:
387
+ return get_job_timestamp(backend, cluster_name, get_end_time=True)
388
+ except exceptions.CommandError as e:
389
+ if e.returncode == 255:
390
+ # Failed to connect - probably the instance was preempted since the
391
+ # job completed. We shouldn't crash here, so just log and use the
392
+ # current time.
393
+ logger.info(f'Failed to connect to the instance {cluster_name} '
394
+ 'since the job completed. Assuming the instance '
395
+ 'was preempted.')
396
+ return time.time()
397
+ else:
398
+ raise
399
+
400
+
174
401
  def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
175
402
  """Run event callback for the task."""
176
403
 
@@ -222,19 +449,21 @@ def generate_managed_job_cluster_name(task_name: str, job_id: int) -> str:
222
449
  return f'{cluster_name}-{job_id}'
223
450
 
224
451
 
225
- def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
452
+ def cancel_jobs_by_id(job_ids: Optional[List[int]],
453
+ all_users: bool = False) -> str:
226
454
  """Cancel jobs by id.
227
455
 
228
456
  If job_ids is None, cancel all jobs.
229
457
  """
230
458
  if job_ids is None:
231
- job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None)
459
+ job_ids = managed_job_state.get_nonterminal_job_ids_by_name(
460
+ None, all_users)
232
461
  job_ids = list(set(job_ids))
233
- if len(job_ids) == 0:
462
+ if not job_ids:
234
463
  return 'No job to cancel.'
235
464
  job_id_str = ', '.join(map(str, job_ids))
236
465
  logger.info(f'Cancelling jobs {job_id_str}.')
237
- cancelled_job_ids = []
466
+ cancelled_job_ids: List[int] = []
238
467
  for job_id in job_ids:
239
468
  # Check the status of the managed job status. If it is in
240
469
  # terminal state, we can safely skip it.
@@ -247,24 +476,19 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
247
476
  f'{job_status.value}. Skipped.')
248
477
  continue
249
478
 
250
- update_managed_job_status(job_id)
479
+ update_managed_jobs_statuses(job_id)
251
480
 
252
481
  # Send the signal to the jobs controller.
253
482
  signal_file = pathlib.Path(SIGNAL_FILE_PREFIX.format(job_id))
254
- legacy_signal_file = pathlib.Path(
255
- LEGACY_SIGNAL_FILE_PREFIX.format(job_id))
256
483
  # Filelock is needed to prevent race condition between signal
257
484
  # check/removal and signal writing.
258
485
  with filelock.FileLock(str(signal_file) + '.lock'):
259
486
  with signal_file.open('w', encoding='utf-8') as f:
260
487
  f.write(UserSignal.CANCEL.value)
261
488
  f.flush()
262
- # Backward compatibility for managed jobs launched before #3419. It
263
- # can be removed in the future 0.8.0 release.
264
- shutil.copy(str(signal_file), str(legacy_signal_file))
265
489
  cancelled_job_ids.append(job_id)
266
490
 
267
- if len(cancelled_job_ids) == 0:
491
+ if not cancelled_job_ids:
268
492
  return 'No job to cancel.'
269
493
  identity_str = f'Job with ID {cancelled_job_ids[0]} is'
270
494
  if len(cancelled_job_ids) > 1:
@@ -277,7 +501,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
277
501
  def cancel_job_by_name(job_name: str) -> str:
278
502
  """Cancel a job by name."""
279
503
  job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
280
- if len(job_ids) == 0:
504
+ if not job_ids:
281
505
  return f'No running job found with name {job_name!r}.'
282
506
  if len(job_ids) > 1:
283
507
  return (f'{colorama.Fore.RED}Multiple running jobs found '
@@ -289,52 +513,57 @@ def cancel_job_by_name(job_name: str) -> str:
289
513
 
290
514
  def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
291
515
  """Stream logs by job id."""
292
- controller_status = job_lib.get_status(job_id)
293
- status_msg = ('[bold cyan]Waiting for controller process to be RUNNING'
294
- '{status_str}[/].')
295
- status_display = rich_utils.safe_status(status_msg.format(status_str=''))
516
+
517
+ def should_keep_logging(status: managed_job_state.ManagedJobStatus) -> bool:
518
+ # If we see CANCELLING, just exit - we could miss some job logs but the
519
+ # job will be terminated momentarily anyway so we don't really care.
520
+ return (not status.is_terminal() and
521
+ status != managed_job_state.ManagedJobStatus.CANCELLING)
522
+
523
+ msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str='', job_id=job_id)
524
+ status_display = rich_utils.safe_status(msg)
296
525
  num_tasks = managed_job_state.get_num_tasks(job_id)
297
526
 
298
527
  with status_display:
299
- prev_msg = None
300
- while (controller_status != job_lib.JobStatus.RUNNING and
301
- (controller_status is None or
302
- not controller_status.is_terminal())):
303
- status_str = 'None'
304
- if controller_status is not None:
305
- status_str = controller_status.value
306
- msg = status_msg.format(status_str=f' (status: {status_str})')
307
- if msg != prev_msg:
308
- status_display.update(msg)
309
- prev_msg = msg
310
- time.sleep(_LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS)
311
- controller_status = job_lib.get_status(job_id)
312
-
313
- msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str='')
314
- status_display.update(msg)
315
528
  prev_msg = msg
316
- managed_job_status = managed_job_state.get_status(job_id)
317
- while managed_job_status is None:
529
+ while (managed_job_status :=
530
+ managed_job_state.get_status(job_id)) is None:
318
531
  time.sleep(1)
319
- managed_job_status = managed_job_state.get_status(job_id)
320
532
 
321
- if managed_job_status.is_terminal():
533
+ if not should_keep_logging(managed_job_status):
322
534
  job_msg = ''
323
535
  if managed_job_status.is_failed():
324
536
  job_msg = ('\nFailure reason: '
325
537
  f'{managed_job_state.get_failure_reason(job_id)}')
538
+ log_file = managed_job_state.get_local_log_file(job_id, None)
539
+ if log_file is not None:
540
+ with open(os.path.expanduser(log_file), 'r',
541
+ encoding='utf-8') as f:
542
+ # Stream the logs to the console without reading the whole
543
+ # file into memory.
544
+ start_streaming = False
545
+ for line in f:
546
+ if log_lib.LOG_FILE_START_STREAMING_AT in line:
547
+ start_streaming = True
548
+ if start_streaming:
549
+ print(line, end='', flush=True)
550
+ return ''
326
551
  return (f'{colorama.Fore.YELLOW}'
327
552
  f'Job {job_id} is already in terminal state '
328
- f'{managed_job_status.value}. Logs will not be shown.'
329
- f'{colorama.Style.RESET_ALL}{job_msg}')
553
+ f'{managed_job_status.value}. For more details, run: '
554
+ f'sky jobs logs --controller {job_id}'
555
+ f'{colorama.Style.RESET_ALL}'
556
+ f'{job_msg}')
330
557
  backend = backends.CloudVmRayBackend()
331
558
  task_id, managed_job_status = (
332
559
  managed_job_state.get_latest_task_id_status(job_id))
333
560
 
334
- # task_id and managed_job_status can be None if the controller process
335
- # just started and the managed job status has not set to PENDING yet.
336
- while (managed_job_status is None or
337
- not managed_job_status.is_terminal()):
561
+ # We wait for managed_job_status to be not None above. Once we see that
562
+ # it's not None, we don't expect it to every become None again.
563
+ assert managed_job_status is not None, (job_id, task_id,
564
+ managed_job_status)
565
+
566
+ while should_keep_logging(managed_job_status):
338
567
  handle = None
339
568
  if task_id is not None:
340
569
  task_name = managed_job_state.get_task_name(job_id, task_id)
@@ -356,15 +585,19 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
356
585
  logger.debug(
357
586
  f'INFO: The log is not ready yet{status_str}. '
358
587
  f'Waiting for {JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
359
- msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str=status_str)
588
+ msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str=status_str,
589
+ job_id=job_id)
360
590
  if msg != prev_msg:
361
591
  status_display.update(msg)
362
592
  prev_msg = msg
363
593
  time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
364
594
  task_id, managed_job_status = (
365
595
  managed_job_state.get_latest_task_id_status(job_id))
596
+ assert managed_job_status is not None, (job_id, task_id,
597
+ managed_job_status)
366
598
  continue
367
- assert managed_job_status is not None
599
+ assert (managed_job_status ==
600
+ managed_job_state.ManagedJobStatus.RUNNING)
368
601
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
369
602
  status_display.stop()
370
603
  returncode = backend.tail_logs(handle,
@@ -379,29 +612,76 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
379
612
  job_statuses = backend.get_job_status(handle, stream_logs=False)
380
613
  job_status = list(job_statuses.values())[0]
381
614
  assert job_status is not None, 'No job found.'
615
+ assert task_id is not None, job_id
616
+
382
617
  if job_status != job_lib.JobStatus.CANCELLED:
383
- assert task_id is not None, job_id
384
- if task_id < num_tasks - 1 and follow:
385
- # The log for the current job is finished. We need to
386
- # wait until next job to be started.
387
- logger.debug(
388
- f'INFO: Log for the current task ({task_id}) '
389
- 'is finished. Waiting for the next task\'s log '
390
- 'to be started.')
391
- status_display.update('Waiting for the next task: '
392
- f'{task_id + 1}.')
618
+ if not follow:
619
+ break
620
+
621
+ # Logs for retrying failed tasks.
622
+ if (job_status
623
+ in job_lib.JobStatus.user_code_failure_states()):
624
+ task_specs = managed_job_state.get_task_specs(
625
+ job_id, task_id)
626
+ if task_specs.get('max_restarts_on_errors', 0) == 0:
627
+ # We don't need to wait for the managed job status
628
+ # update, as the job is guaranteed to be in terminal
629
+ # state afterwards.
630
+ break
631
+ print()
632
+ status_display.update(
633
+ ux_utils.spinner_message(
634
+ 'Waiting for next restart for the failed task'))
393
635
  status_display.start()
394
- original_task_id = task_id
395
- while True:
396
- task_id, managed_job_status = (
397
- managed_job_state.get_latest_task_id_status(
398
- job_id))
399
- if original_task_id != task_id:
400
- break
636
+
637
+ def is_managed_job_status_updated(
638
+ status: Optional[managed_job_state.ManagedJobStatus]
639
+ ) -> bool:
640
+ """Check if local managed job status reflects remote
641
+ job failure.
642
+
643
+ Ensures synchronization between remote cluster
644
+ failure detection (JobStatus.FAILED) and controller
645
+ retry logic.
646
+ """
647
+ return (status !=
648
+ managed_job_state.ManagedJobStatus.RUNNING)
649
+
650
+ while not is_managed_job_status_updated(
651
+ managed_job_status :=
652
+ managed_job_state.get_status(job_id)):
401
653
  time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
654
+ assert managed_job_status is not None, (
655
+ job_id, managed_job_status)
402
656
  continue
403
- else:
657
+
658
+ if task_id == num_tasks - 1:
404
659
  break
660
+
661
+ # The log for the current job is finished. We need to
662
+ # wait until next job to be started.
663
+ logger.debug(
664
+ f'INFO: Log for the current task ({task_id}) '
665
+ 'is finished. Waiting for the next task\'s log '
666
+ 'to be started.')
667
+ # Add a newline to avoid the status display below
668
+ # removing the last line of the task output.
669
+ print()
670
+ status_display.update(
671
+ ux_utils.spinner_message(
672
+ f'Waiting for the next task: {task_id + 1}'))
673
+ status_display.start()
674
+ original_task_id = task_id
675
+ while True:
676
+ task_id, managed_job_status = (
677
+ managed_job_state.get_latest_task_id_status(job_id))
678
+ if original_task_id != task_id:
679
+ break
680
+ time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
681
+ assert managed_job_status is not None, (job_id, task_id,
682
+ managed_job_status)
683
+ continue
684
+
405
685
  # The job can be cancelled by the user or the controller (when
406
686
  # the cluster is partially preempted).
407
687
  logger.debug(
@@ -415,7 +695,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
415
695
  # state.
416
696
  managed_job_status = managed_job_state.get_status(job_id)
417
697
  assert managed_job_status is not None, job_id
418
- if managed_job_status.is_terminal():
698
+ if not should_keep_logging(managed_job_status):
419
699
  break
420
700
  logger.info(f'{colorama.Fore.YELLOW}The job cluster is preempted '
421
701
  f'or failed.{colorama.Style.RESET_ALL}')
@@ -430,6 +710,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
430
710
  # managed job state is updated.
431
711
  time.sleep(3 * JOB_STATUS_CHECK_GAP_SECONDS)
432
712
  managed_job_status = managed_job_state.get_status(job_id)
713
+ assert managed_job_status is not None, (job_id, managed_job_status)
433
714
 
434
715
  # The managed_job_status may not be in terminal status yet, since the
435
716
  # controller has not updated the managed job state yet. We wait for a while,
@@ -437,15 +718,16 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
437
718
  wait_seconds = 0
438
719
  managed_job_status = managed_job_state.get_status(job_id)
439
720
  assert managed_job_status is not None, job_id
440
- while (not managed_job_status.is_terminal() and follow and
721
+ while (should_keep_logging(managed_job_status) and follow and
441
722
  wait_seconds < _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS):
442
723
  time.sleep(1)
443
724
  wait_seconds += 1
444
725
  managed_job_status = managed_job_state.get_status(job_id)
445
726
  assert managed_job_status is not None, job_id
446
727
 
447
- logger.info(f'Logs finished for job {job_id} '
448
- f'(status: {managed_job_status.value}).')
728
+ logger.info(
729
+ ux_utils.finishing_message(f'Managed job finished: {job_id} '
730
+ f'(status: {managed_job_status.value}).'))
449
731
  return ''
450
732
 
451
733
 
@@ -458,6 +740,7 @@ def stream_logs(job_id: Optional[int],
458
740
  job_id = managed_job_state.get_latest_job_id()
459
741
  if job_id is None:
460
742
  return 'No managed job found.'
743
+
461
744
  if controller:
462
745
  if job_id is None:
463
746
  assert job_name is not None
@@ -465,32 +748,99 @@ def stream_logs(job_id: Optional[int],
465
748
  # We manually filter the jobs by name, instead of using
466
749
  # get_nonterminal_job_ids_by_name, as with `controller=True`, we
467
750
  # should be able to show the logs for jobs in terminal states.
468
- managed_jobs = list(
469
- filter(lambda job: job['job_name'] == job_name, managed_jobs))
470
- if len(managed_jobs) == 0:
751
+ managed_job_ids: Set[int] = {
752
+ job['job_id']
753
+ for job in managed_jobs
754
+ if job['job_name'] == job_name
755
+ }
756
+ if not managed_job_ids:
471
757
  return f'No managed job found with name {job_name!r}.'
472
- if len(managed_jobs) > 1:
473
- job_ids_str = ', '.join(job['job_id'] for job in managed_jobs)
474
- raise ValueError(
475
- f'Multiple managed jobs found with name {job_name!r} (Job '
476
- f'IDs: {job_ids_str}). Please specify the job_id instead.')
477
- job_id = managed_jobs[0]['job_id']
758
+ if len(managed_job_ids) > 1:
759
+ job_ids_str = ', '.join(
760
+ str(job_id) for job_id in managed_job_ids)
761
+ with ux_utils.print_exception_no_traceback():
762
+ raise ValueError(
763
+ f'Multiple managed jobs found with name {job_name!r} '
764
+ f'(Job IDs: {job_ids_str}). Please specify the job_id '
765
+ 'instead.')
766
+ job_id = managed_job_ids.pop()
478
767
  assert job_id is not None, (job_id, job_name)
479
- # TODO: keep the following code sync with
480
- # job_lib.JobLibCodeGen.tail_logs, we do not directly call that function
481
- # as the following code need to be run in the current machine, instead
482
- # of running remotely.
483
- run_timestamp = job_lib.get_run_timestamp(job_id)
484
- if run_timestamp is None:
485
- return f'No managed job contrller log found with job_id {job_id}.'
486
- log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp)
487
- log_lib.tail_logs(job_id=job_id, log_dir=log_dir, follow=follow)
768
+
769
+ controller_log_path = os.path.join(
770
+ os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR),
771
+ f'{job_id}.log')
772
+ job_status = None
773
+
774
+ # Wait for the log file to be written
775
+ while not os.path.exists(controller_log_path):
776
+ if not follow:
777
+ # Assume that the log file hasn't been written yet. Since we
778
+ # aren't following, just return.
779
+ return ''
780
+
781
+ job_status = managed_job_state.get_status(job_id)
782
+ if job_status is None:
783
+ with ux_utils.print_exception_no_traceback():
784
+ raise ValueError(f'Job {job_id} not found.')
785
+ if job_status.is_terminal():
786
+ # Don't keep waiting. If the log file is not created by this
787
+ # point, it never will be. This job may have been submitted
788
+ # using an old version that did not create the log file, so this
789
+ # is not considered an exceptional case.
790
+ return ''
791
+
792
+ time.sleep(log_lib.SKY_LOG_WAITING_GAP_SECONDS)
793
+
794
+ # This code is based on log_lib.tail_logs. We can't use that code
795
+ # exactly because state works differently between managed jobs and
796
+ # normal jobs.
797
+ with open(controller_log_path, 'r', newline='', encoding='utf-8') as f:
798
+ # Note: we do not need to care about start_stream_at here, since
799
+ # that should be in the job log printed above.
800
+ for line in f:
801
+ print(line, end='')
802
+ # Flush.
803
+ print(end='', flush=True)
804
+
805
+ if follow:
806
+ while True:
807
+ # Print all new lines, if there are any.
808
+ line = f.readline()
809
+ while line is not None and line != '':
810
+ print(line, end='')
811
+ line = f.readline()
812
+
813
+ # Flush.
814
+ print(end='', flush=True)
815
+
816
+ # Check if the job if finished.
817
+ # TODO(cooperc): The controller can still be
818
+ # cleaning up if job is in a terminal status
819
+ # (e.g. SUCCEEDED). We want to follow those logs
820
+ # too. Use DONE instead?
821
+ job_status = managed_job_state.get_status(job_id)
822
+ assert job_status is not None, (job_id, job_name)
823
+ if job_status.is_terminal():
824
+ break
825
+
826
+ time.sleep(log_lib.SKY_LOG_TAILING_GAP_SECONDS)
827
+
828
+ # Wait for final logs to be written.
829
+ time.sleep(1 + log_lib.SKY_LOG_TAILING_GAP_SECONDS)
830
+
831
+ # Print any remaining logs including incomplete line.
832
+ print(f.read(), end='', flush=True)
833
+
834
+ if follow:
835
+ return ux_utils.finishing_message(
836
+ f'Job finished (status: {job_status}).')
837
+
488
838
  return ''
489
839
 
490
840
  if job_id is None:
491
841
  assert job_name is not None
492
842
  job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
493
- if len(job_ids) == 0:
843
+ if not job_ids:
494
844
  return f'No running managed job found with name {job_name!r}.'
495
845
  if len(job_ids) > 1:
496
846
  raise ValueError(
@@ -520,6 +870,7 @@ def dump_managed_job_queue() -> str:
520
870
  job_duration = 0
521
871
  job['job_duration'] = job_duration
522
872
  job['status'] = job['status'].value
873
+ job['schedule_state'] = job['schedule_state'].value
523
874
 
524
875
  cluster_name = generate_managed_job_cluster_name(
525
876
  job['task_name'], job['job_id'])
@@ -534,12 +885,12 @@ def dump_managed_job_queue() -> str:
534
885
  job['cluster_resources'] = '-'
535
886
  job['region'] = '-'
536
887
 
537
- return common_utils.encode_payload(jobs)
888
+ return message_utils.encode_payload(jobs)
538
889
 
539
890
 
540
891
  def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
541
892
  """Load job queue from json string."""
542
- jobs = common_utils.decode_payload(payload)
893
+ jobs = message_utils.decode_payload(payload)
543
894
  for job in jobs:
544
895
  job['status'] = managed_job_state.ManagedJobStatus(job['status'])
545
896
  return jobs
@@ -568,6 +919,7 @@ def _get_job_status_from_tasks(
568
919
  @typing.overload
569
920
  def format_job_table(tasks: List[Dict[str, Any]],
570
921
  show_all: bool,
922
+ show_user: bool,
571
923
  return_rows: Literal[False] = False,
572
924
  max_jobs: Optional[int] = None) -> str:
573
925
  ...
@@ -576,6 +928,7 @@ def format_job_table(tasks: List[Dict[str, Any]],
576
928
  @typing.overload
577
929
  def format_job_table(tasks: List[Dict[str, Any]],
578
930
  show_all: bool,
931
+ show_user: bool,
579
932
  return_rows: Literal[True],
580
933
  max_jobs: Optional[int] = None) -> List[List[str]]:
581
934
  ...
@@ -584,6 +937,7 @@ def format_job_table(tasks: List[Dict[str, Any]],
584
937
  def format_job_table(
585
938
  tasks: List[Dict[str, Any]],
586
939
  show_all: bool,
940
+ show_user: bool,
587
941
  return_rows: bool = False,
588
942
  max_jobs: Optional[int] = None) -> Union[str, List[List[str]]]:
589
943
  """Returns managed jobs as a formatted string.
@@ -599,11 +953,21 @@ def format_job_table(
599
953
  a list of "rows" (each of which is a list of str).
600
954
  """
601
955
  jobs = collections.defaultdict(list)
956
+ # Check if the tasks have user information from kubernetes.
957
+ # This is only used for sky status --kubernetes.
958
+ tasks_have_k8s_user = any([task.get('user') for task in tasks])
959
+ if max_jobs and tasks_have_k8s_user:
960
+ raise ValueError('max_jobs is not supported when tasks have user info.')
961
+
962
+ def get_hash(task):
963
+ if tasks_have_k8s_user:
964
+ return (task['user'], task['job_id'])
965
+ return task['job_id']
966
+
602
967
  for task in tasks:
603
968
  # The tasks within the same job_id are already sorted
604
969
  # by the task_id.
605
- jobs[task['job_id']].append(task)
606
- jobs = dict(jobs)
970
+ jobs[get_hash(task)].append(task)
607
971
 
608
972
  status_counts: Dict[str, int] = collections.defaultdict(int)
609
973
  for job_tasks in jobs.values():
@@ -611,17 +975,29 @@ def format_job_table(
611
975
  if not managed_job_status.is_terminal():
612
976
  status_counts[managed_job_status.value] += 1
613
977
 
614
- if max_jobs is not None:
615
- job_ids = sorted(jobs.keys(), reverse=True)
616
- job_ids = job_ids[:max_jobs]
617
- jobs = {job_id: jobs[job_id] for job_id in job_ids}
978
+ user_cols: List[str] = []
979
+ if show_user:
980
+ user_cols = ['USER']
981
+ if show_all:
982
+ user_cols.append('USER_ID')
618
983
 
619
984
  columns = [
620
- 'ID', 'TASK', 'NAME', 'RESOURCES', 'SUBMITTED', 'TOT. DURATION',
621
- 'JOB DURATION', '#RECOVERIES', 'STATUS'
985
+ 'ID',
986
+ 'TASK',
987
+ 'NAME',
988
+ *user_cols,
989
+ 'RESOURCES',
990
+ 'SUBMITTED',
991
+ 'TOT. DURATION',
992
+ 'JOB DURATION',
993
+ '#RECOVERIES',
994
+ 'STATUS',
622
995
  ]
623
996
  if show_all:
624
- columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE']
997
+ # TODO: move SCHED. STATE to a separate flag (e.g. --debug)
998
+ columns += ['STARTED', 'CLUSTER', 'REGION', 'SCHED. STATE', 'DETAILS']
999
+ if tasks_have_k8s_user:
1000
+ columns.insert(0, 'USER')
625
1001
  job_table = log_utils.create_table(columns)
626
1002
 
627
1003
  status_counts: Dict[str, int] = collections.defaultdict(int)
@@ -636,9 +1012,33 @@ def format_job_table(
636
1012
  for task in all_tasks:
637
1013
  # The tasks within the same job_id are already sorted
638
1014
  # by the task_id.
639
- jobs[task['job_id']].append(task)
1015
+ jobs[get_hash(task)].append(task)
1016
+
1017
+ def generate_details(failure_reason: Optional[str]) -> str:
1018
+ if failure_reason is not None:
1019
+ return f'Failure: {failure_reason}'
1020
+ return '-'
1021
+
1022
+ def get_user_column_values(task: Dict[str, Any]) -> List[str]:
1023
+ user_values: List[str] = []
1024
+ if show_user:
1025
+
1026
+ user_name = '-'
1027
+ user_hash = task.get('user_hash', None)
1028
+ if user_hash:
1029
+ user = global_user_state.get_user(user_hash)
1030
+ user_name = user.name if user.name else '-'
1031
+ user_values = [user_name]
1032
+
1033
+ if show_all:
1034
+ user_values.append(user_hash if user_hash is not None else '-')
1035
+
1036
+ return user_values
1037
+
1038
+ for job_hash, job_tasks in jobs.items():
1039
+ if show_all:
1040
+ schedule_state = job_tasks[0]['schedule_state']
640
1041
 
641
- for job_id, job_tasks in jobs.items():
642
1042
  if len(job_tasks) > 1:
643
1043
  # Aggregate the tasks into a new row in the table.
644
1044
  job_name = job_tasks[0]['job_name']
@@ -661,7 +1061,6 @@ def format_job_table(
661
1061
  end_at = None
662
1062
  recovery_cnt += task['recovery_count']
663
1063
 
664
- failure_reason = job_tasks[current_task_id]['failure_reason']
665
1064
  job_duration = log_utils.readable_time_duration(0,
666
1065
  job_duration,
667
1066
  absolute=True)
@@ -674,10 +1073,14 @@ def format_job_table(
674
1073
  if not managed_job_status.is_terminal():
675
1074
  status_str += f' (task: {current_task_id})'
676
1075
 
1076
+ user_values = get_user_column_values(job_tasks[0])
1077
+
1078
+ job_id = job_hash[1] if tasks_have_k8s_user else job_hash
677
1079
  job_values = [
678
1080
  job_id,
679
1081
  '',
680
1082
  job_name,
1083
+ *user_values,
681
1084
  '-',
682
1085
  submitted,
683
1086
  total_duration,
@@ -686,12 +1089,16 @@ def format_job_table(
686
1089
  status_str,
687
1090
  ]
688
1091
  if show_all:
1092
+ failure_reason = job_tasks[current_task_id]['failure_reason']
689
1093
  job_values.extend([
690
1094
  '-',
691
1095
  '-',
692
1096
  '-',
693
- failure_reason if failure_reason is not None else '-',
1097
+ job_tasks[0]['schedule_state'],
1098
+ generate_details(failure_reason),
694
1099
  ])
1100
+ if tasks_have_k8s_user:
1101
+ job_values.insert(0, job_tasks[0].get('user', '-'))
695
1102
  job_table.add_row(job_values)
696
1103
 
697
1104
  for task in job_tasks:
@@ -700,10 +1107,12 @@ def format_job_table(
700
1107
  job_duration = log_utils.readable_time_duration(
701
1108
  0, task['job_duration'], absolute=True)
702
1109
  submitted = log_utils.readable_time_duration(task['submitted_at'])
1110
+ user_values = get_user_column_values(task)
703
1111
  values = [
704
1112
  task['job_id'] if len(job_tasks) == 1 else ' \u21B3',
705
1113
  task['task_id'] if len(job_tasks) > 1 else '-',
706
1114
  task['task_name'],
1115
+ *user_values,
707
1116
  task['resources'],
708
1117
  # SUBMITTED
709
1118
  submitted if submitted != '-' else submitted,
@@ -716,14 +1125,20 @@ def format_job_table(
716
1125
  task['status'].colored_str(),
717
1126
  ]
718
1127
  if show_all:
1128
+ # schedule_state is only set at the job level, so if we have
1129
+ # more than one task, only display on the aggregated row.
1130
+ schedule_state = (task['schedule_state']
1131
+ if len(job_tasks) == 1 else '-')
719
1132
  values.extend([
720
1133
  # STARTED
721
1134
  log_utils.readable_time_duration(task['start_at']),
722
1135
  task['cluster_resources'],
723
1136
  task['region'],
724
- task['failure_reason']
725
- if task['failure_reason'] is not None else '-',
1137
+ schedule_state,
1138
+ generate_details(task['failure_reason']),
726
1139
  ])
1140
+ if tasks_have_k8s_user:
1141
+ values.insert(0, task.get('user', '-'))
727
1142
  job_table.add_row(values)
728
1143
 
729
1144
  if len(job_tasks) > 1:
@@ -751,36 +1166,34 @@ class ManagedJobCodeGen:
751
1166
 
752
1167
  >> codegen = ManagedJobCodeGen.show_jobs(...)
753
1168
  """
754
- # TODO: the try..except.. block is for backward compatibility. Remove it in
755
- # v0.8.0.
756
1169
  _PREFIX = textwrap.dedent("""\
757
- managed_job_version = 0
758
- try:
759
- from sky.jobs import utils
760
- from sky.jobs import constants as managed_job_constants
761
- from sky.jobs import state as managed_job_state
762
-
763
- managed_job_version = managed_job_constants.MANAGED_JOBS_VERSION
764
- except ImportError:
765
- from sky.spot import spot_state as managed_job_state
766
- from sky.spot import spot_utils as utils
1170
+ from sky.jobs import utils
1171
+ from sky.jobs import state as managed_job_state
1172
+ from sky.jobs import constants as managed_job_constants
1173
+
1174
+ managed_job_version = managed_job_constants.MANAGED_JOBS_VERSION
767
1175
  """)
768
1176
 
769
1177
  @classmethod
770
1178
  def get_job_table(cls) -> str:
771
1179
  code = textwrap.dedent("""\
772
- if managed_job_version < 1:
773
- job_table = utils.dump_spot_job_queue()
774
- else:
775
- job_table = utils.dump_managed_job_queue()
1180
+ job_table = utils.dump_managed_job_queue()
776
1181
  print(job_table, flush=True)
777
1182
  """)
778
1183
  return cls._build(code)
779
1184
 
780
1185
  @classmethod
781
- def cancel_jobs_by_id(cls, job_ids: Optional[List[int]]) -> str:
1186
+ def cancel_jobs_by_id(cls,
1187
+ job_ids: Optional[List[int]],
1188
+ all_users: bool = False) -> str:
782
1189
  code = textwrap.dedent(f"""\
783
- msg = utils.cancel_jobs_by_id({job_ids})
1190
+ if managed_job_version < 2:
1191
+ # For backward compatibility, since all_users is not supported
1192
+ # before #4787. Assume th
1193
+ # TODO(cooperc): Remove compatibility before 0.12.0
1194
+ msg = utils.cancel_jobs_by_id({job_ids})
1195
+ else:
1196
+ msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users})
784
1197
  print(msg, end="", flush=True)
785
1198
  """)
786
1199
  return cls._build(code)
@@ -793,33 +1206,24 @@ class ManagedJobCodeGen:
793
1206
  """)
794
1207
  return cls._build(code)
795
1208
 
1209
+ @classmethod
1210
+ def get_all_job_ids_by_name(cls, job_name: Optional[str]) -> str:
1211
+ code = textwrap.dedent(f"""\
1212
+ from sky.utils import message_utils
1213
+ job_id = managed_job_state.get_all_job_ids_by_name({job_name!r})
1214
+ print(message_utils.encode_payload(job_id), end="", flush=True)
1215
+ """)
1216
+ return cls._build(code)
1217
+
796
1218
  @classmethod
797
1219
  def stream_logs(cls,
798
1220
  job_name: Optional[str],
799
1221
  job_id: Optional[int],
800
1222
  follow: bool = True,
801
1223
  controller: bool = False) -> str:
802
- # We inspect the source code of the function here for backward
803
- # compatibility.
804
- # TODO: change to utils.stream_logs(job_id, job_name, follow) in v0.8.0.
805
- # Import libraries required by `stream_logs`. The try...except... block
806
- # should be removed in v0.8.0.
807
- code = textwrap.dedent("""\
808
- import os
809
-
810
- from sky.skylet import job_lib, log_lib
811
- from sky.skylet import constants
812
- try:
813
- from sky.jobs.utils import stream_logs_by_id
814
- except ImportError:
815
- from sky.spot.spot_utils import stream_logs_by_id
816
- from typing import Optional
817
- """)
818
- code += inspect.getsource(stream_logs)
819
- code += textwrap.dedent(f"""\
820
-
821
- msg = stream_logs({job_id!r}, {job_name!r},
822
- follow={follow}, controller={controller})
1224
+ code = textwrap.dedent(f"""\
1225
+ msg = utils.stream_logs({job_id!r}, {job_name!r},
1226
+ follow={follow}, controller={controller})
823
1227
  print(msg, flush=True)
824
1228
  """)
825
1229
  return cls._build(code)
@@ -829,13 +1233,13 @@ class ManagedJobCodeGen:
829
1233
  dag_name = managed_job_dag.name
830
1234
  # Add the managed job to queue table.
831
1235
  code = textwrap.dedent(f"""\
832
- managed_job_state.set_job_name({job_id}, {dag_name!r})
1236
+ managed_job_state.set_job_info({job_id}, {dag_name!r})
833
1237
  """)
834
1238
  for task_id, task in enumerate(managed_job_dag.tasks):
835
1239
  resources_str = backend_utils.get_task_resources_str(
836
1240
  task, is_managed_job=True)
837
1241
  code += textwrap.dedent(f"""\
838
- managed_job_state.set_pending({job_id}, {task_id},
1242
+ managed_job_state.set_pending({job_id}, {task_id},
839
1243
  {task.name!r}, {resources_str!r})
840
1244
  """)
841
1245
  return cls._build(code)
@@ -843,4 +1247,9 @@ class ManagedJobCodeGen:
843
1247
  @classmethod
844
1248
  def _build(cls, code: str) -> str:
845
1249
  generated_code = cls._PREFIX + '\n' + code
846
- return f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}'
1250
+ # Use the local user id to make sure the operation goes to the correct
1251
+ # user.
1252
+ return (
1253
+ f'export {constants.USER_ID_ENV_VAR}='
1254
+ f'"{common_utils.get_user_hash()}"; '
1255
+ f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}')