skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/server/server.py ADDED
@@ -0,0 +1,1106 @@
1
+ """SkyPilot API Server exposing RESTful APIs."""
2
+
3
+ import argparse
4
+ import asyncio
5
+ import contextlib
6
+ import dataclasses
7
+ import datetime
8
+ import logging
9
+ import os
10
+ import pathlib
11
+ import re
12
+ import shutil
13
+ import sys
14
+ from typing import Dict, List, Literal, Optional, Set, Tuple
15
+ import uuid
16
+ import zipfile
17
+
18
+ import aiofiles
19
+ import fastapi
20
+ from fastapi.middleware import cors
21
+ import starlette.middleware.base
22
+
23
+ import sky
24
+ from sky import check as sky_check
25
+ from sky import clouds
26
+ from sky import core
27
+ from sky import exceptions
28
+ from sky import execution
29
+ from sky import global_user_state
30
+ from sky import optimizer
31
+ from sky import sky_logging
32
+ from sky.clouds import service_catalog
33
+ from sky.data import storage_utils
34
+ from sky.jobs.server import server as jobs_rest
35
+ from sky.provision.kubernetes import utils as kubernetes_utils
36
+ from sky.serve.server import server as serve_rest
37
+ from sky.server import common
38
+ from sky.server import constants as server_constants
39
+ from sky.server import stream_utils
40
+ from sky.server.requests import executor
41
+ from sky.server.requests import payloads
42
+ from sky.server.requests import requests as requests_lib
43
+ from sky.skylet import constants
44
+ from sky.usage import usage_lib
45
+ from sky.utils import common as common_lib
46
+ from sky.utils import common_utils
47
+ from sky.utils import dag_utils
48
+ from sky.utils import status_lib
49
+
50
+ # pylint: disable=ungrouped-imports
51
+ if sys.version_info >= (3, 10):
52
+ from typing import ParamSpec
53
+ else:
54
+ from typing_extensions import ParamSpec
55
+
56
+ P = ParamSpec('P')
57
+
58
+
59
+ def _add_timestamp_prefix_for_server_logs() -> None:
60
+ server_logger = sky_logging.init_logger('sky.server')
61
+ # Clear existing handlers first to prevent duplicates
62
+ server_logger.handlers.clear()
63
+ # Disable propagation to avoid the root logger of SkyPilot being affected
64
+ server_logger.propagate = False
65
+ # Add date prefix to the log message printed by loggers under
66
+ # server.
67
+ stream_handler = logging.StreamHandler(sys.stdout)
68
+ stream_handler.flush = sys.stdout.flush # type: ignore
69
+ stream_handler.setFormatter(sky_logging.FORMATTER)
70
+ server_logger.addHandler(stream_handler)
71
+ # Add date prefix to the log message printed by uvicorn.
72
+ for name in ['uvicorn', 'uvicorn.access']:
73
+ uvicorn_logger = logging.getLogger(name)
74
+ uvicorn_logger.handlers.clear()
75
+ uvicorn_logger.addHandler(stream_handler)
76
+
77
+
78
+ _add_timestamp_prefix_for_server_logs()
79
+ logger = sky_logging.init_logger(__name__)
80
+
81
+ # TODO(zhwu): Streaming requests, such log tailing after sky launch or sky logs,
82
+ # need to be detached from the main requests queue. Otherwise, the streaming
83
+ # response will block other requests from being processed.
84
+
85
+
86
+ class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
87
+ """Middleware to add a request ID to each request."""
88
+
89
+ async def dispatch(self, request: fastapi.Request, call_next):
90
+ request_id = str(uuid.uuid4())
91
+ request.state.request_id = request_id
92
+ response = await call_next(request)
93
+ response.headers['X-Request-ID'] = request_id
94
+ return response
95
+
96
+
97
+ # Default expiration time for upload ids before cleanup.
98
+ _DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
99
+ # Key: (upload_id, user_hash), Value: the time when the upload id needs to be
100
+ # cleaned up.
101
+ upload_ids_to_cleanup: Dict[Tuple[str, str], datetime.datetime] = {}
102
+
103
+
104
+ async def cleanup_upload_ids():
105
+ """Cleans up the temporary chunks uploaded by the client after a delay."""
106
+ # Clean up the temporary chunks uploaded by the client after an hour. This
107
+ # is to prevent stale chunks taking up space on the API server.
108
+ while True:
109
+ await asyncio.sleep(3600)
110
+ current_time = datetime.datetime.now()
111
+ # We use list() to avoid modifying the dict while iterating over it.
112
+ upload_ids_to_cleanup_list = list(upload_ids_to_cleanup.items())
113
+ for (upload_id, user_hash), expire_time in upload_ids_to_cleanup_list:
114
+ if current_time > expire_time:
115
+ logger.info(f'Cleaning up upload id: {upload_id}')
116
+ client_file_mounts_dir = (
117
+ common.API_SERVER_CLIENT_DIR.expanduser().resolve() /
118
+ user_hash / 'file_mounts')
119
+ shutil.rmtree(client_file_mounts_dir / upload_id,
120
+ ignore_errors=True)
121
+ (client_file_mounts_dir /
122
+ upload_id).with_suffix('.zip').unlink(missing_ok=True)
123
+ upload_ids_to_cleanup.pop((upload_id, user_hash))
124
+
125
+
126
+ @contextlib.asynccontextmanager
127
+ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
128
+ """FastAPI lifespan context manager."""
129
+ del app # unused
130
+ # Startup: Run background tasks
131
+ for event in requests_lib.INTERNAL_REQUEST_DAEMONS:
132
+ executor.schedule_request(
133
+ request_id=event.id,
134
+ request_name=event.name,
135
+ request_body=payloads.RequestBody(),
136
+ func=event.event_fn,
137
+ schedule_type=requests_lib.ScheduleType.SHORT,
138
+ is_skypilot_system=True,
139
+ )
140
+ asyncio.create_task(cleanup_upload_ids())
141
+ yield
142
+ # Shutdown: Add any cleanup code here if needed
143
+
144
+
145
+ app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
146
+ app.add_middleware(
147
+ cors.CORSMiddleware,
148
+ # TODO(zhwu): in production deployment, we should restrict the allowed
149
+ # origins to the domains that are allowed to access the API server.
150
+ allow_origins=['*'], # Specify the correct domains for production
151
+ allow_credentials=True,
152
+ allow_methods=['*'],
153
+ allow_headers=['*'],
154
+ expose_headers=['X-Request-ID'])
155
+ app.add_middleware(RequestIDMiddleware)
156
+ app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
157
+ app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
158
+
159
+
160
+ @app.post('/check')
161
+ async def check(request: fastapi.Request,
162
+ check_body: payloads.CheckBody) -> None:
163
+ """Checks enabled clouds."""
164
+ executor.schedule_request(
165
+ request_id=request.state.request_id,
166
+ request_name='check',
167
+ request_body=check_body,
168
+ func=sky_check.check,
169
+ schedule_type=requests_lib.ScheduleType.SHORT,
170
+ )
171
+
172
+
173
+ @app.get('/enabled_clouds')
174
+ async def enabled_clouds(request: fastapi.Request) -> None:
175
+ """Gets enabled clouds on the server."""
176
+ executor.schedule_request(
177
+ request_id=request.state.request_id,
178
+ request_name='enabled_clouds',
179
+ request_body=payloads.RequestBody(),
180
+ func=core.enabled_clouds,
181
+ schedule_type=requests_lib.ScheduleType.SHORT,
182
+ )
183
+
184
+
185
+ @app.post('/realtime_kubernetes_gpu_availability')
186
+ async def realtime_kubernetes_gpu_availability(
187
+ request: fastapi.Request,
188
+ realtime_gpu_availability_body: payloads.RealtimeGpuAvailabilityRequestBody
189
+ ) -> None:
190
+ """Gets real-time Kubernetes GPU availability."""
191
+ executor.schedule_request(
192
+ request_id=request.state.request_id,
193
+ request_name='realtime_kubernetes_gpu_availability',
194
+ request_body=realtime_gpu_availability_body,
195
+ func=core.realtime_kubernetes_gpu_availability,
196
+ schedule_type=requests_lib.ScheduleType.SHORT,
197
+ )
198
+
199
+
200
+ @app.post('/kubernetes_node_info')
201
+ async def kubernetes_node_info(
202
+ request: fastapi.Request,
203
+ kubernetes_node_info_body: payloads.KubernetesNodeInfoRequestBody
204
+ ) -> None:
205
+ """Gets Kubernetes node information."""
206
+ executor.schedule_request(
207
+ request_id=request.state.request_id,
208
+ request_name='kubernetes_node_info',
209
+ request_body=kubernetes_node_info_body,
210
+ func=kubernetes_utils.get_kubernetes_node_info,
211
+ schedule_type=requests_lib.ScheduleType.SHORT,
212
+ )
213
+
214
+
215
+ @app.get('/status_kubernetes')
216
+ async def status_kubernetes(request: fastapi.Request) -> None:
217
+ """Gets Kubernetes status."""
218
+ executor.schedule_request(
219
+ request_id=request.state.request_id,
220
+ request_name='status_kubernetes',
221
+ request_body=payloads.RequestBody(),
222
+ func=core.status_kubernetes,
223
+ schedule_type=requests_lib.ScheduleType.SHORT,
224
+ )
225
+
226
+
227
+ @app.post('/list_accelerators')
228
+ async def list_accelerators(
229
+ request: fastapi.Request,
230
+ list_accelerator_counts_body: payloads.ListAcceleratorsBody) -> None:
231
+ """Gets list of accelerators from cloud catalog."""
232
+ executor.schedule_request(
233
+ request_id=request.state.request_id,
234
+ request_name='list_accelerators',
235
+ request_body=list_accelerator_counts_body,
236
+ func=service_catalog.list_accelerators,
237
+ schedule_type=requests_lib.ScheduleType.SHORT,
238
+ )
239
+
240
+
241
+ @app.post('/list_accelerator_counts')
242
+ async def list_accelerator_counts(
243
+ request: fastapi.Request,
244
+ list_accelerator_counts_body: payloads.ListAcceleratorCountsBody
245
+ ) -> None:
246
+ """Gets list of accelerator counts from cloud catalog."""
247
+ executor.schedule_request(
248
+ request_id=request.state.request_id,
249
+ request_name='list_accelerator_counts',
250
+ request_body=list_accelerator_counts_body,
251
+ func=service_catalog.list_accelerator_counts,
252
+ schedule_type=requests_lib.ScheduleType.SHORT,
253
+ )
254
+
255
+
256
+ @app.post('/validate')
257
+ async def validate(validate_body: payloads.ValidateBody) -> None:
258
+ """Validates the user's DAG."""
259
+ # TODO(SKY-1035): validate if existing cluster satisfies the requested
260
+ # resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus
261
+ logger.debug(f'Validating tasks: {validate_body.dag}')
262
+ try:
263
+ dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
264
+ for task in dag.tasks:
265
+ # Will validate workdir and file_mounts in the backend, as those
266
+ # need to be validated after the files are uploaded to the SkyPilot
267
+ # API server with `upload_mounts_to_api_server`.
268
+ task.validate_name()
269
+ task.validate_run()
270
+ for r in task.resources:
271
+ r.validate()
272
+ except Exception as e: # pylint: disable=broad-except
273
+ raise fastapi.HTTPException(
274
+ status_code=400, detail=exceptions.serialize_exception(e)) from e
275
+
276
+
277
+ @app.post('/optimize')
278
+ async def optimize(optimize_body: payloads.OptimizeBody,
279
+ request: fastapi.Request) -> None:
280
+ """Optimizes the user's DAG."""
281
+ executor.schedule_request(
282
+ request_id=request.state.request_id,
283
+ request_name='optimize',
284
+ request_body=optimize_body,
285
+ ignore_return_value=True,
286
+ func=optimizer.Optimizer.optimize,
287
+ schedule_type=requests_lib.ScheduleType.SHORT,
288
+ )
289
+
290
+
291
+ @app.post('/upload')
292
+ async def upload_zip_file(request: fastapi.Request, user_hash: str,
293
+ upload_id: str, chunk_index: int,
294
+ total_chunks: int) -> payloads.UploadZipFileResponse:
295
+ """Uploads a zip file to the API server.
296
+
297
+ This endpoints can be called multiple times for the same upload_id with
298
+ different chunk_index. The server will merge the chunks and unzip the file
299
+ when all chunks are uploaded.
300
+
301
+ This implementation is simplified and may need to be improved in the future,
302
+ e.g., adopting S3-style multipart upload.
303
+
304
+ Args:
305
+ user_hash: The user hash.
306
+ upload_id: The upload id, a valid SkyPilot run_timestamp appended with 8
307
+ hex characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
308
+ chunk_index: The chunk index, starting from 0.
309
+ total_chunks: The total number of chunks.
310
+ """
311
+ # Add the upload id to the cleanup list.
312
+ upload_ids_to_cleanup[(upload_id,
313
+ user_hash)] = (datetime.datetime.now() +
314
+ _DEFAULT_UPLOAD_EXPIRATION_TIME)
315
+
316
+ # TODO(SKY-1271): We need to double check security of uploading zip file.
317
+ client_file_mounts_dir = (
318
+ common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_hash /
319
+ 'file_mounts')
320
+ client_file_mounts_dir.mkdir(parents=True, exist_ok=True)
321
+
322
+ # Check upload_id to be a valid SkyPilot run_timestamp appended with 8 hex
323
+ # characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
324
+ if not re.match(
325
+ r'sky-[0-9]{4}-[0-9]{2}-[0-9]{2}-[0-9]{2}-[0-9]{2}-'
326
+ r'[0-9]{2}-[0-9]{6}-[0-9a-f]{8}$', upload_id):
327
+ raise ValueError(
328
+ f'Invalid upload_id: {upload_id}. Please use a valid uuid.')
329
+ # Check chunk_index to be a valid integer
330
+ if chunk_index < 0 or chunk_index >= total_chunks:
331
+ raise ValueError(
332
+ f'Invalid chunk_index: {chunk_index}. Please use a valid integer.')
333
+ # Check total_chunks to be a valid integer
334
+ if total_chunks < 1:
335
+ raise ValueError(
336
+ f'Invalid total_chunks: {total_chunks}. Please use a valid integer.'
337
+ )
338
+
339
+ if total_chunks == 1:
340
+ zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
341
+ else:
342
+ chunk_dir = client_file_mounts_dir / upload_id
343
+ chunk_dir.mkdir(parents=True, exist_ok=True)
344
+ zip_file_path = chunk_dir / f'part{chunk_index}.incomplete'
345
+
346
+ try:
347
+ async with aiofiles.open(zip_file_path, 'wb') as f:
348
+ async for chunk in request.stream():
349
+ await f.write(chunk)
350
+ except starlette.requests.ClientDisconnect as e:
351
+ # Client disconnected, remove the zip file.
352
+ zip_file_path.unlink(missing_ok=True)
353
+ raise fastapi.HTTPException(
354
+ status_code=400,
355
+ detail='Client disconnected, please try again.') from e
356
+ except Exception as e:
357
+ logger.error(f'Error uploading zip file: {zip_file_path}')
358
+ # Client disconnected, remove the zip file.
359
+ zip_file_path.unlink(missing_ok=True)
360
+ raise fastapi.HTTPException(
361
+ status_code=500,
362
+ detail=('Error uploading zip file: '
363
+ f'{common_utils.format_exception(e)}'))
364
+
365
+ def get_missing_chunks(total_chunks: int) -> Set[str]:
366
+ return set(f'part{i}' for i in range(total_chunks)) - set(
367
+ p.name for p in chunk_dir.glob('part*'))
368
+
369
+ if total_chunks > 1:
370
+ zip_file_path.rename(zip_file_path.with_suffix(''))
371
+ missing_chunks = get_missing_chunks(total_chunks)
372
+ if missing_chunks:
373
+ return payloads.UploadZipFileResponse(status='uploading',
374
+ missing_chunks=missing_chunks)
375
+ zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
376
+ async with aiofiles.open(zip_file_path, 'wb') as zip_file:
377
+ for chunk in range(total_chunks):
378
+ async with aiofiles.open(chunk_dir / f'part{chunk}', 'rb') as f:
379
+ while True:
380
+ # Use 64KB buffer to avoid memory overflow, same size as
381
+ # shutil.copyfileobj.
382
+ data = await f.read(64 * 1024)
383
+ if not data:
384
+ break
385
+ await zip_file.write(data)
386
+
387
+ logger.info(f'Uploaded zip file: {zip_file_path}')
388
+ unzip_file(zip_file_path, client_file_mounts_dir)
389
+ if total_chunks > 1:
390
+ shutil.rmtree(chunk_dir)
391
+ return payloads.UploadZipFileResponse(status='completed')
392
+
393
+
394
+ def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
395
+ """Checks if path is a subpath of parent."""
396
+ try:
397
+ # We cannot use is_relative_to, as it is only added after 3.9.
398
+ path.relative_to(parent)
399
+ return True
400
+ except ValueError:
401
+ return False
402
+
403
+
404
+ def unzip_file(zip_file_path: pathlib.Path,
405
+ client_file_mounts_dir: pathlib.Path) -> None:
406
+ """Unzips a zip file."""
407
+ try:
408
+ with zipfile.ZipFile(zip_file_path, 'r') as zipf:
409
+ for member in zipf.infolist():
410
+ # Determine the new path
411
+ original_path = os.path.normpath(member.filename)
412
+ new_path = client_file_mounts_dir / original_path.lstrip('/')
413
+
414
+ if (member.external_attr >> 28) == 0xA:
415
+ # Symlink. Read the target path and create a symlink.
416
+ new_path.parent.mkdir(parents=True, exist_ok=True)
417
+ target = zipf.read(member).decode()
418
+ assert not os.path.isabs(target), target
419
+ # Since target is a relative path, we need to check that it
420
+ # is under `client_file_mounts_dir` for security.
421
+ full_target_path = (new_path.parent / target).resolve()
422
+ if not _is_relative_to(full_target_path,
423
+ client_file_mounts_dir):
424
+ raise ValueError(f'Symlink target {target} leads to a '
425
+ 'file not in userspace. Aborted.')
426
+
427
+ if new_path.exists() or new_path.is_symlink():
428
+ new_path.unlink(missing_ok=True)
429
+ new_path.symlink_to(
430
+ target,
431
+ target_is_directory=member.filename.endswith('/'))
432
+ continue
433
+
434
+ # Handle directories
435
+ if member.filename.endswith('/'):
436
+ new_path.mkdir(parents=True, exist_ok=True)
437
+ continue
438
+
439
+ # Handle files
440
+ new_path.parent.mkdir(parents=True, exist_ok=True)
441
+ with zipf.open(member) as member_file, new_path.open('wb') as f:
442
+ # Use shutil.copyfileobj to copy files in chunks, so it does
443
+ # not load the entire file into memory.
444
+ shutil.copyfileobj(member_file, f)
445
+ except zipfile.BadZipFile as e:
446
+ logger.error(f'Bad zip file: {zip_file_path}')
447
+ raise fastapi.HTTPException(
448
+ status_code=400,
449
+ detail=f'Invalid zip file: {common_utils.format_exception(e)}')
450
+ except Exception as e:
451
+ logger.error(f'Error unzipping file: {zip_file_path}')
452
+ raise fastapi.HTTPException(
453
+ status_code=500,
454
+ detail=(f'Error unzipping file: '
455
+ f'{common_utils.format_exception(e)}'))
456
+
457
+ # Cleanup the temporary file
458
+ zip_file_path.unlink()
459
+
460
+
461
+ @app.post('/launch')
462
+ async def launch(launch_body: payloads.LaunchBody,
463
+ request: fastapi.Request) -> None:
464
+ """Launches a cluster or task."""
465
+ request_id = request.state.request_id
466
+ logger.info(f'Launching request: {request_id}')
467
+ executor.schedule_request(
468
+ request_id,
469
+ request_name='launch',
470
+ request_body=launch_body,
471
+ func=execution.launch,
472
+ schedule_type=requests_lib.ScheduleType.LONG,
473
+ request_cluster_name=launch_body.cluster_name,
474
+ )
475
+
476
+
477
+ @app.post('/exec')
478
+ # pylint: disable=redefined-builtin
479
+ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
480
+ """Executes a task on an existing cluster."""
481
+ executor.schedule_request(
482
+ request_id=request.state.request_id,
483
+ request_name='exec',
484
+ request_body=exec_body,
485
+ func=execution.exec,
486
+ schedule_type=requests_lib.ScheduleType.LONG,
487
+ request_cluster_name=exec_body.cluster_name,
488
+ )
489
+
490
+
491
+ @app.post('/stop')
492
+ async def stop(request: fastapi.Request,
493
+ stop_body: payloads.StopOrDownBody) -> None:
494
+ """Stops a cluster."""
495
+ executor.schedule_request(
496
+ request_id=request.state.request_id,
497
+ request_name='stop',
498
+ request_body=stop_body,
499
+ func=core.stop,
500
+ schedule_type=requests_lib.ScheduleType.SHORT,
501
+ request_cluster_name=stop_body.cluster_name,
502
+ )
503
+
504
+
505
+ @app.post('/status')
506
+ async def status(
507
+ request: fastapi.Request,
508
+ status_body: payloads.StatusBody = payloads.StatusBody()
509
+ ) -> None:
510
+ """Gets cluster statuses."""
511
+ executor.schedule_request(
512
+ request_id=request.state.request_id,
513
+ request_name='status',
514
+ request_body=status_body,
515
+ func=core.status,
516
+ schedule_type=(requests_lib.ScheduleType.LONG if
517
+ status_body.refresh != common_lib.StatusRefreshMode.NONE
518
+ else requests_lib.ScheduleType.SHORT),
519
+ )
520
+
521
+
522
+ @app.post('/endpoints')
523
+ async def endpoints(request: fastapi.Request,
524
+ endpoint_body: payloads.EndpointsBody) -> None:
525
+ """Gets the endpoint for a given cluster and port number (endpoint)."""
526
+ executor.schedule_request(
527
+ request_id=request.state.request_id,
528
+ request_name='endpoints',
529
+ request_body=endpoint_body,
530
+ func=core.endpoints,
531
+ schedule_type=requests_lib.ScheduleType.SHORT,
532
+ request_cluster_name=endpoint_body.cluster,
533
+ )
534
+
535
+
536
+ @app.post('/down')
537
+ async def down(request: fastapi.Request,
538
+ down_body: payloads.StopOrDownBody) -> None:
539
+ """Tears down a cluster."""
540
+ executor.schedule_request(
541
+ request_id=request.state.request_id,
542
+ request_name='down',
543
+ request_body=down_body,
544
+ func=core.down,
545
+ schedule_type=requests_lib.ScheduleType.SHORT,
546
+ request_cluster_name=down_body.cluster_name,
547
+ )
548
+
549
+
550
+ @app.post('/start')
551
+ async def start(request: fastapi.Request,
552
+ start_body: payloads.StartBody) -> None:
553
+ """Restarts a cluster."""
554
+ executor.schedule_request(
555
+ request_id=request.state.request_id,
556
+ request_name='start',
557
+ request_body=start_body,
558
+ func=core.start,
559
+ schedule_type=requests_lib.ScheduleType.LONG,
560
+ request_cluster_name=start_body.cluster_name,
561
+ )
562
+
563
+
564
+ @app.post('/autostop')
565
+ async def autostop(request: fastapi.Request,
566
+ autostop_body: payloads.AutostopBody) -> None:
567
+ """Schedules an autostop/autodown for a cluster."""
568
+ executor.schedule_request(
569
+ request_id=request.state.request_id,
570
+ request_name='autostop',
571
+ request_body=autostop_body,
572
+ func=core.autostop,
573
+ schedule_type=requests_lib.ScheduleType.SHORT,
574
+ request_cluster_name=autostop_body.cluster_name,
575
+ )
576
+
577
+
578
+ @app.post('/queue')
579
+ async def queue(request: fastapi.Request,
580
+ queue_body: payloads.QueueBody) -> None:
581
+ """Gets the job queue of a cluster."""
582
+ executor.schedule_request(
583
+ request_id=request.state.request_id,
584
+ request_name='queue',
585
+ request_body=queue_body,
586
+ func=core.queue,
587
+ schedule_type=requests_lib.ScheduleType.SHORT,
588
+ request_cluster_name=queue_body.cluster_name,
589
+ )
590
+
591
+
592
+ @app.post('/job_status')
593
+ async def job_status(request: fastapi.Request,
594
+ job_status_body: payloads.JobStatusBody) -> None:
595
+ """Gets the status of a job."""
596
+ executor.schedule_request(
597
+ request_id=request.state.request_id,
598
+ request_name='job_status',
599
+ request_body=job_status_body,
600
+ func=core.job_status,
601
+ schedule_type=requests_lib.ScheduleType.SHORT,
602
+ request_cluster_name=job_status_body.cluster_name,
603
+ )
604
+
605
+
606
+ @app.post('/cancel')
607
+ async def cancel(request: fastapi.Request,
608
+ cancel_body: payloads.CancelBody) -> None:
609
+ """Cancels jobs on a cluster."""
610
+ executor.schedule_request(
611
+ request_id=request.state.request_id,
612
+ request_name='cancel',
613
+ request_body=cancel_body,
614
+ func=core.cancel,
615
+ schedule_type=requests_lib.ScheduleType.SHORT,
616
+ request_cluster_name=cancel_body.cluster_name,
617
+ )
618
+
619
+
620
+ @app.post('/logs')
621
+ async def logs(
622
+ request: fastapi.Request, cluster_job_body: payloads.ClusterJobBody,
623
+ background_tasks: fastapi.BackgroundTasks
624
+ ) -> fastapi.responses.StreamingResponse:
625
+ """Tails the logs of a job."""
626
+ # TODO(zhwu): This should wait for the request on the cluster, e.g., async
627
+ # launch, to finish, so that a user does not need to manually pull the
628
+ # request status.
629
+ executor.schedule_request(
630
+ request_id=request.state.request_id,
631
+ request_name='logs',
632
+ request_body=cluster_job_body,
633
+ func=core.tail_logs,
634
+ # TODO(aylei): We have tail logs scheduled as SHORT request, because it
635
+ # should be responsive. However, it can be long running if the user's
636
+ # job keeps running, and we should avoid it taking the SHORT worker.
637
+ schedule_type=requests_lib.ScheduleType.SHORT,
638
+ request_cluster_name=cluster_job_body.cluster_name,
639
+ )
640
+
641
+ request_task = requests_lib.get_request(request.state.request_id)
642
+
643
+ # TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
644
+ # the same approach as /stream.
645
+ return stream_utils.stream_response(
646
+ request_id=request_task.request_id,
647
+ logs_path=request_task.log_path,
648
+ background_tasks=background_tasks,
649
+ )
650
+
651
+
652
+ @app.post('/download_logs')
653
+ async def download_logs(
654
+ request: fastapi.Request,
655
+ cluster_jobs_body: payloads.ClusterJobsDownloadLogsBody) -> None:
656
+ """Downloads the logs of a job."""
657
+ user_hash = cluster_jobs_body.env_vars[constants.USER_ID_ENV_VAR]
658
+ logs_dir_on_api_server = common.api_server_user_logs_dir_prefix(user_hash)
659
+ logs_dir_on_api_server.expanduser().mkdir(parents=True, exist_ok=True)
660
+ # We should reuse the original request body, so that the env vars, such as
661
+ # user hash, are kept the same.
662
+ cluster_jobs_body.local_dir = str(logs_dir_on_api_server)
663
+ executor.schedule_request(
664
+ request_id=request.state.request_id,
665
+ request_name='download_logs',
666
+ request_body=cluster_jobs_body,
667
+ func=core.download_logs,
668
+ schedule_type=requests_lib.ScheduleType.SHORT,
669
+ request_cluster_name=cluster_jobs_body.cluster_name,
670
+ )
671
+
672
+
673
+ @app.post('/download')
674
+ async def download(download_body: payloads.DownloadBody) -> None:
675
+ """Downloads a folder from the cluster to the local machine."""
676
+ folder_paths = [
677
+ pathlib.Path(folder_path) for folder_path in download_body.folder_paths
678
+ ]
679
+ user_hash = download_body.env_vars[constants.USER_ID_ENV_VAR]
680
+ logs_dir_on_api_server = common.api_server_user_logs_dir_prefix(user_hash)
681
+ for folder_path in folder_paths:
682
+ if not str(folder_path).startswith(str(logs_dir_on_api_server)):
683
+ raise fastapi.HTTPException(
684
+ status_code=400,
685
+ detail=
686
+ f'Invalid folder path: {folder_path}; {logs_dir_on_api_server}')
687
+
688
+ if not folder_path.expanduser().resolve().exists():
689
+ raise fastapi.HTTPException(
690
+ status_code=404, detail=f'Folder not found: {folder_path}')
691
+
692
+ # Create a temporary zip file
693
+ log_id = str(uuid.uuid4().hex)
694
+ zip_filename = f'folder_{log_id}.zip'
695
+ zip_path = pathlib.Path(
696
+ logs_dir_on_api_server).expanduser().resolve() / zip_filename
697
+
698
+ try:
699
+ folders = [
700
+ str(folder_path.expanduser().resolve())
701
+ for folder_path in folder_paths
702
+ ]
703
+ storage_utils.zip_files_and_folders(folders, zip_path)
704
+
705
+ # Add home path to the response headers, so that the client can replace
706
+ # the remote path in the zip file to the local path.
707
+ headers = {
708
+ 'Content-Disposition': f'attachment; filename="{zip_filename}"',
709
+ 'X-Home-Path': str(pathlib.Path.home())
710
+ }
711
+
712
+ # Return the zip file as a download
713
+ return fastapi.responses.FileResponse(
714
+ path=zip_path,
715
+ filename=zip_filename,
716
+ media_type='application/zip',
717
+ headers=headers,
718
+ background=fastapi.BackgroundTasks().add_task(
719
+ lambda: zip_path.unlink(missing_ok=True)))
720
+ except Exception as e:
721
+ raise fastapi.HTTPException(status_code=500,
722
+ detail=f'Error creating zip file: {str(e)}')
723
+
724
+
725
+ @app.get('/cost_report')
726
+ async def cost_report(request: fastapi.Request) -> None:
727
+ """Gets the cost report of a cluster."""
728
+ executor.schedule_request(
729
+ request_id=request.state.request_id,
730
+ request_name='cost_report',
731
+ request_body=payloads.RequestBody(),
732
+ func=core.cost_report,
733
+ schedule_type=requests_lib.ScheduleType.SHORT,
734
+ )
735
+
736
+
737
+ @app.get('/storage/ls')
738
+ async def storage_ls(request: fastapi.Request) -> None:
739
+ """Gets the storages."""
740
+ executor.schedule_request(
741
+ request_id=request.state.request_id,
742
+ request_name='storage_ls',
743
+ request_body=payloads.RequestBody(),
744
+ func=core.storage_ls,
745
+ schedule_type=requests_lib.ScheduleType.SHORT,
746
+ )
747
+
748
+
749
+ @app.post('/storage/delete')
750
+ async def storage_delete(request: fastapi.Request,
751
+ storage_body: payloads.StorageBody) -> None:
752
+ """Deletes a storage."""
753
+ executor.schedule_request(
754
+ request_id=request.state.request_id,
755
+ request_name='storage_delete',
756
+ request_body=storage_body,
757
+ func=core.storage_delete,
758
+ schedule_type=requests_lib.ScheduleType.LONG,
759
+ )
760
+
761
+
762
+ @app.post('/local_up')
763
+ async def local_up(request: fastapi.Request,
764
+ local_up_body: payloads.LocalUpBody) -> None:
765
+ """Launches a Kubernetes cluster on API server."""
766
+ executor.schedule_request(
767
+ request_id=request.state.request_id,
768
+ request_name='local_up',
769
+ request_body=local_up_body,
770
+ func=core.local_up,
771
+ schedule_type=requests_lib.ScheduleType.LONG,
772
+ )
773
+
774
+
775
+ @app.post('/local_down')
776
+ async def local_down(request: fastapi.Request) -> None:
777
+ """Tears down the Kubernetes cluster started by local_up."""
778
+ executor.schedule_request(
779
+ request_id=request.state.request_id,
780
+ request_name='local_down',
781
+ request_body=payloads.RequestBody(),
782
+ func=core.local_down,
783
+ schedule_type=requests_lib.ScheduleType.LONG,
784
+ )
785
+
786
+
787
+ # === API server related APIs ===
788
+ @app.get('/api/get')
789
+ async def api_get(request_id: str) -> requests_lib.RequestPayload:
790
+ """Gets a request with a given request ID prefix."""
791
+ while True:
792
+ request_task = requests_lib.get_request(request_id)
793
+ if request_task is None:
794
+ print(f'No task with request ID {request_id}', flush=True)
795
+ raise fastapi.HTTPException(
796
+ status_code=404, detail=f'Request {request_id!r} not found')
797
+ if request_task.status > requests_lib.RequestStatus.RUNNING:
798
+ request_error = request_task.get_error()
799
+ if request_error is not None:
800
+ raise fastapi.HTTPException(status_code=500,
801
+ detail=dataclasses.asdict(
802
+ request_task.encode()))
803
+ return request_task.encode()
804
+ # yield control to allow other coroutines to run, sleep shortly
805
+ # to avoid storming the DB and CPU in the meantime
806
+ await asyncio.sleep(0.1)
807
+
808
+
809
+ @app.get('/api/stream')
810
+ async def stream(
811
+ request: fastapi.Request,
812
+ request_id: Optional[str] = None,
813
+ log_path: Optional[str] = None,
814
+ tail: Optional[int] = None,
815
+ follow: bool = True,
816
+ # Choices: 'auto', 'plain', 'html', 'console'
817
+ # 'auto': automatically choose between HTML and plain text
818
+ # based on the request source
819
+ # 'plain': plain text for HTML clients
820
+ # 'html': HTML for browsers
821
+ # 'console': console for CLI/API clients
822
+ # pylint: disable=redefined-builtin
823
+ format: Literal['auto', 'plain', 'html', 'console'] = 'auto',
824
+ ) -> fastapi.responses.Response:
825
+ """Streams the logs of a request.
826
+
827
+ When format is 'auto' and the request is coming from a browser, the response
828
+ is a HTML page with JavaScript to handle streaming, which will request the
829
+ API server again with format='plain' to get the actual log content.
830
+
831
+ Args:
832
+ request_id: Request ID to stream logs for.
833
+ log_path: Log path to stream logs for.
834
+ tail: Number of lines to stream from the end of the log file.
835
+ follow: Whether to follow the log file.
836
+ format: Response format - 'auto' (HTML for browsers, plain for HTML
837
+ clients, console for CLI/API clients), 'plain' (force plain text),
838
+ 'html' (force HTML), or 'console' (force console)
839
+ """
840
+ if request_id is not None and log_path is not None:
841
+ raise fastapi.HTTPException(
842
+ status_code=400,
843
+ detail='Only one of request_id and log_path can be provided')
844
+
845
+ if request_id is None and log_path is None:
846
+ request_id = requests_lib.get_latest_request_id()
847
+ if request_id is None:
848
+ raise fastapi.HTTPException(status_code=404,
849
+ detail='No request found')
850
+
851
+ # Determine if we should use HTML format
852
+ if format == 'auto':
853
+ # Check if request is coming from a browser
854
+ user_agent = request.headers.get('user-agent', '').lower()
855
+ use_html = any(browser in user_agent
856
+ for browser in ['mozilla', 'chrome', 'safari', 'edge'])
857
+ else:
858
+ use_html = format == 'html'
859
+
860
+ if use_html:
861
+ # Return HTML page with JavaScript to handle streaming
862
+ stream_url = request.url.include_query_params(format='plain')
863
+ html_dir = pathlib.Path(__file__).parent / 'html'
864
+ with open(html_dir / 'log.html', 'r', encoding='utf-8') as file:
865
+ html_content = file.read()
866
+ return fastapi.responses.HTMLResponse(
867
+ html_content.replace('{stream_url}', str(stream_url)),
868
+ headers={
869
+ 'Cache-Control': 'no-cache, no-transform',
870
+ 'X-Accel-Buffering': 'no'
871
+ })
872
+
873
+ # Original plain text streaming logic
874
+ if request_id is not None:
875
+ request_task = requests_lib.get_request(request_id)
876
+ if request_task is None:
877
+ print(f'No task with request ID {request_id}')
878
+ raise fastapi.HTTPException(
879
+ status_code=404, detail=f'Request {request_id!r} not found')
880
+ log_path_to_stream = request_task.log_path
881
+ else:
882
+ assert log_path is not None, (request_id, log_path)
883
+ if log_path == constants.API_SERVER_LOGS:
884
+ resolved_log_path = pathlib.Path(
885
+ constants.API_SERVER_LOGS).expanduser()
886
+ else:
887
+ # This should be a log path under ~/sky_logs.
888
+ resolved_logs_directory = pathlib.Path(
889
+ constants.SKY_LOGS_DIRECTORY).expanduser().resolve()
890
+ resolved_log_path = resolved_logs_directory.joinpath(
891
+ log_path).resolve()
892
+ # Make sure the log path is under ~/sky_logs. We calculate the
893
+ # common path to check if the log path is under ~/sky_logs.
894
+ # This prevents path traversal using '..'
895
+ if os.path.commonpath([resolved_log_path, resolved_logs_directory
896
+ ]) != str(resolved_logs_directory):
897
+ raise fastapi.HTTPException(
898
+ status_code=400,
899
+ detail=f'Unauthorized log path: {log_path!r}')
900
+ elif not resolved_log_path.exists():
901
+ raise fastapi.HTTPException(
902
+ status_code=404,
903
+ detail=f'Log path {log_path!r} does not exist')
904
+
905
+ log_path_to_stream = resolved_log_path
906
+ return fastapi.responses.StreamingResponse(
907
+ content=stream_utils.log_streamer(request_id,
908
+ log_path_to_stream,
909
+ plain_logs=format == 'plain',
910
+ tail=tail,
911
+ follow=follow),
912
+ media_type='text/plain',
913
+ headers={
914
+ 'Cache-Control': 'no-cache, no-transform',
915
+ 'X-Accel-Buffering': 'no',
916
+ 'Transfer-Encoding': 'chunked'
917
+ },
918
+ )
919
+
920
+
921
+ @app.post('/api/cancel')
922
+ async def api_cancel(request: fastapi.Request,
923
+ request_cancel_body: payloads.RequestCancelBody) -> None:
924
+ """Cancels requests."""
925
+ executor.schedule_request(
926
+ request_id=request.state.request_id,
927
+ request_name='api_cancel',
928
+ request_body=request_cancel_body,
929
+ func=requests_lib.kill_requests,
930
+ schedule_type=requests_lib.ScheduleType.SHORT,
931
+ )
932
+
933
+
934
+ @app.get('/api/status')
935
+ async def api_status(
936
+ request_ids: Optional[List[str]] = fastapi.Query(
937
+ None, description='Request IDs to get status for.'),
938
+ all_status: bool = fastapi.Query(
939
+ False, description='Get finished requests as well.'),
940
+ ) -> List[requests_lib.RequestPayload]:
941
+ """Gets the list of requests."""
942
+ if request_ids is None:
943
+ statuses = None
944
+ if not all_status:
945
+ statuses = [
946
+ requests_lib.RequestStatus.PENDING,
947
+ requests_lib.RequestStatus.RUNNING,
948
+ ]
949
+ return [
950
+ request_task.readable_encode()
951
+ for request_task in requests_lib.get_request_tasks(status=statuses)
952
+ ]
953
+ else:
954
+ encoded_request_tasks = []
955
+ for request_id in request_ids:
956
+ request_task = requests_lib.get_request(request_id)
957
+ if request_task is None:
958
+ continue
959
+ encoded_request_tasks.append(request_task.readable_encode())
960
+ return encoded_request_tasks
961
+
962
+
963
+ @app.get('/api/health')
964
+ async def health() -> Dict[str, str]:
965
+ """Checks the health of the API server.
966
+
967
+ Returns:
968
+ A dictionary with the following keys:
969
+ - status: str; The status of the API server.
970
+ - api_version: str; The API version of the API server.
971
+ - commit: str; The commit hash of SkyPilot used for API server.
972
+ - version: str; The version of SkyPilot used for API server.
973
+ """
974
+ return {
975
+ 'status': common.ApiServerStatus.HEALTHY.value,
976
+ 'api_version': server_constants.API_VERSION,
977
+ 'commit': sky.__commit__,
978
+ 'version': sky.__version__,
979
+ }
980
+
981
+
982
+ @app.websocket('/kubernetes-pod-ssh-proxy')
983
+ async def kubernetes_pod_ssh_proxy(
984
+ websocket: fastapi.WebSocket,
985
+ cluster_name_body: payloads.ClusterNameBody = fastapi.Depends()
986
+ ) -> None:
987
+ """Proxies SSH to the Kubernetes pod with websocket."""
988
+ await websocket.accept()
989
+ cluster_name = cluster_name_body.cluster_name
990
+ logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
991
+
992
+ cluster_records = core.status(cluster_name, all_users=True)
993
+ cluster_record = cluster_records[0]
994
+ if cluster_record['status'] != status_lib.ClusterStatus.UP:
995
+ raise fastapi.HTTPException(
996
+ status_code=400, detail=f'Cluster {cluster_name} is not running')
997
+
998
+ handle = cluster_record['handle']
999
+ assert handle is not None, 'Cluster handle is None'
1000
+ if not isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
1001
+ raise fastapi.HTTPException(
1002
+ status_code=400,
1003
+ detail=f'Cluster {cluster_name} is not a Kubernetes cluster'
1004
+ 'Use ssh to connect to the cluster instead.')
1005
+
1006
+ kubectl_cmd = handle.get_command_runners()[0].port_forward_command(
1007
+ port_forward=[(None, 22)])
1008
+ proc = await asyncio.create_subprocess_exec(
1009
+ *kubectl_cmd,
1010
+ stdout=asyncio.subprocess.PIPE,
1011
+ stderr=asyncio.subprocess.STDOUT)
1012
+ logger.info(f'Started kubectl port-forward with command: {kubectl_cmd}')
1013
+
1014
+ # Wait for port-forward to be ready and get the local port
1015
+ local_port = None
1016
+ assert proc.stdout is not None
1017
+ while True:
1018
+ stdout_line = await proc.stdout.readline()
1019
+ if stdout_line:
1020
+ decoded_line = stdout_line.decode()
1021
+ logger.info(f'kubectl port-forward stdout: {decoded_line}')
1022
+ if 'Forwarding from 127.0.0.1' in decoded_line:
1023
+ port_str = decoded_line.split(':')[-1]
1024
+ local_port = int(port_str.replace(' -> ', ':').split(':')[0])
1025
+ break
1026
+ else:
1027
+ await websocket.close()
1028
+ return
1029
+
1030
+ logger.info(f'Starting port-forward to local port: {local_port}')
1031
+ try:
1032
+ # Connect to the local port
1033
+ reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
1034
+
1035
+ async def websocket_to_ssh():
1036
+ try:
1037
+ async for message in websocket.iter_bytes():
1038
+ writer.write(message)
1039
+ await writer.drain()
1040
+ except fastapi.WebSocketDisconnect:
1041
+ pass
1042
+ writer.close()
1043
+
1044
+ async def ssh_to_websocket():
1045
+ try:
1046
+ while True:
1047
+ data = await reader.read(1024)
1048
+ if not data:
1049
+ break
1050
+ await websocket.send_bytes(data)
1051
+ except Exception: # pylint: disable=broad-except
1052
+ pass
1053
+ await websocket.close()
1054
+
1055
+ await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
1056
+ finally:
1057
+ proc.terminate()
1058
+
1059
+
1060
+ # === Internal APIs ===
1061
+ @app.get('/api/completion/cluster_name')
1062
+ async def complete_cluster_name(incomplete: str,) -> List[str]:
1063
+ return global_user_state.get_cluster_names_start_with(incomplete)
1064
+
1065
+
1066
+ @app.get('/api/completion/storage_name')
1067
+ async def complete_storage_name(incomplete: str,) -> List[str]:
1068
+ return global_user_state.get_storage_names_start_with(incomplete)
1069
+
1070
+
1071
+ if __name__ == '__main__':
1072
+ import uvicorn
1073
+ requests_lib.reset_db_and_logs()
1074
+
1075
+ parser = argparse.ArgumentParser()
1076
+ parser.add_argument('--host', default='127.0.0.1')
1077
+ parser.add_argument('--port', default=46580, type=int)
1078
+ parser.add_argument('--deploy', action='store_true')
1079
+ cmd_args = parser.parse_args()
1080
+ # Show the privacy policy if it is not already shown. We place it here so
1081
+ # that it is shown only when the API server is started.
1082
+ usage_lib.maybe_show_privacy_policy()
1083
+
1084
+ num_workers = None
1085
+ if cmd_args.deploy:
1086
+ num_workers = os.cpu_count()
1087
+
1088
+ sub_procs = []
1089
+ try:
1090
+ sub_procs = executor.start(cmd_args.deploy)
1091
+ logger.info('Starting SkyPilot API server')
1092
+ # We don't support reload for now, since it may cause leakage of request
1093
+ # workers or interrupt running requests.
1094
+ uvicorn.run('sky.server.server:app',
1095
+ host=cmd_args.host,
1096
+ port=cmd_args.port,
1097
+ workers=num_workers)
1098
+ except Exception as exc: # pylint: disable=broad-except
1099
+ logger.error(f'Failed to start SkyPilot API server: '
1100
+ f'{common_utils.format_exception(exc, use_bracket=True)}')
1101
+ raise
1102
+ finally:
1103
+ logger.info('Shutting down SkyPilot API server...')
1104
+ for sub_proc in sub_procs:
1105
+ sub_proc.terminate()
1106
+ sub_proc.join()