skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
  import asyncio
3
3
  import logging
4
4
  import threading
5
- from typing import Dict, Union
5
+ from typing import Dict, Optional, Union
6
6
 
7
7
  import aiohttp
8
8
  import fastapi
@@ -27,20 +27,34 @@ class SkyServeLoadBalancer:
27
27
  policy.
28
28
  """
29
29
 
30
- def __init__(self, controller_url: str, load_balancer_port: int) -> None:
30
+ def __init__(
31
+ self,
32
+ controller_url: str,
33
+ load_balancer_port: int,
34
+ load_balancing_policy_name: Optional[str] = None,
35
+ tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
31
36
  """Initialize the load balancer.
32
37
 
33
38
  Args:
34
39
  controller_url: The URL of the controller.
35
40
  load_balancer_port: The port where the load balancer listens to.
41
+ load_balancing_policy_name: The name of the load balancing policy
42
+ to use. Defaults to None.
43
+ tls_credentials: The TLS credentials for HTTPS endpoint. Defaults
44
+ to None.
36
45
  """
37
46
  self._app = fastapi.FastAPI()
38
47
  self._controller_url: str = controller_url
39
48
  self._load_balancer_port: int = load_balancer_port
40
- self._load_balancing_policy: lb_policies.LoadBalancingPolicy = (
41
- lb_policies.RoundRobinPolicy())
49
+ # Use the registry to create the load balancing policy
50
+ self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
51
+ load_balancing_policy_name)
52
+ logger.info('Starting load balancer with policy '
53
+ f'{load_balancing_policy_name}.')
42
54
  self._request_aggregator: serve_utils.RequestsAggregator = (
43
55
  serve_utils.RequestTimestamp())
56
+ self._tls_credential: Optional[serve_utils.TLSCredential] = (
57
+ tls_credential)
44
58
  # TODO(tian): httpx.Client has a resource limit of 100 max connections
45
59
  # for each client. We should wait for feedback on the best max
46
60
  # connections.
@@ -79,7 +93,7 @@ class SkyServeLoadBalancer:
79
93
  'request_aggregator':
80
94
  self._request_aggregator.to_dict()
81
95
  },
82
- timeout=5,
96
+ timeout=aiohttp.ClientTimeout(5),
83
97
  ) as response:
84
98
  # Clean up after reporting request info to avoid OOM.
85
99
  self._request_aggregator.clear()
@@ -122,6 +136,7 @@ class SkyServeLoadBalancer:
122
136
  encountered if anything goes wrong.
123
137
  """
124
138
  logger.info(f'Proxy request to {url}')
139
+ self._load_balancing_policy.pre_execute_hook(url, request)
125
140
  try:
126
141
  # We defer the get of the client here on purpose, for case when the
127
142
  # replica is ready in `_proxy_with_retries` but refreshed before
@@ -141,11 +156,16 @@ class SkyServeLoadBalancer:
141
156
  content=await request.body(),
142
157
  timeout=constants.LB_STREAM_TIMEOUT)
143
158
  proxy_response = await client.send(proxy_request, stream=True)
159
+
160
+ async def background_func():
161
+ await proxy_response.aclose()
162
+ self._load_balancing_policy.post_execute_hook(url, request)
163
+
144
164
  return fastapi.responses.StreamingResponse(
145
165
  content=proxy_response.aiter_raw(),
146
166
  status_code=proxy_response.status_code,
147
167
  headers=proxy_response.headers,
148
- background=background.BackgroundTask(proxy_response.aclose))
168
+ background=background.BackgroundTask(background_func))
149
169
  except (httpx.RequestError, httpx.HTTPStatusError) as e:
150
170
  logger.error(f'Error when proxy request to {url}: '
151
171
  f'{common_utils.format_exception(e)}')
@@ -217,15 +237,38 @@ class SkyServeLoadBalancer:
217
237
  # Register controller synchronization task
218
238
  asyncio.create_task(self._sync_with_controller())
219
239
 
240
+ uvicorn_tls_kwargs = ({} if self._tls_credential is None else
241
+ self._tls_credential.dump_uvicorn_kwargs())
242
+
243
+ protocol = 'https' if self._tls_credential is not None else 'http'
244
+
220
245
  logger.info('SkyServe Load Balancer started on '
221
- f'http://0.0.0.0:{self._load_balancer_port}')
246
+ f'{protocol}://0.0.0.0:{self._load_balancer_port}')
247
+
248
+ uvicorn.run(self._app,
249
+ host='0.0.0.0',
250
+ port=self._load_balancer_port,
251
+ **uvicorn_tls_kwargs)
222
252
 
223
- uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)
224
253
 
254
+ def run_load_balancer(
255
+ controller_addr: str,
256
+ load_balancer_port: int,
257
+ load_balancing_policy_name: Optional[str] = None,
258
+ tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
259
+ """ Run the load balancer.
225
260
 
226
- def run_load_balancer(controller_addr: str, load_balancer_port: int):
227
- load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
228
- load_balancer_port=load_balancer_port)
261
+ Args:
262
+ controller_addr: The address of the controller.
263
+ load_balancer_port: The port where the load balancer listens to.
264
+ policy_name: The name of the load balancing policy to use. Defaults to
265
+ None.
266
+ """
267
+ load_balancer = SkyServeLoadBalancer(
268
+ controller_url=controller_addr,
269
+ load_balancer_port=load_balancer_port,
270
+ load_balancing_policy_name=load_balancing_policy_name,
271
+ tls_credential=tls_credential)
229
272
  load_balancer.run()
230
273
 
231
274
 
@@ -241,5 +284,13 @@ if __name__ == '__main__':
241
284
  required=True,
242
285
  default=8890,
243
286
  help='The port where the load balancer listens to.')
287
+ available_policies = list(lb_policies.LB_POLICIES.keys())
288
+ parser.add_argument(
289
+ '--load-balancing-policy',
290
+ choices=available_policies,
291
+ default=lb_policies.DEFAULT_LB_POLICY,
292
+ help=f'The load balancing policy to use. Available policies: '
293
+ f'{", ".join(available_policies)}.')
244
294
  args = parser.parse_args()
245
- run_load_balancer(args.controller_addr, args.load_balancer_port)
295
+ run_load_balancer(args.controller_addr, args.load_balancer_port,
296
+ args.load_balancing_policy)
@@ -1,7 +1,9 @@
1
1
  """LoadBalancingPolicy: Policy to select endpoint."""
2
+ import collections
2
3
  import random
4
+ import threading
3
5
  import typing
4
- from typing import List, Optional
6
+ from typing import Dict, List, Optional
5
7
 
6
8
  from sky import sky_logging
7
9
 
@@ -10,6 +12,14 @@ if typing.TYPE_CHECKING:
10
12
 
11
13
  logger = sky_logging.init_logger(__name__)
12
14
 
15
+ # Define a registry for load balancing policies
16
+ LB_POLICIES = {}
17
+ DEFAULT_LB_POLICY = None
18
+ # Prior to #4439, the default policy was round_robin. We store the legacy
19
+ # default policy here to maintain backwards compatibility. Remove this after
20
+ # 2 minor release, i.e., 0.9.0.
21
+ LEGACY_DEFAULT_POLICY = 'round_robin'
22
+
13
23
 
14
24
  def _request_repr(request: 'fastapi.Request') -> str:
15
25
  return ('<Request '
@@ -25,6 +35,30 @@ class LoadBalancingPolicy:
25
35
  def __init__(self) -> None:
26
36
  self.ready_replicas: List[str] = []
27
37
 
38
+ def __init_subclass__(cls, name: str, default: bool = False):
39
+ LB_POLICIES[name] = cls
40
+ if default:
41
+ global DEFAULT_LB_POLICY
42
+ assert DEFAULT_LB_POLICY is None, (
43
+ 'Only one policy can be default.')
44
+ DEFAULT_LB_POLICY = name
45
+
46
+ @classmethod
47
+ def make_policy_name(cls, policy_name: Optional[str]) -> str:
48
+ """Return the policy name."""
49
+ assert DEFAULT_LB_POLICY is not None, 'No default policy set.'
50
+ if policy_name is None:
51
+ return DEFAULT_LB_POLICY
52
+ return policy_name
53
+
54
+ @classmethod
55
+ def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
56
+ """Create a load balancing policy from a name."""
57
+ policy_name = cls.make_policy_name(policy_name)
58
+ if policy_name not in LB_POLICIES:
59
+ raise ValueError(f'Unknown load balancing policy: {policy_name}')
60
+ return LB_POLICIES[policy_name]()
61
+
28
62
  def set_ready_replicas(self, ready_replicas: List[str]) -> None:
29
63
  raise NotImplementedError
30
64
 
@@ -43,8 +77,16 @@ class LoadBalancingPolicy:
43
77
  def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
44
78
  raise NotImplementedError
45
79
 
80
+ def pre_execute_hook(self, replica_url: str,
81
+ request: 'fastapi.Request') -> None:
82
+ pass
83
+
84
+ def post_execute_hook(self, replica_url: str,
85
+ request: 'fastapi.Request') -> None:
86
+ pass
46
87
 
47
- class RoundRobinPolicy(LoadBalancingPolicy):
88
+
89
+ class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'):
48
90
  """Round-robin load balancing policy."""
49
91
 
50
92
  def __init__(self) -> None:
@@ -68,3 +110,43 @@ class RoundRobinPolicy(LoadBalancingPolicy):
68
110
  ready_replica_url = self.ready_replicas[self.index]
69
111
  self.index = (self.index + 1) % len(self.ready_replicas)
70
112
  return ready_replica_url
113
+
114
+
115
+ class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True):
116
+ """Least load load balancing policy."""
117
+
118
+ def __init__(self) -> None:
119
+ super().__init__()
120
+ self.load_map: Dict[str, int] = collections.defaultdict(int)
121
+ self.lock = threading.Lock()
122
+
123
+ def set_ready_replicas(self, ready_replicas: List[str]) -> None:
124
+ if set(self.ready_replicas) == set(ready_replicas):
125
+ return
126
+ with self.lock:
127
+ self.ready_replicas = ready_replicas
128
+ for r in self.ready_replicas:
129
+ if r not in ready_replicas:
130
+ del self.load_map[r]
131
+ for replica in ready_replicas:
132
+ self.load_map[replica] = self.load_map.get(replica, 0)
133
+
134
+ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
135
+ del request # Unused.
136
+ if not self.ready_replicas:
137
+ return None
138
+ with self.lock:
139
+ return min(self.ready_replicas,
140
+ key=lambda replica: self.load_map.get(replica, 0))
141
+
142
+ def pre_execute_hook(self, replica_url: str,
143
+ request: 'fastapi.Request') -> None:
144
+ del request # Unused.
145
+ with self.lock:
146
+ self.load_map[replica_url] += 1
147
+
148
+ def post_execute_hook(self, replica_url: str,
149
+ request: 'fastapi.Request') -> None:
150
+ del request # Unused.
151
+ with self.lock:
152
+ self.load_map[replica_url] -= 1
@@ -19,9 +19,9 @@ import sky
19
19
  from sky import backends
20
20
  from sky import core
21
21
  from sky import exceptions
22
+ from sky import execution
22
23
  from sky import global_user_state
23
24
  from sky import sky_logging
24
- from sky import status_lib
25
25
  from sky.backends import backend_utils
26
26
  from sky.serve import constants as serve_constants
27
27
  from sky.serve import serve_state
@@ -33,9 +33,11 @@ from sky.usage import usage_lib
33
33
  from sky.utils import common_utils
34
34
  from sky.utils import controller_utils
35
35
  from sky.utils import env_options
36
+ from sky.utils import status_lib
36
37
  from sky.utils import ux_utils
37
38
 
38
39
  if typing.TYPE_CHECKING:
40
+ from sky import resources
39
41
  from sky.serve import service_spec
40
42
 
41
43
  logger = sky_logging.init_logger(__name__)
@@ -94,12 +96,10 @@ def launch_cluster(replica_id: int,
94
96
  retry_cnt += 1
95
97
  try:
96
98
  usage_lib.messages.usage.set_internal()
97
- sky.launch(task,
98
- cluster_name,
99
- detach_setup=True,
100
- detach_run=True,
101
- retry_until_up=True,
102
- _is_launched_by_sky_serve_controller=True)
99
+ execution.launch(task,
100
+ cluster_name,
101
+ retry_until_up=True,
102
+ _is_launched_by_sky_serve_controller=True)
103
103
  logger.info(f'Replica cluster {cluster_name} launched.')
104
104
  except (exceptions.InvalidClusterNameError,
105
105
  exceptions.NoCloudAccessError,
@@ -147,7 +147,7 @@ def terminate_cluster(cluster_name: str,
147
147
  retry_cnt += 1
148
148
  try:
149
149
  usage_lib.messages.usage.set_internal()
150
- sky.down(cluster_name)
150
+ core.down(cluster_name)
151
151
  return
152
152
  except ValueError:
153
153
  # The cluster is already terminated.
@@ -170,12 +170,11 @@ def terminate_cluster(cluster_name: str,
170
170
  def _get_resources_ports(task_yaml: str) -> str:
171
171
  """Get the resources ports used by the task."""
172
172
  task = sky.Task.from_yaml(task_yaml)
173
- # Already checked all ports are the same in sky.serve.core.up
174
- assert len(task.resources) >= 1, task
175
- task_resources = list(task.resources)[0]
176
- # Already checked the resources have and only have one port
177
- # before upload the task yaml.
178
- return task_resources.ports[0]
173
+ # Already checked all ports are valid in sky.serve.core.up
174
+ assert task.resources, task
175
+ assert task.service is not None, task
176
+ assert task.service.ports is not None, task
177
+ return task.service.ports
179
178
 
180
179
 
181
180
  def _should_use_spot(task_yaml: str,
@@ -245,6 +244,8 @@ class ReplicaStatusProperty:
245
244
  is_scale_down: bool = False
246
245
  # The replica's spot instance was preempted.
247
246
  preempted: bool = False
247
+ # Whether the replica is purged.
248
+ purged: bool = False
248
249
 
249
250
  def remove_terminated_replica(self) -> bool:
250
251
  """Whether to remove the replica record from the replica table.
@@ -305,6 +306,8 @@ class ReplicaStatusProperty:
305
306
  return False
306
307
  if self.preempted:
307
308
  return False
309
+ if self.purged:
310
+ return False
308
311
  return True
309
312
 
310
313
  def to_replica_status(self) -> serve_state.ReplicaStatus:
@@ -488,6 +491,7 @@ class ReplicaInfo:
488
491
  self,
489
492
  readiness_path: str,
490
493
  post_data: Optional[Dict[str, Any]],
494
+ timeout: int,
491
495
  headers: Optional[Dict[str, str]],
492
496
  ) -> Tuple['ReplicaInfo', bool, float]:
493
497
  """Probe the readiness of the replica.
@@ -512,17 +516,15 @@ class ReplicaInfo:
512
516
  logger.info(f'Probing {replica_identity} with {readiness_path}.')
513
517
  if post_data is not None:
514
518
  msg += 'POST'
515
- response = requests.post(
516
- readiness_path,
517
- headers=headers,
518
- json=post_data,
519
- timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
519
+ response = requests.post(readiness_path,
520
+ json=post_data,
521
+ headers=headers,
522
+ timeout=timeout)
520
523
  else:
521
524
  msg += 'GET'
522
- response = requests.get(
523
- readiness_path,
524
- headers=headers,
525
- timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
525
+ response = requests.get(readiness_path,
526
+ headers=headers,
527
+ timeout=timeout)
526
528
  msg += (f' request to {replica_identity} returned status '
527
529
  f'code {response.status_code}')
528
530
  if response.status_code == 200:
@@ -580,8 +582,6 @@ class ReplicaManager:
580
582
  self.latest_version: int = serve_constants.INITIAL_VERSION
581
583
  # Oldest version among the currently provisioned and launched replicas
582
584
  self.least_recent_version: int = serve_constants.INITIAL_VERSION
583
- serve_state.add_or_update_version(self._service_name,
584
- self.latest_version, spec)
585
585
 
586
586
  def scale_up(self,
587
587
  resources_override: Optional[Dict[str, Any]] = None) -> None:
@@ -591,7 +591,7 @@ class ReplicaManager:
591
591
  """
592
592
  raise NotImplementedError
593
593
 
594
- def scale_down(self, replica_id: int) -> None:
594
+ def scale_down(self, replica_id: int, purge: bool = False) -> None:
595
595
  """Scale down replica with replica_id."""
596
596
  raise NotImplementedError
597
597
 
@@ -680,7 +680,8 @@ class SkyPilotReplicaManager(ReplicaManager):
680
680
  replica_id: int,
681
681
  sync_down_logs: bool,
682
682
  replica_drain_delay_seconds: int,
683
- is_scale_down: bool = False) -> None:
683
+ is_scale_down: bool = False,
684
+ purge: bool = False) -> None:
684
685
 
685
686
  if replica_id in self._launch_process_pool:
686
687
  info = serve_state.get_replica_info_from_id(self._service_name,
@@ -737,7 +738,8 @@ class SkyPilotReplicaManager(ReplicaManager):
737
738
  logger.info(f'\n== End of logs (Replica: {replica_id}) ==')
738
739
  with open(log_file_name, 'a',
739
740
  encoding='utf-8') as replica_log_file, open(
740
- job_log_file_name, 'r',
741
+ os.path.expanduser(job_log_file_name),
742
+ 'r',
741
743
  encoding='utf-8') as job_file:
742
744
  replica_log_file.write(job_file.read())
743
745
  else:
@@ -764,16 +766,18 @@ class SkyPilotReplicaManager(ReplicaManager):
764
766
  )
765
767
  info.status_property.sky_down_status = ProcessStatus.RUNNING
766
768
  info.status_property.is_scale_down = is_scale_down
769
+ info.status_property.purged = purge
767
770
  serve_state.add_or_update_replica(self._service_name, replica_id, info)
768
771
  p.start()
769
772
  self._down_process_pool[replica_id] = p
770
773
 
771
- def scale_down(self, replica_id: int) -> None:
774
+ def scale_down(self, replica_id: int, purge: bool = False) -> None:
772
775
  self._terminate_replica(
773
776
  replica_id,
774
777
  sync_down_logs=False,
775
778
  replica_drain_delay_seconds=_DEFAULT_DRAIN_SECONDS,
776
- is_scale_down=True)
779
+ is_scale_down=True,
780
+ purge=purge)
777
781
 
778
782
  def _handle_preemption(self, info: ReplicaInfo) -> bool:
779
783
  """Handle preemption of the replica if any error happened.
@@ -912,6 +916,8 @@ class SkyPilotReplicaManager(ReplicaManager):
912
916
  # since user should fixed the error before update.
913
917
  elif info.version != self.latest_version:
914
918
  removal_reason = 'for version outdated'
919
+ elif info.status_property.purged:
920
+ removal_reason = 'for purge'
915
921
  else:
916
922
  logger.info(f'Termination of replica {replica_id} '
917
923
  'finished. Replica info is kept since some '
@@ -972,7 +978,7 @@ class SkyPilotReplicaManager(ReplicaManager):
972
978
  if not info.status_property.should_track_service_status():
973
979
  continue
974
980
  # We use backend API to avoid usage collection in the
975
- # core.job_status.
981
+ # sdk.job_status.
976
982
  backend = backends.CloudVmRayBackend()
977
983
  handle = info.handle()
978
984
  assert handle is not None, info
@@ -990,9 +996,7 @@ class SkyPilotReplicaManager(ReplicaManager):
990
996
  # Re-raise the exception if it is not preempted.
991
997
  raise
992
998
  job_status = list(job_statuses.values())[0]
993
- if job_status in [
994
- job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
995
- ]:
999
+ if job_status in job_lib.JobStatus.user_code_failure_states():
996
1000
  info.status_property.user_app_failed = True
997
1001
  serve_state.add_or_update_replica(self._service_name,
998
1002
  info.replica_id, info)
@@ -1043,6 +1047,7 @@ class SkyPilotReplicaManager(ReplicaManager):
1043
1047
  (
1044
1048
  self._get_readiness_path(info.version),
1045
1049
  self._get_post_data(info.version),
1050
+ self._get_readiness_timeout_seconds(info.version),
1046
1051
  self._get_readiness_headers(info.version),
1047
1052
  ),
1048
1053
  ),)
@@ -1230,3 +1235,6 @@ class SkyPilotReplicaManager(ReplicaManager):
1230
1235
 
1231
1236
  def _get_initial_delay_seconds(self, version: int) -> int:
1232
1237
  return self._get_version_spec(version).initial_delay_seconds
1238
+
1239
+ def _get_readiness_timeout_seconds(self, version: int) -> int:
1240
+ return self._get_version_spec(version).readiness_timeout_seconds
sky/serve/serve_state.py CHANGED
@@ -11,22 +11,31 @@ from typing import Any, Dict, List, Optional, Tuple
11
11
  import colorama
12
12
 
13
13
  from sky.serve import constants
14
+ from sky.serve import load_balancing_policies as lb_policies
14
15
  from sky.utils import db_utils
15
16
 
16
17
  if typing.TYPE_CHECKING:
17
18
  from sky.serve import replica_managers
18
19
  from sky.serve import service_spec
19
20
 
20
- _DB_PATH = pathlib.Path(constants.SKYSERVE_METADATA_DIR) / 'services.db'
21
- _DB_PATH = _DB_PATH.expanduser().absolute()
22
- _DB_PATH.parents[0].mkdir(parents=True, exist_ok=True)
23
- _DB_PATH = str(_DB_PATH)
21
+
22
+ def _get_db_path() -> str:
23
+ """Workaround to collapse multi-step Path ops for type checker.
24
+ Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
25
+ """
26
+ path = pathlib.Path(constants.SKYSERVE_METADATA_DIR) / 'services.db'
27
+ path = path.expanduser().absolute()
28
+ path.parents[0].mkdir(parents=True, exist_ok=True)
29
+ return str(path)
30
+
31
+
32
+ _DB_PATH: str = _get_db_path()
24
33
 
25
34
 
26
35
  def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
27
36
  """Creates the service and replica tables if they do not exist."""
28
37
 
29
- # auto_restart column is deprecated.
38
+ # auto_restart and requested_resources column is deprecated.
30
39
  cursor.execute("""\
31
40
  CREATE TABLE IF NOT EXISTS services (
32
41
  name TEXT PRIMARY KEY,
@@ -46,28 +55,35 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
46
55
  PRIMARY KEY (service_name, replica_id))""")
47
56
  cursor.execute("""\
48
57
  CREATE TABLE IF NOT EXISTS version_specs (
49
- version INTEGER,
58
+ version INTEGER,
50
59
  service_name TEXT,
51
60
  spec BLOB,
52
61
  PRIMARY KEY (service_name, version))""")
53
62
  conn.commit()
54
63
 
64
+ # Backward compatibility.
65
+ db_utils.add_column_to_table(cursor, conn, 'services',
66
+ 'requested_resources_str', 'TEXT')
67
+ # Deprecated: switched to `active_versions` below for the version
68
+ # considered active by the load balancer. The
69
+ # authscaler/replica_manager version can be found in the
70
+ # version_specs table.
71
+ db_utils.add_column_to_table(
72
+ cursor, conn, 'services', 'current_version',
73
+ f'INTEGER DEFAULT {constants.INITIAL_VERSION}')
74
+ # The versions that is activated for the service. This is a list
75
+ # of integers in json format.
76
+ db_utils.add_column_to_table(cursor, conn, 'services', 'active_versions',
77
+ f'TEXT DEFAULT {json.dumps([])!r}')
78
+ db_utils.add_column_to_table(cursor, conn, 'services',
79
+ 'load_balancing_policy', 'TEXT DEFAULT NULL')
80
+ # Whether the service's load balancer is encrypted with TLS.
81
+ db_utils.add_column_to_table(cursor, conn, 'services', 'tls_encrypted',
82
+ 'INTEGER DEFAULT 0')
83
+ conn.commit()
84
+
55
85
 
56
- _DB = db_utils.SQLiteConn(_DB_PATH, create_table)
57
- # Backward compatibility.
58
- db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
59
- 'requested_resources_str', 'TEXT')
60
- # Deprecated: switched to `active_versions` below for the version considered
61
- # active by the load balancer. The authscaler/replica_manager version can be
62
- # found in the version_specs table.
63
- db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
64
- 'current_version',
65
- f'INTEGER DEFAULT {constants.INITIAL_VERSION}')
66
- # The versions that is activated for the service. This is a list of integers in
67
- # json format.
68
- db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
69
- 'active_versions',
70
- f'TEXT DEFAULT {json.dumps([])!r}')
86
+ db_utils.SQLiteConn(_DB_PATH, create_table)
71
87
  _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
72
88
 
73
89
 
@@ -215,7 +231,7 @@ class ServiceStatus(enum.Enum):
215
231
  for status in ReplicaStatus.failed_statuses()) > 0:
216
232
  return cls.FAILED
217
233
  # When min_replicas = 0, there is no (provisioning) replica.
218
- if len(replica_statuses) == 0:
234
+ if not replica_statuses:
219
235
  return cls.NO_REPLICA
220
236
  return cls.REPLICA_INIT
221
237
 
@@ -233,7 +249,8 @@ _SERVICE_STATUS_TO_COLOR = {
233
249
 
234
250
 
235
251
  def add_service(name: str, controller_job_id: int, policy: str,
236
- requested_resources_str: str, status: ServiceStatus) -> bool:
252
+ requested_resources_str: str, load_balancing_policy: str,
253
+ status: ServiceStatus, tls_encrypted: bool) -> bool:
237
254
  """Add a service in the database.
238
255
 
239
256
  Returns:
@@ -246,10 +263,11 @@ def add_service(name: str, controller_job_id: int, policy: str,
246
263
  """\
247
264
  INSERT INTO services
248
265
  (name, controller_job_id, status, policy,
249
- requested_resources_str)
250
- VALUES (?, ?, ?, ?, ?)""",
266
+ requested_resources_str, load_balancing_policy, tls_encrypted)
267
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
251
268
  (name, controller_job_id, status.value, policy,
252
- requested_resources_str))
269
+ requested_resources_str, load_balancing_policy,
270
+ int(tls_encrypted)))
253
271
 
254
272
  except sqlite3.IntegrityError as e:
255
273
  if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
@@ -315,8 +333,13 @@ def set_service_load_balancer_port(service_name: str,
315
333
 
316
334
  def _get_service_from_row(row) -> Dict[str, Any]:
317
335
  (current_version, name, controller_job_id, controller_port,
318
- load_balancer_port, status, uptime, policy, _, requested_resources,
319
- requested_resources_str, _, active_versions) = row[:13]
336
+ load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
337
+ _, active_versions, load_balancing_policy, tls_encrypted) = row[:15]
338
+ if load_balancing_policy is None:
339
+ # This entry in database was added in #4439, and it will always be set
340
+ # to a str value. If it is None, it means it is an legacy entry and is
341
+ # using the legacy default policy.
342
+ load_balancing_policy = lb_policies.LEGACY_DEFAULT_POLICY
320
343
  return {
321
344
  'name': name,
322
345
  'controller_job_id': controller_job_id,
@@ -332,11 +355,9 @@ def _get_service_from_row(row) -> Dict[str, Any]:
332
355
  # The versions that is active for the load balancer. This is a list of
333
356
  # integers in json format. This is mainly for display purpose.
334
357
  'active_versions': json.loads(active_versions),
335
- # TODO(tian): Backward compatibility.
336
- # Remove after 2 minor release, 0.6.0.
337
- 'requested_resources': pickle.loads(requested_resources)
338
- if requested_resources is not None else None,
339
358
  'requested_resources_str': requested_resources_str,
359
+ 'load_balancing_policy': load_balancing_policy,
360
+ 'tls_encrypted': bool(tls_encrypted),
340
361
  }
341
362
 
342
363
 
@@ -525,3 +546,12 @@ def delete_version(service_name: str, version: int) -> None:
525
546
  DELETE FROM version_specs
526
547
  WHERE service_name=(?)
527
548
  AND version=(?)""", (service_name, version))
549
+
550
+
551
+ def delete_all_versions(service_name: str) -> None:
552
+ """Deletes all versions from the database."""
553
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
554
+ cursor.execute(
555
+ """\
556
+ DELETE FROM version_specs
557
+ WHERE service_name=(?)""", (service_name,))