skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (512) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +25 -7
  3. sky/adaptors/common.py +24 -1
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/hyperbolic.py +8 -0
  7. sky/adaptors/kubernetes.py +149 -18
  8. sky/adaptors/nebius.py +170 -17
  9. sky/adaptors/primeintellect.py +1 -0
  10. sky/adaptors/runpod.py +68 -0
  11. sky/adaptors/seeweb.py +167 -0
  12. sky/adaptors/shadeform.py +89 -0
  13. sky/admin_policy.py +187 -4
  14. sky/authentication.py +179 -225
  15. sky/backends/__init__.py +4 -2
  16. sky/backends/backend.py +22 -9
  17. sky/backends/backend_utils.py +1299 -380
  18. sky/backends/cloud_vm_ray_backend.py +1715 -518
  19. sky/backends/docker_utils.py +1 -1
  20. sky/backends/local_docker_backend.py +11 -6
  21. sky/backends/wheel_utils.py +37 -9
  22. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  23. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  24. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  25. sky/{clouds/service_catalog → catalog}/common.py +89 -48
  26. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  27. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  28. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +30 -40
  29. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +42 -15
  31. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  33. sky/catalog/data_fetchers/fetch_nebius.py +335 -0
  34. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  35. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  36. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  37. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  38. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  39. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  40. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  41. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  42. sky/catalog/hyperbolic_catalog.py +136 -0
  43. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  44. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  45. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  46. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  47. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  48. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  49. sky/catalog/primeintellect_catalog.py +95 -0
  50. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  51. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  52. sky/catalog/seeweb_catalog.py +184 -0
  53. sky/catalog/shadeform_catalog.py +165 -0
  54. sky/catalog/ssh_catalog.py +167 -0
  55. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  56. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  57. sky/check.py +491 -203
  58. sky/cli.py +5 -6005
  59. sky/client/{cli.py → cli/command.py} +2477 -1885
  60. sky/client/cli/deprecation_utils.py +99 -0
  61. sky/client/cli/flags.py +359 -0
  62. sky/client/cli/table_utils.py +320 -0
  63. sky/client/common.py +70 -32
  64. sky/client/oauth.py +82 -0
  65. sky/client/sdk.py +1203 -297
  66. sky/client/sdk_async.py +833 -0
  67. sky/client/service_account_auth.py +47 -0
  68. sky/cloud_stores.py +73 -0
  69. sky/clouds/__init__.py +13 -0
  70. sky/clouds/aws.py +358 -93
  71. sky/clouds/azure.py +105 -83
  72. sky/clouds/cloud.py +127 -36
  73. sky/clouds/cudo.py +68 -50
  74. sky/clouds/do.py +66 -48
  75. sky/clouds/fluidstack.py +63 -44
  76. sky/clouds/gcp.py +339 -110
  77. sky/clouds/hyperbolic.py +293 -0
  78. sky/clouds/ibm.py +70 -49
  79. sky/clouds/kubernetes.py +563 -162
  80. sky/clouds/lambda_cloud.py +74 -54
  81. sky/clouds/nebius.py +206 -80
  82. sky/clouds/oci.py +88 -66
  83. sky/clouds/paperspace.py +61 -44
  84. sky/clouds/primeintellect.py +317 -0
  85. sky/clouds/runpod.py +164 -74
  86. sky/clouds/scp.py +89 -83
  87. sky/clouds/seeweb.py +466 -0
  88. sky/clouds/shadeform.py +400 -0
  89. sky/clouds/ssh.py +263 -0
  90. sky/clouds/utils/aws_utils.py +10 -4
  91. sky/clouds/utils/gcp_utils.py +87 -11
  92. sky/clouds/utils/oci_utils.py +38 -14
  93. sky/clouds/utils/scp_utils.py +177 -124
  94. sky/clouds/vast.py +99 -77
  95. sky/clouds/vsphere.py +51 -40
  96. sky/core.py +349 -139
  97. sky/dag.py +15 -0
  98. sky/dashboard/out/404.html +1 -1
  99. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  100. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  101. sky/dashboard/out/_next/static/chunks/1871-74503c8e80fd253b.js +6 -0
  102. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  103. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  105. sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
  106. sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
  107. sky/dashboard/out/_next/static/chunks/3785.ad6adaa2a0fa9768.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  110. sky/dashboard/out/_next/static/chunks/4725.a830b5c9e7867c92.js +1 -0
  111. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  112. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  113. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  115. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  116. sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
  118. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  119. sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  121. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  122. sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  124. sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  126. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  127. sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
  128. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  129. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  131. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  133. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  134. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
  136. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a37d2063af475a1c.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/pages/clusters-d44859594e6f8064.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  139. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  141. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
  143. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-6edeb7d06032adfc.js +21 -0
  144. sky/dashboard/out/_next/static/chunks/pages/jobs-479dde13399cf270.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/users-5ab3b907622cf0fe.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  148. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c5a3eeee1c218af1.js +1 -0
  149. sky/dashboard/out/_next/static/chunks/pages/workspaces-22b23febb3e89ce1.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
  151. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  152. sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
  153. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  154. sky/dashboard/out/clusters/[cluster].html +1 -1
  155. sky/dashboard/out/clusters.html +1 -1
  156. sky/dashboard/out/config.html +1 -0
  157. sky/dashboard/out/index.html +1 -1
  158. sky/dashboard/out/infra/[context].html +1 -0
  159. sky/dashboard/out/infra.html +1 -0
  160. sky/dashboard/out/jobs/[job].html +1 -1
  161. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  162. sky/dashboard/out/jobs.html +1 -1
  163. sky/dashboard/out/users.html +1 -0
  164. sky/dashboard/out/volumes.html +1 -0
  165. sky/dashboard/out/workspace/new.html +1 -0
  166. sky/dashboard/out/workspaces/[name].html +1 -0
  167. sky/dashboard/out/workspaces.html +1 -0
  168. sky/data/data_utils.py +137 -1
  169. sky/data/mounting_utils.py +269 -84
  170. sky/data/storage.py +1451 -1807
  171. sky/data/storage_utils.py +43 -57
  172. sky/exceptions.py +132 -2
  173. sky/execution.py +206 -63
  174. sky/global_user_state.py +2374 -586
  175. sky/jobs/__init__.py +5 -0
  176. sky/jobs/client/sdk.py +242 -65
  177. sky/jobs/client/sdk_async.py +143 -0
  178. sky/jobs/constants.py +9 -8
  179. sky/jobs/controller.py +839 -277
  180. sky/jobs/file_content_utils.py +80 -0
  181. sky/jobs/log_gc.py +201 -0
  182. sky/jobs/recovery_strategy.py +398 -152
  183. sky/jobs/scheduler.py +315 -189
  184. sky/jobs/server/core.py +829 -255
  185. sky/jobs/server/server.py +156 -115
  186. sky/jobs/server/utils.py +136 -0
  187. sky/jobs/state.py +2092 -701
  188. sky/jobs/utils.py +1242 -160
  189. sky/logs/__init__.py +21 -0
  190. sky/logs/agent.py +108 -0
  191. sky/logs/aws.py +243 -0
  192. sky/logs/gcp.py +91 -0
  193. sky/metrics/__init__.py +0 -0
  194. sky/metrics/utils.py +443 -0
  195. sky/models.py +78 -1
  196. sky/optimizer.py +164 -70
  197. sky/provision/__init__.py +90 -4
  198. sky/provision/aws/config.py +147 -26
  199. sky/provision/aws/instance.py +135 -50
  200. sky/provision/azure/instance.py +10 -5
  201. sky/provision/common.py +13 -1
  202. sky/provision/cudo/cudo_machine_type.py +1 -1
  203. sky/provision/cudo/cudo_utils.py +14 -8
  204. sky/provision/cudo/cudo_wrapper.py +72 -71
  205. sky/provision/cudo/instance.py +10 -6
  206. sky/provision/do/instance.py +10 -6
  207. sky/provision/do/utils.py +4 -3
  208. sky/provision/docker_utils.py +114 -23
  209. sky/provision/fluidstack/instance.py +13 -8
  210. sky/provision/gcp/__init__.py +1 -0
  211. sky/provision/gcp/config.py +301 -19
  212. sky/provision/gcp/constants.py +218 -0
  213. sky/provision/gcp/instance.py +36 -8
  214. sky/provision/gcp/instance_utils.py +18 -4
  215. sky/provision/gcp/volume_utils.py +247 -0
  216. sky/provision/hyperbolic/__init__.py +12 -0
  217. sky/provision/hyperbolic/config.py +10 -0
  218. sky/provision/hyperbolic/instance.py +437 -0
  219. sky/provision/hyperbolic/utils.py +373 -0
  220. sky/provision/instance_setup.py +93 -14
  221. sky/provision/kubernetes/__init__.py +5 -0
  222. sky/provision/kubernetes/config.py +9 -52
  223. sky/provision/kubernetes/constants.py +17 -0
  224. sky/provision/kubernetes/instance.py +789 -247
  225. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  226. sky/provision/kubernetes/network.py +27 -17
  227. sky/provision/kubernetes/network_utils.py +40 -43
  228. sky/provision/kubernetes/utils.py +1192 -531
  229. sky/provision/kubernetes/volume.py +282 -0
  230. sky/provision/lambda_cloud/instance.py +22 -16
  231. sky/provision/nebius/constants.py +50 -0
  232. sky/provision/nebius/instance.py +19 -6
  233. sky/provision/nebius/utils.py +196 -91
  234. sky/provision/oci/instance.py +10 -5
  235. sky/provision/paperspace/instance.py +10 -7
  236. sky/provision/paperspace/utils.py +1 -1
  237. sky/provision/primeintellect/__init__.py +10 -0
  238. sky/provision/primeintellect/config.py +11 -0
  239. sky/provision/primeintellect/instance.py +454 -0
  240. sky/provision/primeintellect/utils.py +398 -0
  241. sky/provision/provisioner.py +110 -36
  242. sky/provision/runpod/__init__.py +5 -0
  243. sky/provision/runpod/instance.py +27 -6
  244. sky/provision/runpod/utils.py +51 -18
  245. sky/provision/runpod/volume.py +180 -0
  246. sky/provision/scp/__init__.py +15 -0
  247. sky/provision/scp/config.py +93 -0
  248. sky/provision/scp/instance.py +531 -0
  249. sky/provision/seeweb/__init__.py +11 -0
  250. sky/provision/seeweb/config.py +13 -0
  251. sky/provision/seeweb/instance.py +807 -0
  252. sky/provision/shadeform/__init__.py +11 -0
  253. sky/provision/shadeform/config.py +12 -0
  254. sky/provision/shadeform/instance.py +351 -0
  255. sky/provision/shadeform/shadeform_utils.py +83 -0
  256. sky/provision/ssh/__init__.py +18 -0
  257. sky/provision/vast/instance.py +13 -8
  258. sky/provision/vast/utils.py +10 -7
  259. sky/provision/vsphere/common/vim_utils.py +1 -2
  260. sky/provision/vsphere/instance.py +15 -10
  261. sky/provision/vsphere/vsphere_utils.py +9 -19
  262. sky/py.typed +0 -0
  263. sky/resources.py +844 -118
  264. sky/schemas/__init__.py +0 -0
  265. sky/schemas/api/__init__.py +0 -0
  266. sky/schemas/api/responses.py +225 -0
  267. sky/schemas/db/README +4 -0
  268. sky/schemas/db/env.py +90 -0
  269. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  270. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  271. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  272. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  273. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  274. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  275. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  276. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  277. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  278. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  279. sky/schemas/db/script.py.mako +28 -0
  280. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  281. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  282. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  283. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  284. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  285. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  286. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  287. sky/schemas/generated/__init__.py +0 -0
  288. sky/schemas/generated/autostopv1_pb2.py +36 -0
  289. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  290. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  291. sky/schemas/generated/jobsv1_pb2.py +86 -0
  292. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  293. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  294. sky/schemas/generated/managed_jobsv1_pb2.py +74 -0
  295. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  296. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  297. sky/schemas/generated/servev1_pb2.py +58 -0
  298. sky/schemas/generated/servev1_pb2.pyi +115 -0
  299. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  300. sky/serve/autoscalers.py +357 -5
  301. sky/serve/client/impl.py +310 -0
  302. sky/serve/client/sdk.py +47 -139
  303. sky/serve/client/sdk_async.py +130 -0
  304. sky/serve/constants.py +10 -8
  305. sky/serve/controller.py +64 -19
  306. sky/serve/load_balancer.py +106 -60
  307. sky/serve/load_balancing_policies.py +115 -1
  308. sky/serve/replica_managers.py +273 -162
  309. sky/serve/serve_rpc_utils.py +179 -0
  310. sky/serve/serve_state.py +554 -251
  311. sky/serve/serve_utils.py +733 -220
  312. sky/serve/server/core.py +66 -711
  313. sky/serve/server/impl.py +1093 -0
  314. sky/serve/server/server.py +21 -18
  315. sky/serve/service.py +133 -48
  316. sky/serve/service_spec.py +135 -16
  317. sky/serve/spot_placer.py +3 -0
  318. sky/server/auth/__init__.py +0 -0
  319. sky/server/auth/authn.py +50 -0
  320. sky/server/auth/loopback.py +38 -0
  321. sky/server/auth/oauth2_proxy.py +200 -0
  322. sky/server/common.py +475 -181
  323. sky/server/config.py +81 -23
  324. sky/server/constants.py +44 -6
  325. sky/server/daemons.py +229 -0
  326. sky/server/html/token_page.html +185 -0
  327. sky/server/metrics.py +160 -0
  328. sky/server/requests/executor.py +528 -138
  329. sky/server/requests/payloads.py +351 -17
  330. sky/server/requests/preconditions.py +21 -17
  331. sky/server/requests/process.py +112 -29
  332. sky/server/requests/request_names.py +120 -0
  333. sky/server/requests/requests.py +817 -224
  334. sky/server/requests/serializers/decoders.py +82 -31
  335. sky/server/requests/serializers/encoders.py +140 -22
  336. sky/server/requests/threads.py +106 -0
  337. sky/server/rest.py +417 -0
  338. sky/server/server.py +1290 -284
  339. sky/server/state.py +20 -0
  340. sky/server/stream_utils.py +345 -57
  341. sky/server/uvicorn.py +217 -3
  342. sky/server/versions.py +270 -0
  343. sky/setup_files/MANIFEST.in +5 -0
  344. sky/setup_files/alembic.ini +156 -0
  345. sky/setup_files/dependencies.py +136 -31
  346. sky/setup_files/setup.py +44 -42
  347. sky/sky_logging.py +102 -5
  348. sky/skylet/attempt_skylet.py +1 -0
  349. sky/skylet/autostop_lib.py +129 -8
  350. sky/skylet/configs.py +27 -20
  351. sky/skylet/constants.py +171 -19
  352. sky/skylet/events.py +105 -21
  353. sky/skylet/job_lib.py +335 -104
  354. sky/skylet/log_lib.py +297 -18
  355. sky/skylet/log_lib.pyi +44 -1
  356. sky/skylet/ray_patches/__init__.py +17 -3
  357. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  358. sky/skylet/ray_patches/cli.py.diff +19 -0
  359. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  360. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  361. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  362. sky/skylet/ray_patches/updater.py.diff +18 -0
  363. sky/skylet/ray_patches/worker.py.diff +41 -0
  364. sky/skylet/services.py +564 -0
  365. sky/skylet/skylet.py +63 -4
  366. sky/skylet/subprocess_daemon.py +103 -29
  367. sky/skypilot_config.py +506 -99
  368. sky/ssh_node_pools/__init__.py +1 -0
  369. sky/ssh_node_pools/core.py +135 -0
  370. sky/ssh_node_pools/server.py +233 -0
  371. sky/task.py +621 -137
  372. sky/templates/aws-ray.yml.j2 +10 -3
  373. sky/templates/azure-ray.yml.j2 +1 -1
  374. sky/templates/do-ray.yml.j2 +1 -1
  375. sky/templates/gcp-ray.yml.j2 +57 -0
  376. sky/templates/hyperbolic-ray.yml.j2 +67 -0
  377. sky/templates/jobs-controller.yaml.j2 +27 -24
  378. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  379. sky/templates/kubernetes-ray.yml.j2 +607 -51
  380. sky/templates/lambda-ray.yml.j2 +1 -1
  381. sky/templates/nebius-ray.yml.j2 +33 -12
  382. sky/templates/paperspace-ray.yml.j2 +1 -1
  383. sky/templates/primeintellect-ray.yml.j2 +71 -0
  384. sky/templates/runpod-ray.yml.j2 +9 -1
  385. sky/templates/scp-ray.yml.j2 +3 -50
  386. sky/templates/seeweb-ray.yml.j2 +108 -0
  387. sky/templates/shadeform-ray.yml.j2 +72 -0
  388. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  389. sky/templates/websocket_proxy.py +178 -18
  390. sky/usage/usage_lib.py +18 -11
  391. sky/users/__init__.py +0 -0
  392. sky/users/model.conf +15 -0
  393. sky/users/permission.py +387 -0
  394. sky/users/rbac.py +121 -0
  395. sky/users/server.py +720 -0
  396. sky/users/token_service.py +218 -0
  397. sky/utils/accelerator_registry.py +34 -5
  398. sky/utils/admin_policy_utils.py +84 -38
  399. sky/utils/annotations.py +16 -5
  400. sky/utils/asyncio_utils.py +78 -0
  401. sky/utils/auth_utils.py +153 -0
  402. sky/utils/benchmark_utils.py +60 -0
  403. sky/utils/cli_utils/status_utils.py +159 -86
  404. sky/utils/cluster_utils.py +31 -9
  405. sky/utils/command_runner.py +354 -68
  406. sky/utils/command_runner.pyi +93 -3
  407. sky/utils/common.py +35 -8
  408. sky/utils/common_utils.py +310 -87
  409. sky/utils/config_utils.py +87 -5
  410. sky/utils/context.py +402 -0
  411. sky/utils/context_utils.py +222 -0
  412. sky/utils/controller_utils.py +264 -89
  413. sky/utils/dag_utils.py +31 -12
  414. sky/utils/db/__init__.py +0 -0
  415. sky/utils/db/db_utils.py +470 -0
  416. sky/utils/db/migration_utils.py +133 -0
  417. sky/utils/directory_utils.py +12 -0
  418. sky/utils/env_options.py +13 -0
  419. sky/utils/git.py +567 -0
  420. sky/utils/git_clone.sh +460 -0
  421. sky/utils/infra_utils.py +195 -0
  422. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  423. sky/utils/kubernetes/config_map_utils.py +133 -0
  424. sky/utils/kubernetes/create_cluster.sh +13 -27
  425. sky/utils/kubernetes/delete_cluster.sh +10 -7
  426. sky/utils/kubernetes/deploy_remote_cluster.py +1299 -0
  427. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  428. sky/utils/kubernetes/generate_kind_config.py +6 -66
  429. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  430. sky/utils/kubernetes/gpu_labeler.py +5 -5
  431. sky/utils/kubernetes/kubernetes_deploy_utils.py +354 -47
  432. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  433. sky/utils/kubernetes/ssh_utils.py +221 -0
  434. sky/utils/kubernetes_enums.py +8 -15
  435. sky/utils/lock_events.py +94 -0
  436. sky/utils/locks.py +368 -0
  437. sky/utils/log_utils.py +300 -6
  438. sky/utils/perf_utils.py +22 -0
  439. sky/utils/resource_checker.py +298 -0
  440. sky/utils/resources_utils.py +249 -32
  441. sky/utils/rich_utils.py +213 -37
  442. sky/utils/schemas.py +905 -147
  443. sky/utils/serialize_utils.py +16 -0
  444. sky/utils/status_lib.py +10 -0
  445. sky/utils/subprocess_utils.py +38 -15
  446. sky/utils/tempstore.py +70 -0
  447. sky/utils/timeline.py +24 -52
  448. sky/utils/ux_utils.py +84 -15
  449. sky/utils/validator.py +11 -1
  450. sky/utils/volume.py +86 -0
  451. sky/utils/yaml_utils.py +111 -0
  452. sky/volumes/__init__.py +13 -0
  453. sky/volumes/client/__init__.py +0 -0
  454. sky/volumes/client/sdk.py +149 -0
  455. sky/volumes/server/__init__.py +0 -0
  456. sky/volumes/server/core.py +258 -0
  457. sky/volumes/server/server.py +122 -0
  458. sky/volumes/volume.py +212 -0
  459. sky/workspaces/__init__.py +0 -0
  460. sky/workspaces/core.py +655 -0
  461. sky/workspaces/server.py +101 -0
  462. sky/workspaces/utils.py +56 -0
  463. skypilot_nightly-1.0.0.dev20251107.dist-info/METADATA +675 -0
  464. skypilot_nightly-1.0.0.dev20251107.dist-info/RECORD +594 -0
  465. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +1 -1
  466. sky/benchmark/benchmark_state.py +0 -256
  467. sky/benchmark/benchmark_utils.py +0 -641
  468. sky/clouds/service_catalog/constants.py +0 -7
  469. sky/dashboard/out/_next/static/LksQgChY5izXjokL3LcEu/_buildManifest.js +0 -1
  470. sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
  471. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  472. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  473. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  474. sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
  475. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  476. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  477. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  478. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  479. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  480. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  481. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  482. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
  483. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  484. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  485. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  486. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
  487. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  488. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  489. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  490. sky/jobs/dashboard/dashboard.py +0 -223
  491. sky/jobs/dashboard/static/favicon.ico +0 -0
  492. sky/jobs/dashboard/templates/index.html +0 -831
  493. sky/jobs/server/dashboard_utils.py +0 -69
  494. sky/skylet/providers/scp/__init__.py +0 -2
  495. sky/skylet/providers/scp/config.py +0 -149
  496. sky/skylet/providers/scp/node_provider.py +0 -578
  497. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  498. sky/utils/db_utils.py +0 -100
  499. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  500. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  501. skypilot_nightly-1.0.0.dev20250509.dist-info/METADATA +0 -361
  502. skypilot_nightly-1.0.0.dev20250509.dist-info/RECORD +0 -396
  503. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  504. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  505. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  506. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  507. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  508. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  509. /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
  510. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
  511. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
  512. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -4,60 +4,87 @@ NOTE: whenever an API change is made in this file, we need to bump the
4
4
  jobs.constants.MANAGED_JOBS_VERSION and handle the API change in the
5
5
  ManagedJobCodeGen.
6
6
  """
7
+ import asyncio
7
8
  import collections
9
+ from datetime import datetime
8
10
  import enum
9
11
  import os
10
12
  import pathlib
13
+ import re
11
14
  import shlex
12
15
  import textwrap
13
16
  import time
14
17
  import traceback
15
18
  import typing
16
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
19
+ from typing import (Any, Deque, Dict, Iterable, List, Literal, Optional, Set,
20
+ TextIO, Tuple, Union)
17
21
 
18
22
  import colorama
19
23
  import filelock
20
- from typing_extensions import Literal
21
24
 
22
25
  from sky import backends
23
26
  from sky import exceptions
24
27
  from sky import global_user_state
25
28
  from sky import sky_logging
29
+ from sky import skypilot_config
26
30
  from sky.adaptors import common as adaptors_common
27
31
  from sky.backends import backend_utils
32
+ from sky.backends import cloud_vm_ray_backend
28
33
  from sky.jobs import constants as managed_job_constants
29
34
  from sky.jobs import scheduler
30
35
  from sky.jobs import state as managed_job_state
36
+ from sky.schemas.api import responses
31
37
  from sky.skylet import constants
32
38
  from sky.skylet import job_lib
33
39
  from sky.skylet import log_lib
34
40
  from sky.usage import usage_lib
41
+ from sky.utils import annotations
42
+ from sky.utils import command_runner
35
43
  from sky.utils import common_utils
44
+ from sky.utils import context_utils
45
+ from sky.utils import controller_utils
46
+ from sky.utils import infra_utils
36
47
  from sky.utils import log_utils
37
48
  from sky.utils import message_utils
49
+ from sky.utils import resources_utils
38
50
  from sky.utils import rich_utils
39
51
  from sky.utils import subprocess_utils
40
52
  from sky.utils import ux_utils
41
53
 
42
54
  if typing.TYPE_CHECKING:
55
+ from google.protobuf import descriptor
56
+ from google.protobuf import json_format
57
+ import grpc
43
58
  import psutil
44
59
 
45
60
  import sky
46
61
  from sky import dag as dag_lib
62
+ from sky.schemas.generated import jobsv1_pb2
63
+ from sky.schemas.generated import managed_jobsv1_pb2
47
64
  else:
65
+ json_format = adaptors_common.LazyImport('google.protobuf.json_format')
66
+ descriptor = adaptors_common.LazyImport('google.protobuf.descriptor')
48
67
  psutil = adaptors_common.LazyImport('psutil')
68
+ grpc = adaptors_common.LazyImport('grpc')
69
+ jobsv1_pb2 = adaptors_common.LazyImport('sky.schemas.generated.jobsv1_pb2')
70
+ managed_jobsv1_pb2 = adaptors_common.LazyImport(
71
+ 'sky.schemas.generated.managed_jobsv1_pb2')
49
72
 
50
73
  logger = sky_logging.init_logger(__name__)
51
74
 
52
- SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
53
75
  # Controller checks its job's status every this many seconds.
54
- JOB_STATUS_CHECK_GAP_SECONDS = 20
76
+ # This is a tradeoff between the latency and the resource usage.
77
+ JOB_STATUS_CHECK_GAP_SECONDS = 15
55
78
 
56
79
  # Controller checks if its job has started every this many seconds.
57
80
  JOB_STARTED_STATUS_CHECK_GAP_SECONDS = 5
58
81
 
59
82
  _LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS = 5
60
83
 
84
+ _JOB_STATUS_FETCH_MAX_RETRIES = 3
85
+ _JOB_K8S_TRANSIENT_NW_MSG = 'Unable to connect to the server: dial tcp'
86
+ _JOB_STATUS_FETCH_TIMEOUT_SECONDS = 30
87
+
61
88
  _JOB_WAITING_STATUS_MESSAGE = ux_utils.spinner_message(
62
89
  'Waiting for task to start[/]'
63
90
  '{status_str}. It may take a few minutes.\n'
@@ -72,7 +99,35 @@ _JOB_CANCELLED_MESSAGE = (
72
99
  # blocking for a long time. This should be significantly longer than the
73
100
  # JOB_STATUS_CHECK_GAP_SECONDS to avoid timing out before the controller can
74
101
  # update the state.
75
- _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 40
102
+ _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
103
+
104
+ # After enabling consolidation mode, we need to restart the API server to get
105
+ # the jobs refresh deamon and correct number of executors. We use this file to
106
+ # indicate that the API server has been restarted after enabling consolidation
107
+ # mode.
108
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE = (
109
+ '~/.sky/.jobs_controller_consolidation_reloaded_signal')
110
+
111
+ # The response fields for managed jobs that require cluster handle
112
+ _CLUSTER_HANDLE_FIELDS = [
113
+ 'cluster_resources',
114
+ 'cluster_resources_full',
115
+ 'cloud',
116
+ 'region',
117
+ 'zone',
118
+ 'infra',
119
+ 'accelerators',
120
+ ]
121
+
122
+ # The response fields for managed jobs that are not stored in the database
123
+ # These fields will be mapped to the DB fields in the `_update_fields`.
124
+ _NON_DB_FIELDS = _CLUSTER_HANDLE_FIELDS + ['user_yaml', 'user_name', 'details']
125
+
126
+
127
+ class ManagedJobQueueResultType(enum.Enum):
128
+ """The type of the managed job queue result."""
129
+ DICT = 'DICT'
130
+ LIST = 'LIST'
76
131
 
77
132
 
78
133
  class UserSignal(enum.Enum):
@@ -83,7 +138,10 @@ class UserSignal(enum.Enum):
83
138
 
84
139
 
85
140
  # ====== internal functions ======
86
- def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
141
+ def terminate_cluster(
142
+ cluster_name: str,
143
+ max_retry: int = 6,
144
+ ) -> None:
87
145
  """Terminate the cluster."""
88
146
  from sky import core # pylint: disable=import-outside-toplevel
89
147
  retry_cnt = 0
@@ -121,42 +179,256 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
121
179
  time.sleep(backoff.current_backoff())
122
180
 
123
181
 
124
- def get_job_status(backend: 'backends.CloudVmRayBackend',
125
- cluster_name: str) -> Optional['job_lib.JobStatus']:
182
+ def _validate_consolidation_mode_config(
183
+ current_is_consolidation_mode: bool) -> None:
184
+ """Validate the consolidation mode config."""
185
+ # Check whether the consolidation mode config is changed.
186
+ if current_is_consolidation_mode:
187
+ controller_cn = (
188
+ controller_utils.Controllers.JOBS_CONTROLLER.value.cluster_name)
189
+ if global_user_state.cluster_with_name_exists(controller_cn):
190
+ with ux_utils.print_exception_no_traceback():
191
+ raise exceptions.InconsistentConsolidationModeError(
192
+ f'{colorama.Fore.RED}Consolidation mode for jobs is '
193
+ f'enabled, but the controller cluster '
194
+ f'{controller_cn} is still running. Please '
195
+ 'terminate the controller cluster first.'
196
+ f'{colorama.Style.RESET_ALL}')
197
+ else:
198
+ total_jobs = managed_job_state.get_managed_jobs_total()
199
+ if total_jobs > 0:
200
+ nonterminal_jobs = (
201
+ managed_job_state.get_nonterminal_job_ids_by_name(
202
+ None, None, all_users=True))
203
+ if nonterminal_jobs:
204
+ with ux_utils.print_exception_no_traceback():
205
+ raise exceptions.InconsistentConsolidationModeError(
206
+ f'{colorama.Fore.RED}Consolidation mode '
207
+ 'is disabled, but there are still '
208
+ f'{len(nonterminal_jobs)} managed jobs '
209
+ 'running. Please terminate those jobs '
210
+ f'first.{colorama.Style.RESET_ALL}')
211
+ else:
212
+ logger.warning(
213
+ f'{colorama.Fore.YELLOW}Consolidation mode is disabled, '
214
+ f'but there are {total_jobs} jobs from previous '
215
+ 'consolidation mode. Reset the `jobs.controller.'
216
+ 'consolidation_mode` to `true` and run `sky jobs queue` '
217
+ 'to see those jobs. Switching to normal mode will '
218
+ f'lose the job history.{colorama.Style.RESET_ALL}')
219
+
220
+
221
+ # Whether to use consolidation mode or not. When this is enabled, the managed
222
+ # jobs controller will not be running on a separate cluster, but locally on the
223
+ # API Server. Under the hood, we submit the job monitoring logic as processes
224
+ # directly in the API Server.
225
+ # Use LRU Cache so that the check is only done once.
226
+ @annotations.lru_cache(scope='request', maxsize=2)
227
+ def is_consolidation_mode(on_api_restart: bool = False) -> bool:
228
+ if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
229
+ return True
230
+
231
+ config_consolidation_mode = skypilot_config.get_nested(
232
+ ('jobs', 'controller', 'consolidation_mode'), default_value=False)
233
+
234
+ signal_file = pathlib.Path(
235
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE).expanduser()
236
+
237
+ restart_signal_file_exists = signal_file.exists()
238
+ consolidation_mode = (config_consolidation_mode and
239
+ restart_signal_file_exists)
240
+
241
+ if on_api_restart:
242
+ if config_consolidation_mode:
243
+ signal_file.touch()
244
+ else:
245
+ if not restart_signal_file_exists:
246
+ if config_consolidation_mode:
247
+ logger.warning(f'{colorama.Fore.YELLOW}Consolidation mode for '
248
+ 'managed jobs is enabled in the server config, '
249
+ 'but the API server has not been restarted yet. '
250
+ 'Please restart the API server to enable it.'
251
+ f'{colorama.Style.RESET_ALL}')
252
+ return False
253
+ elif not config_consolidation_mode:
254
+ # Cleanup the signal file if the consolidation mode is disabled in
255
+ # the config. This allow the user to disable the consolidation mode
256
+ # without restarting the API server.
257
+ signal_file.unlink()
258
+
259
+ # We should only do this check on API server, as the controller will not
260
+ # have related config and will always seemingly disabled for consolidation
261
+ # mode. Check #6611 for more details.
262
+ if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
263
+ _validate_consolidation_mode_config(consolidation_mode)
264
+ return consolidation_mode
265
+
266
+
267
+ def ha_recovery_for_consolidation_mode():
268
+ """Recovery logic for HA mode."""
269
+ # Touch the signal file here to avoid conflict with
270
+ # update_managed_jobs_statuses. Although we run this first and then start
271
+ # the deamon, this function is also called in cancel_jobs_by_id.
272
+ signal_file = pathlib.Path(
273
+ constants.PERSISTENT_RUN_RESTARTING_SIGNAL_FILE).expanduser()
274
+ signal_file.touch()
275
+ # No setup recovery is needed in consolidation mode, as the API server
276
+ # already has all runtime installed. Directly start jobs recovery here.
277
+ # Refers to sky/templates/kubernetes-ray.yml.j2 for more details.
278
+ runner = command_runner.LocalProcessCommandRunner()
279
+ scheduler.maybe_start_controllers()
280
+ with open(constants.HA_PERSISTENT_RECOVERY_LOG_PATH.format('jobs_'),
281
+ 'w',
282
+ encoding='utf-8') as f:
283
+ start = time.time()
284
+ f.write(f'Starting HA recovery at {datetime.datetime.now()}\n')
285
+ jobs, _ = managed_job_state.get_managed_jobs_with_filters(
286
+ fields=['job_id', 'controller_pid', 'schedule_state', 'status'])
287
+ for job in jobs:
288
+ job_id = job['job_id']
289
+ controller_pid = job['controller_pid']
290
+
291
+ # In consolidation mode, it is possible that only the API server
292
+ # process is restarted, and the controller process is not. In such
293
+ # case, we don't need to do anything and the controller process will
294
+ # just keep running.
295
+ if controller_pid is not None:
296
+ try:
297
+ if controller_process_alive(controller_pid, job_id):
298
+ f.write(f'Controller pid {controller_pid} for '
299
+ f'job {job_id} is still running. '
300
+ 'Skipping recovery.\n')
301
+ continue
302
+ except Exception: # pylint: disable=broad-except
303
+ # _controller_process_alive may raise if psutil fails; we
304
+ # should not crash the recovery logic because of this.
305
+ f.write('Error checking controller pid '
306
+ f'{controller_pid} for job {job_id}\n')
307
+
308
+ if job['schedule_state'] not in [
309
+ managed_job_state.ManagedJobScheduleState.DONE,
310
+ managed_job_state.ManagedJobScheduleState.WAITING,
311
+ ]:
312
+ script = managed_job_state.get_ha_recovery_script(job_id)
313
+ if script is None:
314
+ f.write(f'Job {job_id}\'s recovery script does not exist. '
315
+ 'Skipping recovery. Job schedule state: '
316
+ f'{job["schedule_state"]}\n')
317
+ continue
318
+ runner.run(script)
319
+ f.write(f'Job {job_id} completed recovery at '
320
+ f'{datetime.datetime.now()}\n')
321
+ f.write(f'HA recovery completed at {datetime.datetime.now()}\n')
322
+ f.write(f'Total recovery time: {time.time() - start} seconds\n')
323
+ signal_file.unlink()
324
+
325
+
326
+ async def get_job_status(
327
+ backend: 'backends.CloudVmRayBackend', cluster_name: str,
328
+ job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
126
329
  """Check the status of the job running on a managed job cluster.
127
330
 
128
331
  It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
129
332
  FAILED_SETUP or CANCELLED.
130
333
  """
131
- handle = global_user_state.get_handle_from_cluster_name(cluster_name)
334
+ # TODO(luca) make this async
335
+ handle = await context_utils.to_thread(
336
+ global_user_state.get_handle_from_cluster_name, cluster_name)
132
337
  if handle is None:
133
338
  # This can happen if the cluster was preempted and background status
134
339
  # refresh already noticed and cleaned it up.
135
340
  logger.info(f'Cluster {cluster_name} not found.')
136
341
  return None
137
342
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
138
- status = None
139
- try:
140
- logger.info('=== Checking the job status... ===')
141
- statuses = backend.get_job_status(handle, stream_logs=False)
142
- status = list(statuses.values())[0]
143
- if status is None:
144
- logger.info('No job found.')
145
- else:
146
- logger.info(f'Job status: {status}')
147
- except exceptions.CommandError:
148
- logger.info('Failed to connect to the cluster.')
149
- logger.info('=' * 34)
150
- return status
343
+ job_ids = None if job_id is None else [job_id]
344
+ for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
345
+ try:
346
+ logger.info('=== Checking the job status... ===')
347
+ statuses = await asyncio.wait_for(
348
+ context_utils.to_thread(backend.get_job_status,
349
+ handle,
350
+ job_ids=job_ids,
351
+ stream_logs=False),
352
+ timeout=_JOB_STATUS_FETCH_TIMEOUT_SECONDS)
353
+ status = list(statuses.values())[0]
354
+ if status is None:
355
+ logger.info('No job found.')
356
+ else:
357
+ logger.info(f'Job status: {status}')
358
+ logger.info('=' * 34)
359
+ return status
360
+ except (exceptions.CommandError, grpc.RpcError, grpc.FutureTimeoutError,
361
+ ValueError, TypeError, asyncio.TimeoutError) as e:
362
+ # Note: Each of these exceptions has some additional conditions to
363
+ # limit how we handle it and whether or not we catch it.
364
+ # Retry on k8s transient network errors. This is useful when using
365
+ # coreweave which may have transient network issue sometimes.
366
+ is_transient_error = False
367
+ detailed_reason = None
368
+ if isinstance(e, exceptions.CommandError):
369
+ detailed_reason = e.detailed_reason
370
+ if (detailed_reason is not None and
371
+ _JOB_K8S_TRANSIENT_NW_MSG in detailed_reason):
372
+ is_transient_error = True
373
+ elif isinstance(e, grpc.RpcError):
374
+ detailed_reason = e.details()
375
+ if e.code() in [
376
+ grpc.StatusCode.UNAVAILABLE,
377
+ grpc.StatusCode.DEADLINE_EXCEEDED
378
+ ]:
379
+ is_transient_error = True
380
+ elif isinstance(e, grpc.FutureTimeoutError):
381
+ detailed_reason = 'Timeout'
382
+ elif isinstance(e, asyncio.TimeoutError):
383
+ detailed_reason = ('Job status check timed out after '
384
+ f'{_JOB_STATUS_FETCH_TIMEOUT_SECONDS}s')
385
+ # TODO(cooperc): Gracefully handle these exceptions in the backend.
386
+ elif isinstance(e, ValueError):
387
+ # If the cluster yaml is deleted in the middle of getting the
388
+ # SSH credentials, we could see this. See
389
+ # sky/global_user_state.py get_cluster_yaml_dict.
390
+ if re.search(r'Cluster yaml .* not found', str(e)):
391
+ detailed_reason = 'Cluster yaml was deleted'
392
+ else:
393
+ raise
394
+ elif isinstance(e, TypeError):
395
+ # We will grab the SSH credentials from the cluster yaml, but if
396
+ # handle.cluster_yaml is None, we will just return an empty dict
397
+ # for the credentials. See
398
+ # backend_utils.ssh_credential_from_yaml. Then, the credentials
399
+ # are passed as kwargs to SSHCommandRunner.__init__ - see
400
+ # cloud_vm_ray_backend.get_command_runners. So we can hit this
401
+ # TypeError if the cluster yaml is removed from the handle right
402
+ # when we pull it before the cluster is fully deleted.
403
+ error_msg_to_check = (
404
+ 'SSHCommandRunner.__init__() missing 2 required positional '
405
+ 'arguments: \'ssh_user\' and \'ssh_private_key\'')
406
+ if str(e) == error_msg_to_check:
407
+ detailed_reason = 'SSH credentials were already cleaned up'
408
+ else:
409
+ raise
410
+ if is_transient_error:
411
+ logger.info('Failed to connect to the cluster. Retrying '
412
+ f'({i + 1}/{_JOB_STATUS_FETCH_MAX_RETRIES})...')
413
+ logger.info('=' * 34)
414
+ await asyncio.sleep(1)
415
+ else:
416
+ logger.info(f'Failed to get job status: {detailed_reason}')
417
+ logger.info('=' * 34)
418
+ return None
419
+ return None
151
420
 
152
421
 
153
- def _controller_process_alive(pid: int, job_id: int) -> bool:
422
+ def controller_process_alive(pid: int, job_id: int) -> bool:
154
423
  """Check if the controller process is alive."""
155
424
  try:
425
+ if pid < 0:
426
+ # new job controller process will always be negative
427
+ pid = -pid
156
428
  process = psutil.Process(pid)
157
- # The last two args of the command line should be --job-id <id>
158
- job_args = process.cmdline()[-2:]
159
- return process.is_running() and job_args == ['--job-id', str(job_id)]
429
+ cmd_str = ' '.join(process.cmdline())
430
+ return process.is_running() and ((f'--job-id {job_id}' in cmd_str) or
431
+ ('controller' in cmd_str))
160
432
  except psutil.NoSuchProcess:
161
433
  return False
162
434
 
@@ -173,6 +445,17 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
173
445
  Note: we expect that job_id, if provided, refers to a nonterminal job or a
174
446
  job that has not completed its cleanup (schedule state not DONE).
175
447
  """
448
+ # This signal file suggests that the controller is recovering from a
449
+ # failure. See sky/templates/kubernetes-ray.yml.j2 for more details.
450
+ # When restarting the controller processes, we don't want this event to
451
+ # set the job status to FAILED_CONTROLLER.
452
+ # TODO(tian): Change this to restart the controller process. For now we
453
+ # disabled it when recovering because we want to avoid caveats of infinite
454
+ # restart of last controller process that fully occupied the controller VM.
455
+ if os.path.exists(
456
+ os.path.expanduser(
457
+ constants.PERSISTENT_RUN_RESTARTING_SIGNAL_FILE)):
458
+ return
176
459
 
177
460
  def _cleanup_job_clusters(job_id: int) -> Optional[str]:
178
461
  """Clean up clusters for a job. Returns error message if any.
@@ -180,16 +463,24 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
180
463
  This function should not throw any exception. If it fails, it will
181
464
  capture the error message, and log/return it.
182
465
  """
466
+ managed_job_state.remove_ha_recovery_script(job_id)
183
467
  error_msg = None
184
- tasks = managed_job_state.get_managed_jobs(job_id)
468
+ tasks = managed_job_state.get_managed_job_tasks(job_id)
185
469
  for task in tasks:
186
- task_name = task['job_name']
187
- cluster_name = generate_managed_job_cluster_name(task_name, job_id)
470
+ pool = task.get('pool', None)
471
+ if pool is None:
472
+ task_name = task['job_name']
473
+ cluster_name = generate_managed_job_cluster_name(
474
+ task_name, job_id)
475
+ else:
476
+ cluster_name, _ = (
477
+ managed_job_state.get_pool_submit_info(job_id))
188
478
  handle = global_user_state.get_handle_from_cluster_name(
189
479
  cluster_name)
190
480
  if handle is not None:
191
481
  try:
192
- terminate_cluster(cluster_name)
482
+ if pool is None:
483
+ terminate_cluster(cluster_name)
193
484
  except Exception as e: # pylint: disable=broad-except
194
485
  error_msg = (
195
486
  f'Failed to terminate cluster {cluster_name}: '
@@ -242,7 +533,8 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
242
533
  return
243
534
 
244
535
  for job_id in job_ids:
245
- tasks = managed_job_state.get_managed_jobs(job_id)
536
+ assert job_id is not None
537
+ tasks = managed_job_state.get_managed_job_tasks(job_id)
246
538
  # Note: controller_pid and schedule_state are in the job_info table
247
539
  # which is joined to the spot table, so all tasks with the same job_id
248
540
  # will have the same value for these columns. This is what lets us just
@@ -262,9 +554,9 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
262
554
  if schedule_state == managed_job_state.ManagedJobScheduleState.DONE:
263
555
  # There are two cases where we could get a job that is DONE.
264
556
  # 1. At query time (get_jobs_to_check_status), the job was not yet
265
- # DONE, but since then (before get_managed_jobs is called) it has
266
- # hit a terminal status, marked itself done, and exited. This is
267
- # fine.
557
+ # DONE, but since then (before get_managed_job_tasks is called)
558
+ # it has hit a terminal status, marked itself done, and exited.
559
+ # This is fine.
268
560
  # 2. The job is DONE, but in a non-terminal status. This is
269
561
  # unexpected. For instance, the task status is RUNNING, but the
270
562
  # job schedule_state is DONE.
@@ -311,7 +603,7 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
311
603
  failure_reason = f'No controller pid set for {schedule_state.value}'
312
604
  else:
313
605
  logger.debug(f'Checking controller pid {pid}')
314
- if _controller_process_alive(pid, job_id):
606
+ if controller_process_alive(pid, job_id):
315
607
  # The controller is still running, so this job is fine.
316
608
  continue
317
609
 
@@ -369,11 +661,34 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
369
661
 
370
662
 
371
663
  def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
372
- get_end_time: bool) -> float:
664
+ job_id: Optional[int], get_end_time: bool) -> float:
373
665
  """Get the submitted/ended time of the job."""
374
- code = job_lib.JobLibCodeGen.get_job_submitted_or_ended_timestamp_payload(
375
- job_id=None, get_ended_time=get_end_time)
376
666
  handle = global_user_state.get_handle_from_cluster_name(cluster_name)
667
+ assert handle is not None, (
668
+ f'handle for cluster {cluster_name!r} should not be None')
669
+ if handle.is_grpc_enabled_with_flag:
670
+ try:
671
+ if get_end_time:
672
+ end_ts_request = jobsv1_pb2.GetJobEndedTimestampRequest(
673
+ job_id=job_id)
674
+ end_ts_response = backend_utils.invoke_skylet_with_retries(
675
+ lambda: cloud_vm_ray_backend.SkyletClient(
676
+ handle.get_grpc_channel()).get_job_ended_timestamp(
677
+ end_ts_request))
678
+ return end_ts_response.timestamp
679
+ else:
680
+ submit_ts_request = jobsv1_pb2.GetJobSubmittedTimestampRequest(
681
+ job_id=job_id)
682
+ submit_ts_response = backend_utils.invoke_skylet_with_retries(
683
+ lambda: cloud_vm_ray_backend.SkyletClient(
684
+ handle.get_grpc_channel()).get_job_submitted_timestamp(
685
+ submit_ts_request))
686
+ return submit_ts_response.timestamp
687
+ except exceptions.SkyletMethodNotImplementedError:
688
+ pass
689
+
690
+ code = (job_lib.JobLibCodeGen.get_job_submitted_or_ended_timestamp_payload(
691
+ job_id=job_id, get_ended_time=get_end_time))
377
692
  returncode, stdout, stderr = backend.run_on_head(handle,
378
693
  code,
379
694
  stream_logs=False,
@@ -386,16 +701,24 @@ def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
386
701
 
387
702
 
388
703
  def try_to_get_job_end_time(backend: 'backends.CloudVmRayBackend',
389
- cluster_name: str) -> float:
704
+ cluster_name: str, job_id: Optional[int]) -> float:
390
705
  """Try to get the end time of the job.
391
706
 
392
707
  If the job is preempted or we can't connect to the instance for whatever
393
708
  reason, fall back to the current time.
394
709
  """
395
710
  try:
396
- return get_job_timestamp(backend, cluster_name, get_end_time=True)
397
- except exceptions.CommandError as e:
398
- if e.returncode == 255:
711
+ return get_job_timestamp(backend,
712
+ cluster_name,
713
+ job_id=job_id,
714
+ get_end_time=True)
715
+ except (exceptions.CommandError, grpc.RpcError,
716
+ grpc.FutureTimeoutError) as e:
717
+ if isinstance(e, exceptions.CommandError) and e.returncode == 255 or \
718
+ (isinstance(e, grpc.RpcError) and e.code() in [
719
+ grpc.StatusCode.UNAVAILABLE,
720
+ grpc.StatusCode.DEADLINE_EXCEEDED,
721
+ ]) or isinstance(e, grpc.FutureTimeoutError):
399
722
  # Failed to connect - probably the instance was preempted since the
400
723
  # job completed. We shouldn't crash here, so just log and use the
401
724
  # current time.
@@ -407,7 +730,9 @@ def try_to_get_job_end_time(backend: 'backends.CloudVmRayBackend',
407
730
  raise
408
731
 
409
732
 
410
- def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
733
+ def event_callback_func(
734
+ job_id: int, task_id: Optional[int],
735
+ task: Optional['sky.Task']) -> managed_job_state.AsyncCallbackType:
411
736
  """Run event callback for the task."""
412
737
 
413
738
  def callback_func(status: str):
@@ -415,8 +740,12 @@ def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
415
740
  if event_callback is None or task is None:
416
741
  return
417
742
  event_callback = event_callback.strip()
418
- cluster_name = generate_managed_job_cluster_name(
419
- task.name, job_id) if task.name else None
743
+ pool = managed_job_state.get_pool_from_job_id(job_id)
744
+ if pool is not None:
745
+ cluster_name, _ = (managed_job_state.get_pool_submit_info(job_id))
746
+ else:
747
+ cluster_name = generate_managed_job_cluster_name(
748
+ task.name, job_id) if task.name else None
420
749
  logger.info(f'=== START: event callback for {status!r} ===')
421
750
  log_path = os.path.join(constants.SKY_LOGS_DIRECTORY,
422
751
  'managed_job_event',
@@ -442,7 +771,10 @@ def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
442
771
  f'Bash:{event_callback},log_path:{log_path},result:{result}')
443
772
  logger.info(f'=== END: event callback for {status!r} ===')
444
773
 
445
- return callback_func
774
+ async def async_callback_func(status: str):
775
+ return await context_utils.to_thread(callback_func, status)
776
+
777
+ return async_callback_func
446
778
 
447
779
 
448
780
  # ======== user functions ========
@@ -461,20 +793,24 @@ def generate_managed_job_cluster_name(task_name: str, job_id: int) -> str:
461
793
 
462
794
 
463
795
  def cancel_jobs_by_id(job_ids: Optional[List[int]],
464
- all_users: bool = False) -> str:
796
+ all_users: bool = False,
797
+ current_workspace: Optional[str] = None,
798
+ user_hash: Optional[str] = None) -> str:
465
799
  """Cancel jobs by id.
466
800
 
467
801
  If job_ids is None, cancel all jobs.
468
802
  """
469
803
  if job_ids is None:
470
804
  job_ids = managed_job_state.get_nonterminal_job_ids_by_name(
471
- None, all_users)
805
+ None, user_hash, all_users)
472
806
  job_ids = list(set(job_ids))
473
807
  if not job_ids:
474
808
  return 'No job to cancel.'
475
- job_id_str = ', '.join(map(str, job_ids))
476
- logger.info(f'Cancelling jobs {job_id_str}.')
809
+ if current_workspace is None:
810
+ current_workspace = constants.SKYPILOT_DEFAULT_WORKSPACE
811
+
477
812
  cancelled_job_ids: List[int] = []
813
+ wrong_workspace_job_ids: List[int] = []
478
814
  for job_id in job_ids:
479
815
  # Check the status of the managed job status. If it is in
480
816
  # terminal state, we can safely skip it.
@@ -486,11 +822,41 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
486
822
  logger.info(f'Job {job_id} is already in terminal state '
487
823
  f'{job_status.value}. Skipped.')
488
824
  continue
825
+ elif job_status == managed_job_state.ManagedJobStatus.PENDING:
826
+ # the if is a short circuit, this will be atomic.
827
+ cancelled = managed_job_state.set_pending_cancelled(job_id)
828
+ if cancelled:
829
+ cancelled_job_ids.append(job_id)
830
+ continue
489
831
 
490
832
  update_managed_jobs_statuses(job_id)
491
833
 
834
+ job_controller_pid = managed_job_state.get_job_controller_pid(job_id)
835
+ if job_controller_pid is not None and job_controller_pid < 0:
836
+ # This is a consolidated job controller, so we need to cancel the
837
+ # with the controller server API
838
+ try:
839
+ # we create a file as a signal to the controller server
840
+ signal_file = pathlib.Path(
841
+ managed_job_constants.CONSOLIDATED_SIGNAL_PATH, f'{job_id}')
842
+ signal_file.touch()
843
+ cancelled_job_ids.append(job_id)
844
+ except OSError as e:
845
+ logger.error(f'Failed to cancel job {job_id} '
846
+ f'with controller server: {e}')
847
+ # don't add it to the to be cancelled job ids, since we don't
848
+ # know for sure yet.
849
+ continue
850
+ continue
851
+
852
+ job_workspace = managed_job_state.get_workspace(job_id)
853
+ if current_workspace is not None and job_workspace != current_workspace:
854
+ wrong_workspace_job_ids.append(job_id)
855
+ continue
856
+
492
857
  # Send the signal to the jobs controller.
493
- signal_file = pathlib.Path(SIGNAL_FILE_PREFIX.format(job_id))
858
+ signal_file = (pathlib.Path(
859
+ managed_job_constants.SIGNAL_FILE_PREFIX.format(job_id)))
494
860
  # Filelock is needed to prevent race condition between signal
495
861
  # check/removal and signal writing.
496
862
  with filelock.FileLock(str(signal_file) + '.lock'):
@@ -499,17 +865,30 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
499
865
  f.flush()
500
866
  cancelled_job_ids.append(job_id)
501
867
 
868
+ wrong_workspace_job_str = ''
869
+ if wrong_workspace_job_ids:
870
+ plural = 's' if len(wrong_workspace_job_ids) > 1 else ''
871
+ plural_verb = 'are' if len(wrong_workspace_job_ids) > 1 else 'is'
872
+ wrong_workspace_job_str = (
873
+ f' Job{plural} with ID{plural}'
874
+ f' {", ".join(map(str, wrong_workspace_job_ids))} '
875
+ f'{plural_verb} skipped as they are not in the active workspace '
876
+ f'{current_workspace!r}. Check the workspace of the job with: '
877
+ f'sky jobs queue')
878
+
502
879
  if not cancelled_job_ids:
503
- return 'No job to cancel.'
880
+ return f'No job to cancel.{wrong_workspace_job_str}'
504
881
  identity_str = f'Job with ID {cancelled_job_ids[0]} is'
505
882
  if len(cancelled_job_ids) > 1:
506
883
  cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
507
884
  identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'
508
885
 
509
- return f'{identity_str} scheduled to be cancelled.'
886
+ msg = f'{identity_str} scheduled to be cancelled.{wrong_workspace_job_str}'
887
+ return msg
510
888
 
511
889
 
512
- def cancel_job_by_name(job_name: str) -> str:
890
+ def cancel_job_by_name(job_name: str,
891
+ current_workspace: Optional[str] = None) -> str:
513
892
  """Cancel a job by name."""
514
893
  job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
515
894
  if not job_ids:
@@ -518,11 +897,30 @@ def cancel_job_by_name(job_name: str) -> str:
518
897
  return (f'{colorama.Fore.RED}Multiple running jobs found '
519
898
  f'with name {job_name!r}.\n'
520
899
  f'Job IDs: {job_ids}{colorama.Style.RESET_ALL}')
521
- cancel_jobs_by_id(job_ids)
522
- return f'Job {job_name!r} is scheduled to be cancelled.'
900
+ msg = cancel_jobs_by_id(job_ids, current_workspace=current_workspace)
901
+ return f'{job_name!r} {msg}'
902
+
523
903
 
904
+ def cancel_jobs_by_pool(pool_name: str,
905
+ current_workspace: Optional[str] = None) -> str:
906
+ """Cancel all jobs in a pool."""
907
+ job_ids = managed_job_state.get_nonterminal_job_ids_by_pool(pool_name)
908
+ if not job_ids:
909
+ return f'No running job found in pool {pool_name!r}.'
910
+ return cancel_jobs_by_id(job_ids, current_workspace=current_workspace)
911
+
912
+
913
+ def controller_log_file_for_job(job_id: int,
914
+ create_if_not_exists: bool = False) -> str:
915
+ log_dir = os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR)
916
+ if create_if_not_exists:
917
+ os.makedirs(log_dir, exist_ok=True)
918
+ return os.path.join(log_dir, f'{job_id}.log')
524
919
 
525
- def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
920
+
921
+ def stream_logs_by_id(job_id: int,
922
+ follow: bool = True,
923
+ tail: Optional[int] = None) -> Tuple[str, int]:
526
924
  """Stream logs by job id.
527
925
 
528
926
  Returns:
@@ -552,18 +950,60 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
552
950
  if managed_job_status.is_failed():
553
951
  job_msg = ('\nFailure reason: '
554
952
  f'{managed_job_state.get_failure_reason(job_id)}')
555
- log_file = managed_job_state.get_local_log_file(job_id, None)
556
- if log_file is not None:
557
- with open(os.path.expanduser(log_file), 'r',
558
- encoding='utf-8') as f:
559
- # Stream the logs to the console without reading the whole
560
- # file into memory.
561
- start_streaming = False
562
- for line in f:
563
- if log_lib.LOG_FILE_START_STREAMING_AT in line:
953
+ log_file_ever_existed = False
954
+ task_info = managed_job_state.get_all_task_ids_names_statuses_logs(
955
+ job_id)
956
+ num_tasks = len(task_info)
957
+ for (task_id, task_name, task_status, log_file,
958
+ logs_cleaned_at) in task_info:
959
+ if log_file:
960
+ log_file_ever_existed = True
961
+ if logs_cleaned_at is not None:
962
+ ts_str = datetime.fromtimestamp(
963
+ logs_cleaned_at).strftime('%Y-%m-%d %H:%M:%S')
964
+ print(f'Task {task_name}({task_id}) log has been '
965
+ f'cleaned at {ts_str}.')
966
+ continue
967
+ task_str = (f'Task {task_name}({task_id})'
968
+ if task_name else f'Task {task_id}')
969
+ if num_tasks > 1:
970
+ print(f'=== {task_str} ===')
971
+ with open(os.path.expanduser(log_file),
972
+ 'r',
973
+ encoding='utf-8') as f:
974
+ # Stream the logs to the console without reading the
975
+ # whole file into memory.
976
+ start_streaming = False
977
+ read_from: Union[TextIO, Deque[str]] = f
978
+ if tail is not None:
979
+ assert tail > 0
980
+ # Read only the last 'tail' lines using deque
981
+ read_from = collections.deque(f, maxlen=tail)
982
+ # We set start_streaming to True here in case
983
+ # truncating the log file removes the line that
984
+ # contains LOG_FILE_START_STREAMING_AT. This does
985
+ # not cause issues for log files shorter than tail
986
+ # because tail_logs in sky/skylet/log_lib.py also
987
+ # handles LOG_FILE_START_STREAMING_AT.
564
988
  start_streaming = True
565
- if start_streaming:
566
- print(line, end='', flush=True)
989
+ for line in read_from:
990
+ if log_lib.LOG_FILE_START_STREAMING_AT in line:
991
+ start_streaming = True
992
+ if start_streaming:
993
+ print(line, end='', flush=True)
994
+ if num_tasks > 1:
995
+ # Add the "Task finished" message for terminal states
996
+ if task_status.is_terminal():
997
+ print(ux_utils.finishing_message(
998
+ f'{task_str} finished '
999
+ f'(status: {task_status.value}).'),
1000
+ flush=True)
1001
+ if log_file_ever_existed:
1002
+ # Add the "Job finished" message for terminal states
1003
+ if managed_job_status.is_terminal():
1004
+ print(ux_utils.finishing_message(
1005
+ f'Job finished (status: {managed_job_status.value}).'),
1006
+ flush=True)
567
1007
  return '', exceptions.JobExitCode.from_managed_job_status(
568
1008
  managed_job_status)
569
1009
  return (f'{colorama.Fore.YELLOW}'
@@ -585,12 +1025,19 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
585
1025
 
586
1026
  while should_keep_logging(managed_job_status):
587
1027
  handle = None
1028
+ job_id_to_tail = None
588
1029
  if task_id is not None:
589
- task_name = managed_job_state.get_task_name(job_id, task_id)
590
- cluster_name = generate_managed_job_cluster_name(
591
- task_name, job_id)
592
- handle = global_user_state.get_handle_from_cluster_name(
593
- cluster_name)
1030
+ pool = managed_job_state.get_pool_from_job_id(job_id)
1031
+ if pool is not None:
1032
+ cluster_name, job_id_to_tail = (
1033
+ managed_job_state.get_pool_submit_info(job_id))
1034
+ else:
1035
+ task_name = managed_job_state.get_task_name(job_id, task_id)
1036
+ cluster_name = generate_managed_job_cluster_name(
1037
+ task_name, job_id)
1038
+ if cluster_name is not None:
1039
+ handle = global_user_state.get_handle_from_cluster_name(
1040
+ cluster_name)
594
1041
 
595
1042
  # Check the handle: The cluster can be preempted and removed from
596
1043
  # the table before the managed job state is updated by the
@@ -620,10 +1067,12 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
620
1067
  managed_job_state.ManagedJobStatus.RUNNING)
621
1068
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
622
1069
  status_display.stop()
1070
+ tail_param = tail if tail is not None else 0
623
1071
  returncode = backend.tail_logs(handle,
624
- job_id=None,
1072
+ job_id=job_id_to_tail,
625
1073
  managed_job_id=job_id,
626
- follow=follow)
1074
+ follow=follow,
1075
+ tail=tail_param)
627
1076
  if returncode in [rc.value for rc in exceptions.JobExitCode]:
628
1077
  # If the log tailing exits with a known exit code we can safely
629
1078
  # break the loop because it indicates the tailing process
@@ -760,7 +1209,8 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
760
1209
  def stream_logs(job_id: Optional[int],
761
1210
  job_name: Optional[str],
762
1211
  controller: bool = False,
763
- follow: bool = True) -> Tuple[str, int]:
1212
+ follow: bool = True,
1213
+ tail: Optional[int] = None) -> Tuple[str, int]:
764
1214
  """Stream logs by job id or job name.
765
1215
 
766
1216
  Returns:
@@ -776,7 +1226,8 @@ def stream_logs(job_id: Optional[int],
776
1226
  if controller:
777
1227
  if job_id is None:
778
1228
  assert job_name is not None
779
- managed_jobs = managed_job_state.get_managed_jobs()
1229
+ managed_jobs, _ = managed_job_state.get_managed_jobs_with_filters(
1230
+ name_match=job_name, fields=['job_id', 'job_name', 'status'])
780
1231
  # We manually filter the jobs by name, instead of using
781
1232
  # get_nonterminal_job_ids_by_name, as with `controller=True`, we
782
1233
  # should be able to show the logs for jobs in terminal states.
@@ -799,9 +1250,7 @@ def stream_logs(job_id: Optional[int],
799
1250
  job_id = managed_job_ids.pop()
800
1251
  assert job_id is not None, (job_id, job_name)
801
1252
 
802
- controller_log_path = os.path.join(
803
- os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR),
804
- f'{job_id}.log')
1253
+ controller_log_path = controller_log_file_for_job(job_id)
805
1254
  job_status = None
806
1255
 
807
1256
  # Wait for the log file to be written
@@ -831,7 +1280,12 @@ def stream_logs(job_id: Optional[int],
831
1280
  with open(controller_log_path, 'r', newline='', encoding='utf-8') as f:
832
1281
  # Note: we do not need to care about start_stream_at here, since
833
1282
  # that should be in the job log printed above.
834
- for line in f:
1283
+ read_from: Union[TextIO, Deque[str]] = f
1284
+ if tail is not None:
1285
+ assert tail > 0
1286
+ # Read only the last 'tail' lines efficiently using deque
1287
+ read_from = collections.deque(f, maxlen=tail)
1288
+ for line in read_from:
835
1289
  print(line, end='')
836
1290
  # Flush.
837
1291
  print(end='', flush=True)
@@ -883,61 +1337,384 @@ def stream_logs(job_id: Optional[int],
883
1337
  f'Multiple running jobs found with name {job_name!r}.')
884
1338
  job_id = job_ids[0]
885
1339
 
886
- return stream_logs_by_id(job_id, follow)
1340
+ return stream_logs_by_id(job_id, follow, tail)
1341
+
1342
+
1343
+ def dump_managed_job_queue(
1344
+ skip_finished: bool = False,
1345
+ accessible_workspaces: Optional[List[str]] = None,
1346
+ job_ids: Optional[List[int]] = None,
1347
+ workspace_match: Optional[str] = None,
1348
+ name_match: Optional[str] = None,
1349
+ pool_match: Optional[str] = None,
1350
+ page: Optional[int] = None,
1351
+ limit: Optional[int] = None,
1352
+ user_hashes: Optional[List[Optional[str]]] = None,
1353
+ statuses: Optional[List[str]] = None,
1354
+ fields: Optional[List[str]] = None,
1355
+ ) -> str:
1356
+ return message_utils.encode_payload(
1357
+ get_managed_job_queue(skip_finished, accessible_workspaces, job_ids,
1358
+ workspace_match, name_match, pool_match, page,
1359
+ limit, user_hashes, statuses, fields))
887
1360
 
888
1361
 
889
- def dump_managed_job_queue() -> str:
890
- jobs = managed_job_state.get_managed_jobs()
1362
+ def _update_fields(fields: List[str],) -> Tuple[List[str], bool]:
1363
+ """Update the fields list to include the necessary fields.
1364
+
1365
+ Args:
1366
+ fields: The fields to update.
1367
+
1368
+ It will:
1369
+ - Add the necessary dependent fields to the list.
1370
+ - Remove the fields that are not in the DB.
1371
+ - Determine if cluster handle is required.
1372
+
1373
+ Returns:
1374
+ A tuple containing the updated fields and a boolean indicating if
1375
+ cluster handle is required.
1376
+ """
1377
+ cluster_handle_required = True
1378
+ if _cluster_handle_not_required(fields):
1379
+ cluster_handle_required = False
1380
+ # Copy the list to avoid modifying the original list
1381
+ new_fields = fields.copy()
1382
+ # status and job_id are always included
1383
+ if 'status' not in new_fields:
1384
+ new_fields.append('status')
1385
+ if 'job_id' not in new_fields:
1386
+ new_fields.append('job_id')
1387
+ # user_hash is required if user_name is present
1388
+ if 'user_name' in new_fields and 'user_hash' not in new_fields:
1389
+ new_fields.append('user_hash')
1390
+ if 'job_duration' in new_fields:
1391
+ if 'last_recovered_at' not in new_fields:
1392
+ new_fields.append('last_recovered_at')
1393
+ if 'end_at' not in new_fields:
1394
+ new_fields.append('end_at')
1395
+ if 'job_name' in new_fields and 'task_name' not in new_fields:
1396
+ new_fields.append('task_name')
1397
+ if 'details' in new_fields:
1398
+ if 'schedule_state' not in new_fields:
1399
+ new_fields.append('schedule_state')
1400
+ if 'priority' not in new_fields:
1401
+ new_fields.append('priority')
1402
+ if 'failure_reason' not in new_fields:
1403
+ new_fields.append('failure_reason')
1404
+ if 'user_yaml' in new_fields:
1405
+ if 'original_user_yaml_path' not in new_fields:
1406
+ new_fields.append('original_user_yaml_path')
1407
+ if 'original_user_yaml_content' not in new_fields:
1408
+ new_fields.append('original_user_yaml_content')
1409
+ if cluster_handle_required:
1410
+ if 'task_name' not in new_fields:
1411
+ new_fields.append('task_name')
1412
+ if 'current_cluster_name' not in new_fields:
1413
+ new_fields.append('current_cluster_name')
1414
+ # Remove _NON_DB_FIELDS
1415
+ # These fields have been mapped to the DB fields in the above code, so we
1416
+ # don't need to include them in the updated fields.
1417
+ for field in _NON_DB_FIELDS:
1418
+ if field in new_fields:
1419
+ new_fields.remove(field)
1420
+ return new_fields, cluster_handle_required
1421
+
1422
+
1423
+ def _cluster_handle_not_required(fields: List[str]) -> bool:
1424
+ """Determine if cluster handle is not required.
1425
+
1426
+ Args:
1427
+ fields: The fields to check if they contain any of the cluster handle
1428
+ fields.
1429
+
1430
+ Returns:
1431
+ True if the fields do not contain any of the cluster handle fields,
1432
+ False otherwise.
1433
+ """
1434
+ return not any(field in fields for field in _CLUSTER_HANDLE_FIELDS)
1435
+
1436
+
1437
+ def get_managed_job_queue(
1438
+ skip_finished: bool = False,
1439
+ accessible_workspaces: Optional[List[str]] = None,
1440
+ job_ids: Optional[List[int]] = None,
1441
+ workspace_match: Optional[str] = None,
1442
+ name_match: Optional[str] = None,
1443
+ pool_match: Optional[str] = None,
1444
+ page: Optional[int] = None,
1445
+ limit: Optional[int] = None,
1446
+ user_hashes: Optional[List[Optional[str]]] = None,
1447
+ statuses: Optional[List[str]] = None,
1448
+ fields: Optional[List[str]] = None,
1449
+ ) -> Dict[str, Any]:
1450
+ """Get the managed job queue.
1451
+
1452
+ Args:
1453
+ skip_finished: Whether to skip finished jobs.
1454
+ accessible_workspaces: The accessible workspaces.
1455
+ job_ids: The job ids.
1456
+ workspace_match: The workspace name to match.
1457
+ name_match: The job name to match.
1458
+ pool_match: The pool name to match.
1459
+ page: The page number.
1460
+ limit: The limit number.
1461
+ user_hashes: The user hashes.
1462
+ statuses: The statuses.
1463
+ fields: The fields to include in the response.
1464
+
1465
+ Returns:
1466
+ A dictionary containing the managed job queue.
1467
+ """
1468
+ cluster_handle_required = True
1469
+ updated_fields = None
1470
+ # The caller only need to specify the fields in the
1471
+ # `class ManagedJobRecord` in `response.py`, and the `_update_fields`
1472
+ # function will add the necessary dependent fields to the list, for
1473
+ # example, if the caller specifies `['user_name']`, the `_update_fields`
1474
+ # function will add `['user_hash']` to the list.
1475
+ if fields:
1476
+ updated_fields, cluster_handle_required = _update_fields(fields)
1477
+
1478
+ total_no_filter = managed_job_state.get_managed_jobs_total()
1479
+
1480
+ status_counts = managed_job_state.get_status_count_with_filters(
1481
+ fields=fields,
1482
+ job_ids=job_ids,
1483
+ accessible_workspaces=accessible_workspaces,
1484
+ workspace_match=workspace_match,
1485
+ name_match=name_match,
1486
+ pool_match=pool_match,
1487
+ user_hashes=user_hashes,
1488
+ skip_finished=skip_finished,
1489
+ )
1490
+
1491
+ jobs, total = managed_job_state.get_managed_jobs_with_filters(
1492
+ fields=updated_fields,
1493
+ job_ids=job_ids,
1494
+ accessible_workspaces=accessible_workspaces,
1495
+ workspace_match=workspace_match,
1496
+ name_match=name_match,
1497
+ pool_match=pool_match,
1498
+ user_hashes=user_hashes,
1499
+ statuses=statuses,
1500
+ skip_finished=skip_finished,
1501
+ page=page,
1502
+ limit=limit,
1503
+ )
1504
+
1505
+ if cluster_handle_required:
1506
+ # Fetch the cluster name to handle map for managed clusters only.
1507
+ cluster_name_to_handle = (
1508
+ global_user_state.get_cluster_name_to_handle_map(is_managed=True))
1509
+
1510
+ highest_blocking_priority = constants.MIN_PRIORITY
1511
+ if not fields or 'details' in fields:
1512
+ # Figure out what the highest priority blocking job is. We need to know
1513
+ # in order to determine if other jobs are blocked by a higher priority
1514
+ # job, or just by the limited controller resources.
1515
+ highest_blocking_priority = (
1516
+ managed_job_state.get_managed_jobs_highest_priority())
891
1517
 
892
1518
  for job in jobs:
893
- end_at = job['end_at']
894
- if end_at is None:
895
- end_at = time.time()
896
-
897
- job_submitted_at = job['last_recovered_at'] - job['job_duration']
898
- if job['status'] == managed_job_state.ManagedJobStatus.RECOVERING:
899
- # When job is recovering, the duration is exact job['job_duration']
900
- job_duration = job['job_duration']
901
- elif job_submitted_at > 0:
902
- job_duration = end_at - job_submitted_at
903
- else:
904
- # When job_start_at <= 0, that means the last_recovered_at is not
905
- # set yet, i.e. the job is not started.
906
- job_duration = 0
907
- job['job_duration'] = job_duration
1519
+ if not fields or 'job_duration' in fields:
1520
+ end_at = job['end_at']
1521
+ if end_at is None:
1522
+ end_at = time.time()
1523
+
1524
+ job_submitted_at = job['last_recovered_at'] - job['job_duration']
1525
+ if job['status'] == managed_job_state.ManagedJobStatus.RECOVERING:
1526
+ # When job is recovering, the duration is exact
1527
+ # job['job_duration']
1528
+ job_duration = job['job_duration']
1529
+ elif job_submitted_at > 0:
1530
+ job_duration = end_at - job_submitted_at
1531
+ else:
1532
+ # When job_start_at <= 0, that means the last_recovered_at
1533
+ # is not set yet, i.e. the job is not started.
1534
+ job_duration = 0
1535
+ job['job_duration'] = job_duration
908
1536
  job['status'] = job['status'].value
909
- job['schedule_state'] = job['schedule_state'].value
910
-
911
- cluster_name = generate_managed_job_cluster_name(
912
- job['task_name'], job['job_id'])
913
- handle = global_user_state.get_handle_from_cluster_name(cluster_name)
914
- if handle is not None:
915
- assert isinstance(handle, backends.CloudVmRayResourceHandle)
916
- job['cluster_resources'] = (
917
- f'{handle.launched_nodes}x {handle.launched_resources}')
918
- job['region'] = handle.launched_resources.region
1537
+ if not fields or 'schedule_state' in fields:
1538
+ job['schedule_state'] = job['schedule_state'].value
919
1539
  else:
920
- # FIXME(zongheng): display the last cached values for these.
921
- job['cluster_resources'] = '-'
922
- job['region'] = '-'
1540
+ job['schedule_state'] = None
1541
+
1542
+ if cluster_handle_required:
1543
+ cluster_name = job.get('current_cluster_name', None)
1544
+ if cluster_name is None:
1545
+ cluster_name = generate_managed_job_cluster_name(
1546
+ job['task_name'], job['job_id'])
1547
+ handle = cluster_name_to_handle.get(
1548
+ cluster_name, None) if cluster_name is not None else None
1549
+ if isinstance(handle, backends.CloudVmRayResourceHandle):
1550
+ resources_str_simple, resources_str_full = (
1551
+ resources_utils.get_readable_resources_repr(
1552
+ handle, simplified_only=False))
1553
+ assert resources_str_full is not None
1554
+ job['cluster_resources'] = resources_str_simple
1555
+ job['cluster_resources_full'] = resources_str_full
1556
+ job['cloud'] = str(handle.launched_resources.cloud)
1557
+ job['region'] = handle.launched_resources.region
1558
+ job['zone'] = handle.launched_resources.zone
1559
+ job['infra'] = infra_utils.InfraInfo(
1560
+ str(handle.launched_resources.cloud),
1561
+ handle.launched_resources.region,
1562
+ handle.launched_resources.zone).formatted_str()
1563
+ job['accelerators'] = handle.launched_resources.accelerators
1564
+ else:
1565
+ # FIXME(zongheng): display the last cached values for these.
1566
+ job['cluster_resources'] = '-'
1567
+ job['cluster_resources_full'] = '-'
1568
+ job['cloud'] = '-'
1569
+ job['region'] = '-'
1570
+ job['zone'] = '-'
1571
+ job['infra'] = '-'
1572
+
1573
+ if not fields or 'details' in fields:
1574
+ # Add details about schedule state / backoff.
1575
+ state_details = None
1576
+ if job['schedule_state'] == 'ALIVE_BACKOFF':
1577
+ state_details = 'In backoff, waiting for resources'
1578
+ elif job['schedule_state'] in ('WAITING', 'ALIVE_WAITING'):
1579
+ priority = job.get('priority')
1580
+ if (priority is not None and
1581
+ priority < highest_blocking_priority):
1582
+ # Job is lower priority than some other blocking job.
1583
+ state_details = 'Waiting for higher priority jobs to launch'
1584
+ else:
1585
+ state_details = 'Waiting for other jobs to launch'
1586
+
1587
+ if state_details and job['failure_reason']:
1588
+ job['details'] = f'{state_details} - {job["failure_reason"]}'
1589
+ elif state_details:
1590
+ job['details'] = state_details
1591
+ elif job['failure_reason']:
1592
+ job['details'] = f'Failure: {job["failure_reason"]}'
1593
+ else:
1594
+ job['details'] = None
1595
+
1596
+ return {
1597
+ 'jobs': jobs,
1598
+ 'total': total,
1599
+ 'total_no_filter': total_no_filter,
1600
+ 'status_counts': status_counts
1601
+ }
1602
+
1603
+
1604
+ def filter_jobs(
1605
+ jobs: List[Dict[str, Any]],
1606
+ workspace_match: Optional[str],
1607
+ name_match: Optional[str],
1608
+ pool_match: Optional[str],
1609
+ page: Optional[int],
1610
+ limit: Optional[int],
1611
+ user_match: Optional[str] = None,
1612
+ enable_user_match: bool = False,
1613
+ statuses: Optional[List[str]] = None,
1614
+ ) -> Tuple[List[Dict[str, Any]], int, Dict[str, int]]:
1615
+ """Filter jobs based on the given criteria.
1616
+
1617
+ Args:
1618
+ jobs: List of jobs to filter.
1619
+ workspace_match: Workspace name to filter.
1620
+ name_match: Job name to filter.
1621
+ pool_match: Pool name to filter.
1622
+ page: Page to filter.
1623
+ limit: Limit to filter.
1624
+ user_match: User name to filter.
1625
+ enable_user_match: Whether to enable user match.
1626
+ statuses: Statuses to filter.
1627
+
1628
+ Returns:
1629
+ List of filtered jobs
1630
+ Total number of jobs
1631
+ Dictionary of status counts
1632
+ """
923
1633
 
924
- return message_utils.encode_payload(jobs)
1634
+ # TODO(hailong): refactor the whole function including the
1635
+ # `dump_managed_job_queue()` to use DB filtering.
1636
+
1637
+ def _pattern_matches(job: Dict[str, Any], key: str,
1638
+ pattern: Optional[str]) -> bool:
1639
+ if pattern is None:
1640
+ return True
1641
+ if key not in job:
1642
+ return False
1643
+ value = job[key]
1644
+ if not value:
1645
+ return False
1646
+ return pattern in str(value)
1647
+
1648
+ def _handle_page_and_limit(
1649
+ result: List[Dict[str, Any]],
1650
+ page: Optional[int],
1651
+ limit: Optional[int],
1652
+ ) -> List[Dict[str, Any]]:
1653
+ if page is None and limit is None:
1654
+ return result
1655
+ assert page is not None and limit is not None, (page, limit)
1656
+ # page starts from 1
1657
+ start = (page - 1) * limit
1658
+ end = min(start + limit, len(result))
1659
+ return result[start:end]
925
1660
 
1661
+ status_counts: Dict[str, int] = collections.defaultdict(int)
1662
+ result = []
1663
+ checks = [
1664
+ ('workspace', workspace_match),
1665
+ ('job_name', name_match),
1666
+ ('pool', pool_match),
1667
+ ]
1668
+ if enable_user_match:
1669
+ checks.append(('user_name', user_match))
926
1670
 
927
- def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
1671
+ for job in jobs:
1672
+ if not all(
1673
+ _pattern_matches(job, key, pattern) for key, pattern in checks):
1674
+ continue
1675
+ status_counts[job['status'].value] += 1
1676
+ if statuses:
1677
+ if job['status'].value not in statuses:
1678
+ continue
1679
+ result.append(job)
1680
+
1681
+ total = len(result)
1682
+
1683
+ return _handle_page_and_limit(result, page, limit), total, status_counts
1684
+
1685
+
1686
+ def load_managed_job_queue(
1687
+ payload: str
1688
+ ) -> Tuple[List[Dict[str, Any]], int, ManagedJobQueueResultType, int, Dict[
1689
+ str, int]]:
928
1690
  """Load job queue from json string."""
929
- jobs = message_utils.decode_payload(payload)
1691
+ result = message_utils.decode_payload(payload)
1692
+ result_type = ManagedJobQueueResultType.DICT
1693
+ status_counts: Dict[str, int] = {}
1694
+ if isinstance(result, dict):
1695
+ jobs: List[Dict[str, Any]] = result['jobs']
1696
+ total: int = result['total']
1697
+ status_counts = result.get('status_counts', {})
1698
+ total_no_filter: int = result.get('total_no_filter', total)
1699
+ else:
1700
+ jobs = result
1701
+ total = len(jobs)
1702
+ total_no_filter = total
1703
+ result_type = ManagedJobQueueResultType.LIST
1704
+
1705
+ all_users = global_user_state.get_all_users()
1706
+ all_users_map = {user.id: user.name for user in all_users}
930
1707
  for job in jobs:
931
1708
  job['status'] = managed_job_state.ManagedJobStatus(job['status'])
932
1709
  if 'user_hash' in job and job['user_hash'] is not None:
933
1710
  # Skip jobs that do not have user_hash info.
934
1711
  # TODO(cooperc): Remove check before 0.12.0.
935
- job['user_name'] = global_user_state.get_user(job['user_hash']).name
936
- return jobs
1712
+ job['user_name'] = all_users_map.get(job['user_hash'])
1713
+ return jobs, total, result_type, total_no_filter, status_counts
937
1714
 
938
1715
 
939
1716
  def _get_job_status_from_tasks(
940
- job_tasks: List[Dict[str, Any]]
1717
+ job_tasks: Union[List[responses.ManagedJobRecord], List[Dict[str, Any]]]
941
1718
  ) -> Tuple[managed_job_state.ManagedJobStatus, int]:
942
1719
  """Get the current task status and the current task id for a job."""
943
1720
  managed_task_status = managed_job_state.ManagedJobStatus.SUCCEEDED
@@ -949,7 +1726,7 @@ def _get_job_status_from_tasks(
949
1726
  # Use the first non-succeeded status.
950
1727
  if managed_task_status != managed_job_state.ManagedJobStatus.SUCCEEDED:
951
1728
  # TODO(zhwu): we should not blindly use the first non-
952
- # succeeded as the status could be changed to SUBMITTED
1729
+ # succeeded as the status could be changed to PENDING
953
1730
  # when going from one task to the next one, which can be
954
1731
  # confusing.
955
1732
  break
@@ -957,29 +1734,40 @@ def _get_job_status_from_tasks(
957
1734
 
958
1735
 
959
1736
  @typing.overload
960
- def format_job_table(tasks: List[Dict[str, Any]],
961
- show_all: bool,
962
- show_user: bool,
963
- return_rows: Literal[False] = False,
964
- max_jobs: Optional[int] = None) -> str:
1737
+ def format_job_table(
1738
+ tasks: List[Dict[str, Any]],
1739
+ show_all: bool,
1740
+ show_user: bool,
1741
+ return_rows: Literal[False] = False,
1742
+ pool_status: Optional[List[Dict[str, Any]]] = None,
1743
+ max_jobs: Optional[int] = None,
1744
+ job_status_counts: Optional[Dict[str, int]] = None,
1745
+ ) -> str:
965
1746
  ...
966
1747
 
967
1748
 
968
1749
  @typing.overload
969
- def format_job_table(tasks: List[Dict[str, Any]],
970
- show_all: bool,
971
- show_user: bool,
972
- return_rows: Literal[True],
973
- max_jobs: Optional[int] = None) -> List[List[str]]:
1750
+ def format_job_table(
1751
+ tasks: List[Dict[str, Any]],
1752
+ show_all: bool,
1753
+ show_user: bool,
1754
+ return_rows: Literal[True],
1755
+ pool_status: Optional[List[Dict[str, Any]]] = None,
1756
+ max_jobs: Optional[int] = None,
1757
+ job_status_counts: Optional[Dict[str, int]] = None,
1758
+ ) -> List[List[str]]:
974
1759
  ...
975
1760
 
976
1761
 
977
1762
  def format_job_table(
978
- tasks: List[Dict[str, Any]],
979
- show_all: bool,
980
- show_user: bool,
981
- return_rows: bool = False,
982
- max_jobs: Optional[int] = None) -> Union[str, List[List[str]]]:
1763
+ tasks: List[Dict[str, Any]],
1764
+ show_all: bool,
1765
+ show_user: bool,
1766
+ return_rows: bool = False,
1767
+ pool_status: Optional[List[Dict[str, Any]]] = None,
1768
+ max_jobs: Optional[int] = None,
1769
+ job_status_counts: Optional[Dict[str, int]] = None,
1770
+ ) -> Union[str, List[List[str]]]:
983
1771
  """Returns managed jobs as a formatted string.
984
1772
 
985
1773
  Args:
@@ -988,6 +1776,8 @@ def format_job_table(
988
1776
  max_jobs: The maximum number of jobs to show in the table.
989
1777
  return_rows: If True, return the rows as a list of strings instead of
990
1778
  all rows concatenated into a single string.
1779
+ pool_status: List of pool status dictionaries with replica_info.
1780
+ job_status_counts: The counts of each job status.
991
1781
 
992
1782
  Returns: A formatted string of managed jobs, if not `return_rows`; otherwise
993
1783
  a list of "rows" (each of which is a list of str).
@@ -1004,16 +1794,41 @@ def format_job_table(
1004
1794
  return (task['user'], task['job_id'])
1005
1795
  return task['job_id']
1006
1796
 
1797
+ def _get_job_id_to_worker_map(
1798
+ pool_status: Optional[List[Dict[str, Any]]]) -> Dict[int, int]:
1799
+ """Create a mapping from job_id to worker replica_id.
1800
+
1801
+ Args:
1802
+ pool_status: List of pool status dictionaries with replica_info.
1803
+
1804
+ Returns:
1805
+ Dictionary mapping job_id to replica_id (worker ID).
1806
+ """
1807
+ job_to_worker: Dict[int, int] = {}
1808
+ if pool_status is None:
1809
+ return job_to_worker
1810
+ for pool in pool_status:
1811
+ replica_info = pool.get('replica_info', [])
1812
+ for replica in replica_info:
1813
+ used_by = replica.get('used_by')
1814
+ if used_by is not None:
1815
+ job_to_worker[used_by] = replica.get('replica_id')
1816
+ return job_to_worker
1817
+
1818
+ # Create mapping from job_id to worker replica_id
1819
+ job_to_worker = _get_job_id_to_worker_map(pool_status)
1820
+
1007
1821
  for task in tasks:
1008
1822
  # The tasks within the same job_id are already sorted
1009
1823
  # by the task_id.
1010
1824
  jobs[get_hash(task)].append(task)
1011
1825
 
1012
- status_counts: Dict[str, int] = collections.defaultdict(int)
1826
+ workspaces = set()
1013
1827
  for job_tasks in jobs.values():
1014
- managed_job_status = _get_job_status_from_tasks(job_tasks)[0]
1015
- if not managed_job_status.is_terminal():
1016
- status_counts[managed_job_status.value] += 1
1828
+ workspaces.add(job_tasks[0].get('workspace',
1829
+ constants.SKYPILOT_DEFAULT_WORKSPACE))
1830
+
1831
+ show_workspace = len(workspaces) > 1 or show_all
1017
1832
 
1018
1833
  user_cols: List[str] = []
1019
1834
  if show_user:
@@ -1024,26 +1839,43 @@ def format_job_table(
1024
1839
  columns = [
1025
1840
  'ID',
1026
1841
  'TASK',
1842
+ *(['WORKSPACE'] if show_workspace else []),
1027
1843
  'NAME',
1028
1844
  *user_cols,
1029
- 'RESOURCES',
1845
+ 'REQUESTED',
1030
1846
  'SUBMITTED',
1031
1847
  'TOT. DURATION',
1032
1848
  'JOB DURATION',
1033
1849
  '#RECOVERIES',
1034
1850
  'STATUS',
1851
+ 'POOL',
1035
1852
  ]
1036
1853
  if show_all:
1037
1854
  # TODO: move SCHED. STATE to a separate flag (e.g. --debug)
1038
- columns += ['STARTED', 'CLUSTER', 'REGION', 'SCHED. STATE', 'DETAILS']
1855
+ columns += [
1856
+ 'WORKER_CLUSTER',
1857
+ 'WORKER_JOB_ID',
1858
+ 'STARTED',
1859
+ 'INFRA',
1860
+ 'RESOURCES',
1861
+ 'SCHED. STATE',
1862
+ 'DETAILS',
1863
+ 'GIT_COMMIT',
1864
+ ]
1039
1865
  if tasks_have_k8s_user:
1040
1866
  columns.insert(0, 'USER')
1041
1867
  job_table = log_utils.create_table(columns)
1042
1868
 
1043
1869
  status_counts: Dict[str, int] = collections.defaultdict(int)
1044
- for task in tasks:
1045
- if not task['status'].is_terminal():
1046
- status_counts[task['status'].value] += 1
1870
+ if job_status_counts:
1871
+ for status_value, count in job_status_counts.items():
1872
+ status = managed_job_state.ManagedJobStatus(status_value)
1873
+ if not status.is_terminal():
1874
+ status_counts[status_value] = count
1875
+ else:
1876
+ for task in tasks:
1877
+ if not task['status'].is_terminal():
1878
+ status_counts[task['status'].value] += 1
1047
1879
 
1048
1880
  all_tasks = tasks
1049
1881
  if max_jobs is not None:
@@ -1054,7 +1886,10 @@ def format_job_table(
1054
1886
  # by the task_id.
1055
1887
  jobs[get_hash(task)].append(task)
1056
1888
 
1057
- def generate_details(failure_reason: Optional[str]) -> str:
1889
+ def generate_details(details: Optional[str],
1890
+ failure_reason: Optional[str]) -> str:
1891
+ if details is not None:
1892
+ return details
1058
1893
  if failure_reason is not None:
1059
1894
  return f'Failure: {failure_reason}'
1060
1895
  return '-'
@@ -1083,6 +1918,8 @@ def format_job_table(
1083
1918
  for job_hash, job_tasks in jobs.items():
1084
1919
  if show_all:
1085
1920
  schedule_state = job_tasks[0]['schedule_state']
1921
+ workspace = job_tasks[0].get('workspace',
1922
+ constants.SKYPILOT_DEFAULT_WORKSPACE)
1086
1923
 
1087
1924
  if len(job_tasks) > 1:
1088
1925
  # Aggregate the tasks into a new row in the table.
@@ -1120,10 +1957,20 @@ def format_job_table(
1120
1957
 
1121
1958
  user_values = get_user_column_values(job_tasks[0])
1122
1959
 
1960
+ pool = job_tasks[0].get('pool')
1961
+ if pool is None:
1962
+ pool = '-'
1963
+
1964
+ # Add worker information if job is assigned to a worker
1123
1965
  job_id = job_hash[1] if tasks_have_k8s_user else job_hash
1966
+ # job_id is now always an integer, use it to look up worker
1967
+ if job_id in job_to_worker and pool != '-':
1968
+ pool = f'{pool} (worker={job_to_worker[job_id]})'
1969
+
1124
1970
  job_values = [
1125
1971
  job_id,
1126
1972
  '',
1973
+ *([''] if show_workspace else []),
1127
1974
  job_name,
1128
1975
  *user_values,
1129
1976
  '-',
@@ -1132,15 +1979,20 @@ def format_job_table(
1132
1979
  job_duration,
1133
1980
  recovery_cnt,
1134
1981
  status_str,
1982
+ pool,
1135
1983
  ]
1136
1984
  if show_all:
1985
+ details = job_tasks[current_task_id].get('details')
1137
1986
  failure_reason = job_tasks[current_task_id]['failure_reason']
1138
1987
  job_values.extend([
1988
+ '-',
1989
+ '-',
1139
1990
  '-',
1140
1991
  '-',
1141
1992
  '-',
1142
1993
  job_tasks[0]['schedule_state'],
1143
- generate_details(failure_reason),
1994
+ generate_details(details, failure_reason),
1995
+ job_tasks[0].get('metadata', {}).get('git_commit', '-'),
1144
1996
  ])
1145
1997
  if tasks_have_k8s_user:
1146
1998
  job_values.insert(0, job_tasks[0].get('user', '-'))
@@ -1153,9 +2005,20 @@ def format_job_table(
1153
2005
  0, task['job_duration'], absolute=True)
1154
2006
  submitted = log_utils.readable_time_duration(task['submitted_at'])
1155
2007
  user_values = get_user_column_values(task)
2008
+ task_workspace = '-' if len(job_tasks) > 1 else workspace
2009
+ pool = task.get('pool')
2010
+ if pool is None:
2011
+ pool = '-'
2012
+
2013
+ # Add worker information if task is assigned to a worker
2014
+ task_job_id = task['job_id']
2015
+ if task_job_id in job_to_worker and pool != '-':
2016
+ pool = f'{pool} (worker={job_to_worker[task_job_id]})'
2017
+
1156
2018
  values = [
1157
2019
  task['job_id'] if len(job_tasks) == 1 else ' \u21B3',
1158
2020
  task['task_id'] if len(job_tasks) > 1 else '-',
2021
+ *([task_workspace] if show_workspace else []),
1159
2022
  task['task_name'],
1160
2023
  *user_values,
1161
2024
  task['resources'],
@@ -1168,20 +2031,50 @@ def format_job_table(
1168
2031
  job_duration,
1169
2032
  task['recovery_count'],
1170
2033
  task['status'].colored_str(),
2034
+ pool,
1171
2035
  ]
1172
2036
  if show_all:
1173
2037
  # schedule_state is only set at the job level, so if we have
1174
2038
  # more than one task, only display on the aggregated row.
1175
2039
  schedule_state = (task['schedule_state']
1176
2040
  if len(job_tasks) == 1 else '-')
2041
+ infra_str = task.get('infra')
2042
+ if infra_str is None:
2043
+ cloud = task.get('cloud')
2044
+ if cloud is None:
2045
+ # Backward compatibility for old jobs controller without
2046
+ # cloud info returned, we parse it from the cluster
2047
+ # resources
2048
+ # TODO(zhwu): remove this after 0.12.0
2049
+ cloud = task['cluster_resources'].split('(')[0].split(
2050
+ 'x')[-1]
2051
+ task['cluster_resources'] = task[
2052
+ 'cluster_resources'].replace(f'{cloud}(',
2053
+ '(').replace(
2054
+ 'x ', 'x')
2055
+ region = task['region']
2056
+ zone = task.get('zone')
2057
+ if cloud == '-':
2058
+ cloud = None
2059
+ if region == '-':
2060
+ region = None
2061
+ if zone == '-':
2062
+ zone = None
2063
+ infra_str = infra_utils.InfraInfo(cloud, region,
2064
+ zone).formatted_str()
1177
2065
  values.extend([
2066
+ task.get('current_cluster_name', '-'),
2067
+ task.get('job_id_on_pool_cluster', '-'),
1178
2068
  # STARTED
1179
2069
  log_utils.readable_time_duration(task['start_at']),
2070
+ infra_str,
1180
2071
  task['cluster_resources'],
1181
- task['region'],
1182
2072
  schedule_state,
1183
- generate_details(task['failure_reason']),
2073
+ generate_details(task.get('details'),
2074
+ task['failure_reason']),
1184
2075
  ])
2076
+
2077
+ values.append(task.get('metadata', {}).get('git_commit', '-'))
1185
2078
  if tasks_have_k8s_user:
1186
2079
  values.insert(0, task.get('user', '-'))
1187
2080
  job_table.add_row(values)
@@ -1204,6 +2097,59 @@ def format_job_table(
1204
2097
  return output
1205
2098
 
1206
2099
 
2100
+ def decode_managed_job_protos(
2101
+ job_protos: Iterable['managed_jobsv1_pb2.ManagedJobInfo']
2102
+ ) -> List[Dict[str, Any]]:
2103
+ """Decode job protos to dicts. Similar to load_managed_job_queue."""
2104
+ user_hash_to_user = global_user_state.get_users(
2105
+ set(job.user_hash for job in job_protos if job.user_hash))
2106
+
2107
+ jobs = []
2108
+ for job_proto in job_protos:
2109
+ job_dict = _job_proto_to_dict(job_proto)
2110
+ user_hash = job_dict.get('user_hash', None)
2111
+ if user_hash is not None:
2112
+ # Skip jobs that do not have user_hash info.
2113
+ # TODO(cooperc): Remove check before 0.12.0.
2114
+ user = user_hash_to_user.get(user_hash, None)
2115
+ job_dict['user_name'] = user.name if user is not None else None
2116
+ jobs.append(job_dict)
2117
+ return jobs
2118
+
2119
+
2120
+ def _job_proto_to_dict(
2121
+ job_proto: 'managed_jobsv1_pb2.ManagedJobInfo') -> Dict[str, Any]:
2122
+ job_dict = json_format.MessageToDict(
2123
+ job_proto,
2124
+ always_print_fields_with_no_presence=True,
2125
+ # Our API returns fields in snake_case.
2126
+ preserving_proto_field_name=True,
2127
+ use_integers_for_enums=True)
2128
+ for field in job_proto.DESCRIPTOR.fields:
2129
+ # Ensure optional fields are present with None values for
2130
+ # backwards compatibility with older clients.
2131
+ if field.has_presence and field.name not in job_dict:
2132
+ job_dict[field.name] = None
2133
+ # json_format.MessageToDict is meant for encoding to JSON,
2134
+ # and Protobuf encodes int64 as decimal strings in JSON,
2135
+ # so we need to convert them back to ints.
2136
+ # https://protobuf.dev/programming-guides/json/#field-representation
2137
+ if (field.type == descriptor.FieldDescriptor.TYPE_INT64 and
2138
+ job_dict.get(field.name) is not None):
2139
+ job_dict[field.name] = int(job_dict[field.name])
2140
+ job_dict['status'] = managed_job_state.ManagedJobStatus.from_protobuf(
2141
+ job_dict['status'])
2142
+ # For backwards compatibility, convert schedule_state to a string,
2143
+ # as we don't have the logic to handle it in our request
2144
+ # encoder/decoder, unlike status.
2145
+ schedule_state_enum = (
2146
+ managed_job_state.ManagedJobScheduleState.from_protobuf(
2147
+ job_dict['schedule_state']))
2148
+ job_dict['schedule_state'] = (schedule_state_enum.value
2149
+ if schedule_state_enum is not None else None)
2150
+ return job_dict
2151
+
2152
+
1207
2153
  class ManagedJobCodeGen:
1208
2154
  """Code generator for managed job utility functions.
1209
2155
 
@@ -1221,9 +2167,62 @@ class ManagedJobCodeGen:
1221
2167
  """)
1222
2168
 
1223
2169
  @classmethod
1224
- def get_job_table(cls) -> str:
1225
- code = textwrap.dedent("""\
1226
- job_table = utils.dump_managed_job_queue()
2170
+ def get_job_table(
2171
+ cls,
2172
+ skip_finished: bool = False,
2173
+ accessible_workspaces: Optional[List[str]] = None,
2174
+ job_ids: Optional[List[int]] = None,
2175
+ workspace_match: Optional[str] = None,
2176
+ name_match: Optional[str] = None,
2177
+ pool_match: Optional[str] = None,
2178
+ page: Optional[int] = None,
2179
+ limit: Optional[int] = None,
2180
+ user_hashes: Optional[List[Optional[str]]] = None,
2181
+ statuses: Optional[List[str]] = None,
2182
+ fields: Optional[List[str]] = None,
2183
+ ) -> str:
2184
+ code = textwrap.dedent(f"""\
2185
+ if managed_job_version < 9:
2186
+ # For backward compatibility, since filtering is not supported
2187
+ # before #6652.
2188
+ # TODO(hailong): Remove compatibility before 0.12.0
2189
+ job_table = utils.dump_managed_job_queue()
2190
+ elif managed_job_version < 10:
2191
+ job_table = utils.dump_managed_job_queue(
2192
+ skip_finished={skip_finished},
2193
+ accessible_workspaces={accessible_workspaces!r},
2194
+ job_ids={job_ids!r},
2195
+ workspace_match={workspace_match!r},
2196
+ name_match={name_match!r},
2197
+ pool_match={pool_match!r},
2198
+ page={page!r},
2199
+ limit={limit!r},
2200
+ user_hashes={user_hashes!r})
2201
+ elif managed_job_version < 12:
2202
+ job_table = utils.dump_managed_job_queue(
2203
+ skip_finished={skip_finished},
2204
+ accessible_workspaces={accessible_workspaces!r},
2205
+ job_ids={job_ids!r},
2206
+ workspace_match={workspace_match!r},
2207
+ name_match={name_match!r},
2208
+ pool_match={pool_match!r},
2209
+ page={page!r},
2210
+ limit={limit!r},
2211
+ user_hashes={user_hashes!r},
2212
+ statuses={statuses!r})
2213
+ else:
2214
+ job_table = utils.dump_managed_job_queue(
2215
+ skip_finished={skip_finished},
2216
+ accessible_workspaces={accessible_workspaces!r},
2217
+ job_ids={job_ids!r},
2218
+ workspace_match={workspace_match!r},
2219
+ name_match={name_match!r},
2220
+ pool_match={pool_match!r},
2221
+ page={page!r},
2222
+ limit={limit!r},
2223
+ user_hashes={user_hashes!r},
2224
+ statuses={statuses!r},
2225
+ fields={fields!r})
1227
2226
  print(job_table, flush=True)
1228
2227
  """)
1229
2228
  return cls._build(code)
@@ -1232,26 +2231,77 @@ class ManagedJobCodeGen:
1232
2231
  def cancel_jobs_by_id(cls,
1233
2232
  job_ids: Optional[List[int]],
1234
2233
  all_users: bool = False) -> str:
2234
+ active_workspace = skypilot_config.get_active_workspace()
1235
2235
  code = textwrap.dedent(f"""\
1236
2236
  if managed_job_version < 2:
1237
2237
  # For backward compatibility, since all_users is not supported
1238
- # before #4787. Assume th
2238
+ # before #4787.
1239
2239
  # TODO(cooperc): Remove compatibility before 0.12.0
1240
2240
  msg = utils.cancel_jobs_by_id({job_ids})
1241
- else:
2241
+ elif managed_job_version < 4:
2242
+ # For backward compatibility, since current_workspace is not
2243
+ # supported before #5660. Don't check the workspace.
2244
+ # TODO(zhwu): Remove compatibility before 0.12.0
1242
2245
  msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users})
2246
+ else:
2247
+ msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users},
2248
+ current_workspace={active_workspace!r})
1243
2249
  print(msg, end="", flush=True)
1244
2250
  """)
1245
2251
  return cls._build(code)
1246
2252
 
1247
2253
  @classmethod
1248
2254
  def cancel_job_by_name(cls, job_name: str) -> str:
2255
+ active_workspace = skypilot_config.get_active_workspace()
1249
2256
  code = textwrap.dedent(f"""\
1250
- msg = utils.cancel_job_by_name({job_name!r})
2257
+ if managed_job_version < 4:
2258
+ # For backward compatibility, since current_workspace is not
2259
+ # supported before #5660. Don't check the workspace.
2260
+ # TODO(zhwu): Remove compatibility before 0.12.0
2261
+ msg = utils.cancel_job_by_name({job_name!r})
2262
+ else:
2263
+ msg = utils.cancel_job_by_name({job_name!r}, {active_workspace!r})
1251
2264
  print(msg, end="", flush=True)
1252
2265
  """)
1253
2266
  return cls._build(code)
1254
2267
 
2268
+ @classmethod
2269
+ def cancel_jobs_by_pool(cls, pool_name: str) -> str:
2270
+ active_workspace = skypilot_config.get_active_workspace()
2271
+ code = textwrap.dedent(f"""\
2272
+ msg = utils.cancel_jobs_by_pool({pool_name!r}, {active_workspace!r})
2273
+ print(msg, end="", flush=True)
2274
+ """)
2275
+ return cls._build(code)
2276
+
2277
+ @classmethod
2278
+ def get_version_and_job_table(cls) -> str:
2279
+ """Generate code to get controller version and raw job table."""
2280
+ code = textwrap.dedent("""\
2281
+ from sky.skylet import constants as controller_constants
2282
+
2283
+ # Get controller version
2284
+ controller_version = controller_constants.SKYLET_VERSION
2285
+ print(f"controller_version:{controller_version}", flush=True)
2286
+
2287
+ # Get and print raw job table (load_managed_job_queue can parse this directly)
2288
+ job_table = utils.dump_managed_job_queue()
2289
+ print(job_table, flush=True)
2290
+ """)
2291
+ return cls._build(code)
2292
+
2293
+ @classmethod
2294
+ def get_version(cls) -> str:
2295
+ """Generate code to get controller version."""
2296
+ code = textwrap.dedent("""\
2297
+ from sky.skylet import constants as controller_constants
2298
+
2299
+ # Get controller version
2300
+ controller_version = controller_constants.SKYLET_VERSION
2301
+ print(f"controller_version:{controller_version}", flush=True)
2302
+ """)
2303
+ return cls._build(code)
2304
+
1255
2305
  @classmethod
1256
2306
  def get_all_job_ids_by_name(cls, job_name: Optional[str]) -> str:
1257
2307
  code = textwrap.dedent(f"""\
@@ -1266,10 +2316,16 @@ class ManagedJobCodeGen:
1266
2316
  job_name: Optional[str],
1267
2317
  job_id: Optional[int],
1268
2318
  follow: bool = True,
1269
- controller: bool = False) -> str:
2319
+ controller: bool = False,
2320
+ tail: Optional[int] = None) -> str:
1270
2321
  code = textwrap.dedent(f"""\
1271
- result = utils.stream_logs(job_id={job_id!r}, job_name={job_name!r},
1272
- follow={follow}, controller={controller})
2322
+ if managed_job_version < 6:
2323
+ # Versions before 5 did not support tail parameter
2324
+ result = utils.stream_logs(job_id={job_id!r}, job_name={job_name!r},
2325
+ follow={follow}, controller={controller})
2326
+ else:
2327
+ result = utils.stream_logs(job_id={job_id!r}, job_name={job_name!r},
2328
+ follow={follow}, controller={controller}, tail={tail!r})
1273
2329
  if managed_job_version < 3:
1274
2330
  # Versions 2 and older did not return a retcode, so we just print
1275
2331
  # the result.
@@ -1283,18 +2339,44 @@ class ManagedJobCodeGen:
1283
2339
  return cls._build(code)
1284
2340
 
1285
2341
  @classmethod
1286
- def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str:
2342
+ def set_pending(cls,
2343
+ job_id: int,
2344
+ managed_job_dag: 'dag_lib.Dag',
2345
+ workspace: str,
2346
+ entrypoint: str,
2347
+ user_hash: Optional[str] = None) -> str:
1287
2348
  dag_name = managed_job_dag.name
2349
+ pool = managed_job_dag.pool
1288
2350
  # Add the managed job to queue table.
1289
2351
  code = textwrap.dedent(f"""\
1290
- managed_job_state.set_job_info({job_id}, {dag_name!r})
2352
+ set_job_info_kwargs = {{'workspace': {workspace!r}}}
2353
+ if managed_job_version < 4:
2354
+ set_job_info_kwargs = {{}}
2355
+ if managed_job_version >= 5:
2356
+ set_job_info_kwargs['entrypoint'] = {entrypoint!r}
2357
+ if managed_job_version >= 8:
2358
+ from sky.serve import serve_state
2359
+ pool_hash = None
2360
+ if {pool!r} != None:
2361
+ pool_hash = serve_state.get_service_hash({pool!r})
2362
+ set_job_info_kwargs['pool'] = {pool!r}
2363
+ set_job_info_kwargs['pool_hash'] = pool_hash
2364
+ if managed_job_version >= 11:
2365
+ set_job_info_kwargs['user_hash'] = {user_hash!r}
2366
+ managed_job_state.set_job_info(
2367
+ {job_id}, {dag_name!r}, **set_job_info_kwargs)
1291
2368
  """)
1292
2369
  for task_id, task in enumerate(managed_job_dag.tasks):
1293
2370
  resources_str = backend_utils.get_task_resources_str(
1294
2371
  task, is_managed_job=True)
1295
2372
  code += textwrap.dedent(f"""\
1296
- managed_job_state.set_pending({job_id}, {task_id},
1297
- {task.name!r}, {resources_str!r})
2373
+ if managed_job_version < 7:
2374
+ managed_job_state.set_pending({job_id}, {task_id},
2375
+ {task.name!r}, {resources_str!r})
2376
+ else:
2377
+ managed_job_state.set_pending({job_id}, {task_id},
2378
+ {task.name!r}, {resources_str!r},
2379
+ {task.metadata_json!r})
1298
2380
  """)
1299
2381
  return cls._build(code)
1300
2382