skypilot-nightly 1.0.0.dev20250804__py3-none-any.whl → 1.0.0.dev20250807__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (151) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/cloud_vm_ray_backend.py +33 -4
  3. sky/catalog/kubernetes_catalog.py +8 -0
  4. sky/catalog/nebius_catalog.py +0 -1
  5. sky/check.py +11 -1
  6. sky/client/cli/command.py +234 -100
  7. sky/client/sdk.py +30 -9
  8. sky/client/sdk_async.py +815 -0
  9. sky/clouds/kubernetes.py +6 -1
  10. sky/clouds/nebius.py +1 -4
  11. sky/dashboard/out/404.html +1 -1
  12. sky/dashboard/out/_next/static/YAirOGsV1z6B2RJ0VIUmD/_buildManifest.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/1141-a8a8f1adba34c892.js +11 -0
  14. sky/dashboard/out/_next/static/chunks/1871-980a395e92633a5c.js +6 -0
  15. sky/dashboard/out/_next/static/chunks/3785.6003d293cb83eab4.js +1 -0
  16. sky/dashboard/out/_next/static/chunks/{3698-7874720877646365.js → 3850-ff4a9a69d978632b.js} +1 -1
  17. sky/dashboard/out/_next/static/chunks/4725.29550342bd53afd8.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/{4937.d6bf67771e353356.js → 4937.a2baa2df5572a276.js} +1 -1
  19. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  20. sky/dashboard/out/_next/static/chunks/6601-3e21152fe16da09c.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/{691.6d99cbfba347cebf.js → 691.5eeedf82cc243343.js} +1 -1
  22. sky/dashboard/out/_next/static/chunks/6989-6129c1cfbcf51063.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/6990-0f886f16e0d55ff8.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/8056-019615038d6ce427.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/8252.62b0d23aed618bb2.js +16 -0
  26. sky/dashboard/out/_next/static/chunks/8969-318c3dca725e8e5d.js +1 -0
  27. sky/dashboard/out/_next/static/chunks/{9025.7937c16bc8623516.js → 9025.a1bef12d672bb66d.js} +1 -1
  28. sky/dashboard/out/_next/static/chunks/9159-11421c0f2909236f.js +1 -0
  29. sky/dashboard/out/_next/static/chunks/9360.85b0b1b4054574dd.js +31 -0
  30. sky/dashboard/out/_next/static/chunks/9666.cd4273f2a5c5802c.js +1 -0
  31. sky/dashboard/out/_next/static/chunks/{9847.4c46c5e229c78704.js → 9847.757720f3b40c0aa5.js} +1 -1
  32. sky/dashboard/out/_next/static/chunks/{9984.78ee6d2c6fa4b0e8.js → 9984.c5564679e467d245.js} +1 -1
  33. sky/dashboard/out/_next/static/chunks/pages/{_app-a67ae198457b9886.js → _app-1e6de35d15a8d432.js} +1 -1
  34. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6fd1d2d8441aa54b.js +11 -0
  35. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-155d477a6c3e04e2.js +1 -0
  36. sky/dashboard/out/_next/static/chunks/pages/clusters-b30460f683e6ba96.js +1 -0
  37. sky/dashboard/out/_next/static/chunks/pages/{config-8620d099cbef8608.js → config-dfb9bf07b13045f4.js} +1 -1
  38. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-13d53fffc03ccb52.js +1 -0
  39. sky/dashboard/out/_next/static/chunks/pages/infra-fc9222e26c8e2f0d.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-154f55cf8af55be5.js +11 -0
  41. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-f5ccf5d39d87aebe.js +21 -0
  42. sky/dashboard/out/_next/static/chunks/pages/jobs-cdc60fb5d371e16a.js +1 -0
  43. sky/dashboard/out/_next/static/chunks/pages/users-7ed36e44e779d5c7.js +1 -0
  44. sky/dashboard/out/_next/static/chunks/pages/volumes-c9695d657f78b5dc.js +1 -0
  45. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  46. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-f72f73bcef9541dc.js +1 -0
  47. sky/dashboard/out/_next/static/chunks/pages/workspaces-8f67be60165724cc.js +1 -0
  48. sky/dashboard/out/_next/static/chunks/webpack-76efbdad99742559.js +1 -0
  49. sky/dashboard/out/_next/static/css/4614e06482d7309e.css +3 -0
  50. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  51. sky/dashboard/out/clusters/[cluster].html +1 -1
  52. sky/dashboard/out/clusters.html +1 -1
  53. sky/dashboard/out/config.html +1 -1
  54. sky/dashboard/out/index.html +1 -1
  55. sky/dashboard/out/infra/[context].html +1 -1
  56. sky/dashboard/out/infra.html +1 -1
  57. sky/dashboard/out/jobs/[job].html +1 -1
  58. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  59. sky/dashboard/out/jobs.html +1 -1
  60. sky/dashboard/out/users.html +1 -1
  61. sky/dashboard/out/volumes.html +1 -1
  62. sky/dashboard/out/workspace/new.html +1 -1
  63. sky/dashboard/out/workspaces/[name].html +1 -1
  64. sky/dashboard/out/workspaces.html +1 -1
  65. sky/global_user_state.py +14 -2
  66. sky/jobs/__init__.py +2 -0
  67. sky/jobs/client/sdk.py +43 -2
  68. sky/jobs/client/sdk_async.py +135 -0
  69. sky/jobs/server/core.py +48 -1
  70. sky/jobs/server/server.py +52 -3
  71. sky/jobs/state.py +5 -1
  72. sky/jobs/utils.py +3 -1
  73. sky/provision/kubernetes/utils.py +30 -4
  74. sky/provision/nebius/instance.py +1 -0
  75. sky/provision/nebius/utils.py +9 -1
  76. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  77. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  78. sky/serve/client/impl.py +85 -1
  79. sky/serve/client/sdk.py +16 -47
  80. sky/serve/client/sdk_async.py +130 -0
  81. sky/serve/constants.py +3 -1
  82. sky/serve/controller.py +6 -3
  83. sky/serve/load_balancer.py +3 -1
  84. sky/serve/serve_state.py +93 -5
  85. sky/serve/serve_utils.py +200 -67
  86. sky/serve/server/core.py +13 -197
  87. sky/serve/server/impl.py +261 -23
  88. sky/serve/service.py +15 -3
  89. sky/server/auth/__init__.py +0 -0
  90. sky/server/auth/authn.py +46 -0
  91. sky/server/auth/oauth2_proxy.py +185 -0
  92. sky/server/common.py +119 -21
  93. sky/server/constants.py +1 -1
  94. sky/server/daemons.py +60 -11
  95. sky/server/requests/executor.py +5 -3
  96. sky/server/requests/payloads.py +19 -0
  97. sky/server/rest.py +114 -0
  98. sky/server/server.py +44 -40
  99. sky/setup_files/dependencies.py +2 -0
  100. sky/skylet/constants.py +1 -1
  101. sky/skylet/events.py +5 -1
  102. sky/skylet/skylet.py +3 -1
  103. sky/task.py +61 -21
  104. sky/templates/kubernetes-ray.yml.j2 +9 -0
  105. sky/templates/nebius-ray.yml.j2 +1 -0
  106. sky/templates/sky-serve-controller.yaml.j2 +1 -0
  107. sky/usage/usage_lib.py +8 -6
  108. sky/utils/annotations.py +8 -3
  109. sky/utils/common_utils.py +11 -1
  110. sky/utils/controller_utils.py +7 -0
  111. sky/utils/db/migration_utils.py +2 -2
  112. sky/utils/rich_utils.py +120 -0
  113. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/METADATA +22 -13
  114. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/RECORD +120 -112
  115. sky/client/sdk.pyi +0 -300
  116. sky/dashboard/out/_next/static/KiGGm4fK0CpmN6BT17jkh/_buildManifest.js +0 -1
  117. sky/dashboard/out/_next/static/chunks/1043-928582d4860fef92.js +0 -1
  118. sky/dashboard/out/_next/static/chunks/1141-3f10a5a9f697c630.js +0 -11
  119. sky/dashboard/out/_next/static/chunks/1664-22b00e32c9ff96a4.js +0 -1
  120. sky/dashboard/out/_next/static/chunks/1871-7e17c195296e2ea9.js +0 -6
  121. sky/dashboard/out/_next/static/chunks/2003.f90b06bb1f914295.js +0 -1
  122. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
  123. sky/dashboard/out/_next/static/chunks/3785.95524bc443db8260.js +0 -1
  124. sky/dashboard/out/_next/static/chunks/4725.42f21f250f91f65b.js +0 -1
  125. sky/dashboard/out/_next/static/chunks/4869.18e6a4361a380763.js +0 -16
  126. sky/dashboard/out/_next/static/chunks/5230-f3bb2663e442e86c.js +0 -1
  127. sky/dashboard/out/_next/static/chunks/6601-234b1cf963c7280b.js +0 -1
  128. sky/dashboard/out/_next/static/chunks/6989-983d3ae7a874de98.js +0 -1
  129. sky/dashboard/out/_next/static/chunks/6990-08b2a1cae076a943.js +0 -1
  130. sky/dashboard/out/_next/static/chunks/8969-9a8cca241b30db83.js +0 -1
  131. sky/dashboard/out/_next/static/chunks/938-40d15b6261ec8dc1.js +0 -1
  132. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-fa63e8b1d203f298.js +0 -11
  133. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-9e7df5fc761c95a7.js +0 -1
  134. sky/dashboard/out/_next/static/chunks/pages/clusters-956ad430075efee8.js +0 -1
  135. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-9cfd875eecb6eaf5.js +0 -1
  136. sky/dashboard/out/_next/static/chunks/pages/infra-0fbdc9072f19fbe2.js +0 -1
  137. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-6c5af4c86e6ab3d3.js +0 -11
  138. sky/dashboard/out/_next/static/chunks/pages/jobs-6393a9edc7322b54.js +0 -1
  139. sky/dashboard/out/_next/static/chunks/pages/users-34d6bb10c3b3ee3d.js +0 -1
  140. sky/dashboard/out/_next/static/chunks/pages/volumes-225c8dae0634eb7f.js +0 -1
  141. sky/dashboard/out/_next/static/chunks/pages/workspace/new-92f741084a89e27b.js +0 -1
  142. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-4d41c9023287f59a.js +0 -1
  143. sky/dashboard/out/_next/static/chunks/pages/workspaces-e4cb7e97d37e93ad.js +0 -1
  144. sky/dashboard/out/_next/static/chunks/webpack-13145516b19858fb.js +0 -1
  145. sky/dashboard/out/_next/static/css/b3227360726f12eb.css +0 -3
  146. /sky/dashboard/out/_next/static/{KiGGm4fK0CpmN6BT17jkh → YAirOGsV1z6B2RJ0VIUmD}/_ssgManifest.js +0 -0
  147. /sky/dashboard/out/_next/static/chunks/{6135-d0e285ac5f3f2485.js → 6135-85426374db04811e.js} +0 -0
  148. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/WHEEL +0 -0
  149. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/entry_points.txt +0 -0
  150. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/licenses/LICENSE +0 -0
  151. {skypilot_nightly-1.0.0.dev20250804.dist-info → skypilot_nightly-1.0.0.dev20250807.dist-info}/top_level.txt +0 -0
sky/serve/service.py CHANGED
@@ -112,6 +112,10 @@ def cleanup_storage(task_yaml: str) -> bool:
112
112
 
113
113
  def _cleanup(service_name: str) -> bool:
114
114
  """Clean up all service related resources, i.e. replicas and storage."""
115
+ # Cleanup the HA recovery script first as it is possible that some error
116
+ # was raised when we construct the task object (e.g.,
117
+ # sky.exceptions.ResourcesUnavailableError).
118
+ serve_state.remove_ha_recovery_script(service_name)
115
119
  failed = False
116
120
  replica_infos = serve_state.get_replica_infos(service_name)
117
121
  info2proc: Dict[replica_managers.ReplicaInfo,
@@ -172,7 +176,7 @@ def _cleanup_task_run_script(job_id: int) -> None:
172
176
  logger.warning(f'Task run script {this_task_run_script} not found')
173
177
 
174
178
 
175
- def _start(service_name: str, tmp_task_yaml: str, job_id: int):
179
+ def _start(service_name: str, tmp_task_yaml: str, job_id: int, entrypoint: str):
176
180
  """Starts the service.
177
181
  This including the controller and load balancer.
178
182
  """
@@ -223,7 +227,9 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
223
227
  load_balancing_policy=service_spec.load_balancing_policy,
224
228
  status=serve_state.ServiceStatus.CONTROLLER_INIT,
225
229
  tls_encrypted=service_spec.tls_credential is not None,
226
- pool=service_spec.pool)
230
+ pool=service_spec.pool,
231
+ controller_pid=os.getpid(),
232
+ entrypoint=entrypoint)
227
233
  # Directly throw an error here. See sky/serve/api.py::up
228
234
  # for more details.
229
235
  if not success:
@@ -241,6 +247,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
241
247
  # sync to a tmp file first and then copy it to the final name
242
248
  # if there is no name conflict.
243
249
  shutil.copy(tmp_task_yaml, service_task_yaml)
250
+ else:
251
+ serve_state.update_service_controller_pid(service_name, os.getpid())
244
252
 
245
253
  controller_process = None
246
254
  load_balancer_process = None
@@ -358,8 +366,12 @@ if __name__ == '__main__':
358
366
  required=True,
359
367
  type=int,
360
368
  help='Job id for the service job.')
369
+ parser.add_argument('--entrypoint',
370
+ type=str,
371
+ help='Entrypoint to launch the service',
372
+ required=True)
361
373
  args = parser.parse_args()
362
374
  # We start process with 'spawn', because 'fork' could result in weird
363
375
  # behaviors; 'spawn' is also cross-platform.
364
376
  multiprocessing.set_start_method('spawn', force=True)
365
- _start(args.service_name, args.task_yaml, args.job_id)
377
+ _start(args.service_name, args.task_yaml, args.job_id, args.entrypoint)
File without changes
@@ -0,0 +1,46 @@
1
+ """Authentication module."""
2
+ import json
3
+ from typing import Optional
4
+
5
+ import fastapi
6
+
7
+ from sky import models
8
+ from sky import sky_logging
9
+ from sky.skylet import constants
10
+
11
+ logger = sky_logging.init_logger(__name__)
12
+
13
+
14
+ # TODO(hailong): Remove this function and use request.state.auth_user instead.
15
+ async def override_user_info_in_request_body(request: fastapi.Request,
16
+ auth_user: Optional[models.User]):
17
+ if auth_user is None:
18
+ return
19
+
20
+ body = await request.body()
21
+ if body:
22
+ try:
23
+ original_json = await request.json()
24
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
25
+ logger.error(f'Error parsing request JSON: {e}')
26
+ else:
27
+ logger.debug(f'Overriding user for {request.state.request_id}: '
28
+ f'{auth_user.name}, {auth_user.id}')
29
+ if 'env_vars' in original_json:
30
+ if isinstance(original_json.get('env_vars'), dict):
31
+ original_json['env_vars'][
32
+ constants.USER_ID_ENV_VAR] = auth_user.id
33
+ original_json['env_vars'][
34
+ constants.USER_ENV_VAR] = auth_user.name
35
+ else:
36
+ logger.warning(
37
+ f'"env_vars" in request body is not a dictionary '
38
+ f'for request {request.state.request_id}. '
39
+ 'Skipping user info injection into body.')
40
+ else:
41
+ original_json['env_vars'] = {}
42
+ original_json['env_vars'][
43
+ constants.USER_ID_ENV_VAR] = auth_user.id
44
+ original_json['env_vars'][
45
+ constants.USER_ENV_VAR] = auth_user.name
46
+ request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
@@ -0,0 +1,185 @@
1
+ """Authentication based on oauth2-proxy."""
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import http
6
+ import os
7
+ from typing import Optional
8
+ import urllib
9
+
10
+ import aiohttp
11
+ import fastapi
12
+ import starlette.middleware.base
13
+
14
+ from sky import models
15
+ from sky import sky_logging
16
+ from sky.server.auth import authn
17
+ from sky.utils import common_utils
18
+
19
+ logger = sky_logging.init_logger(__name__)
20
+
21
+ # We do not support setting these in config.yaml because:
22
+ # 1. config.yaml can be updated dynamically, but auth middleware does not
23
+ # support hot reload yet.
24
+ # 2. If we introduce hot reload for auth middleware, bad config might
25
+ # invalidate all authenticated sessions and thus cannot be rolled back
26
+ # by API users.
27
+ # TODO(aylei): we should introduce server.yaml for static server admin config,
28
+ # which is more structured than multiple environment variables and can be less
29
+ # confusing to users.
30
+ OAUTH2_PROXY_BASE_URL_ENV_VAR = 'SKYPILOT_AUTH_OAUTH2_PROXY_BASE_URL'
31
+ OAUTH2_PROXY_ENABLED_ENV_VAR = 'SKYPILOT_AUTH_OAUTH2_PROXY_ENABLED'
32
+
33
+
34
+ class OAuth2ProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
35
+ """Middleware to handle authentication by delegating to OAuth2 Proxy."""
36
+
37
+ def __init__(self, application: fastapi.FastAPI):
38
+ super().__init__(application)
39
+ self.enabled: bool = (os.getenv(OAUTH2_PROXY_ENABLED_ENV_VAR,
40
+ 'false') == 'true')
41
+ self.proxy_base: str = ''
42
+ if self.enabled:
43
+ proxy_base = os.getenv(OAUTH2_PROXY_BASE_URL_ENV_VAR)
44
+ if not proxy_base:
45
+ raise ValueError('OAuth2 Proxy is enabled but base_url is not '
46
+ 'set')
47
+ self.proxy_base = proxy_base.rstrip('/')
48
+
49
+ async def dispatch(self, request: fastapi.Request, call_next):
50
+ if not self.enabled:
51
+ return await call_next(request)
52
+
53
+ # Forward /oauth2/* to oauth2-proxy, including /oauth2/start and
54
+ # /oauth2/callback.
55
+ if request.url.path.startswith('/oauth2'):
56
+ return await self.forward_to_oauth2_proxy(request)
57
+
58
+ return await self.authenticate(request, call_next)
59
+
60
+ async def forward_to_oauth2_proxy(self, request: fastapi.Request):
61
+ """Forward requests to oauth2-proxy service."""
62
+ logger.debug(f'forwarding to oauth2-proxy: {request.url.path}')
63
+ path = request.url.path.lstrip('/')
64
+ target_url = f'{self.proxy_base}/{path}'
65
+ body = await request.body()
66
+ async with aiohttp.ClientSession() as session:
67
+ try:
68
+ forwarded_headers = dict(request.headers)
69
+ async with session.request(
70
+ method=request.method,
71
+ url=target_url,
72
+ headers=forwarded_headers,
73
+ data=body,
74
+ cookies=request.cookies,
75
+ params=request.query_params,
76
+ allow_redirects=False,
77
+ ) as response:
78
+ response_body = await response.read()
79
+ fastapi_response = fastapi.responses.Response(
80
+ content=response_body,
81
+ status_code=response.status,
82
+ headers=dict(response.headers),
83
+ )
84
+ # Forward cookies from OAuth2 proxy response to client
85
+ for cookie_name, cookie in response.cookies.items():
86
+ fastapi_response.set_cookie(
87
+ key=cookie_name,
88
+ value=cookie.value,
89
+ max_age=cookie.get('max-age'),
90
+ expires=cookie.get('expires'),
91
+ path=cookie.get('path', '/'),
92
+ domain=cookie.get('domain'),
93
+ secure=cookie.get('secure', False),
94
+ httponly=cookie.get('httponly', False),
95
+ )
96
+ return fastapi_response
97
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
98
+ logger.error(f'Error forwarding to OAuth2 proxy: {e}')
99
+ return fastapi.responses.JSONResponse(
100
+ status_code=http.HTTPStatus.BAD_GATEWAY,
101
+ content={'detail': 'oauth2-proxy service unavailable'})
102
+
103
+ async def authenticate(self, request: fastapi.Request, call_next):
104
+ if request.state.auth_user is not None:
105
+ # Already authenticated
106
+ return await call_next(request)
107
+
108
+ async with aiohttp.ClientSession() as session:
109
+ try:
110
+ return await self._authenticate(request, call_next, session)
111
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
112
+ logger.error(f'Error communicating with OAuth2 proxy: {e}')
113
+ # Fail open or closed based on your security requirements
114
+ return fastapi.responses.JSONResponse(
115
+ status_code=http.HTTPStatus.BAD_GATEWAY,
116
+ content={'detail': 'oauth2-proxy service unavailable'})
117
+
118
+ async def _authenticate(self, request: fastapi.Request, call_next,
119
+ session: aiohttp.ClientSession):
120
+ forwarded_headers = dict(request.headers)
121
+ auth_url = f'{self.proxy_base}/oauth2/auth'
122
+ forwarded_headers['X-Forwarded-Uri'] = str(request.url).rstrip('/')
123
+ logger.debug(f'authenticate request: {request.url.path}')
124
+
125
+ async with session.request(
126
+ method=request.method,
127
+ url=auth_url,
128
+ headers=forwarded_headers,
129
+ cookies=request.cookies,
130
+ timeout=aiohttp.ClientTimeout(total=10),
131
+ allow_redirects=False,
132
+ ) as auth_response:
133
+
134
+ if auth_response.status == http.HTTPStatus.ACCEPTED:
135
+ # User is authenticated, extract user info from headers
136
+ auth_user = self.get_auth_user(auth_response)
137
+ if not auth_user:
138
+ return fastapi.responses.JSONResponse(
139
+ status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR,
140
+ content={
141
+ 'detail':
142
+ 'oauth2-proxy is enabled but did not'
143
+ 'return user info, check your oauth2-proxy'
144
+ 'setup.'
145
+ })
146
+ request.state.auth_user = auth_user
147
+ await authn.override_user_info_in_request_body(
148
+ request, auth_user)
149
+ return await call_next(request)
150
+ elif auth_response.status == http.HTTPStatus.UNAUTHORIZED:
151
+ # For /api/health, we should allow unauthenticated requests to
152
+ # not break healthz check.
153
+ # TODO(aylei): remove this to an aggregated login middleware
154
+ # in favor of the unified authentication.
155
+ if request.url.path.startswith('/api/health'):
156
+ request.state.anonymous_user = True
157
+ return await call_next(request)
158
+
159
+ # TODO(aylei): in unified authentication, the redirection
160
+ # or rejection should be done after all the authentication
161
+ # methods are performed.
162
+ # Not authenticated, redirect to sign-in
163
+ redirect_path = request.url.path
164
+ if request.url.query:
165
+ redirect_path += f'?{request.url.query}'
166
+ rd = urllib.parse.quote(redirect_path)
167
+ signin_url = (f'{request.base_url}oauth2/start?'
168
+ f'rd={rd}')
169
+ return fastapi.responses.RedirectResponse(url=signin_url)
170
+ else:
171
+ logger.error('oauth2-proxy returned unexpected status '
172
+ f'{auth_response.status}: {auth_response.text}')
173
+ return fastapi.responses.JSONResponse(
174
+ status_code=auth_response.status,
175
+ content={'detail': 'oauth2-proxy error'})
176
+
177
+ def get_auth_user(
178
+ self, response: aiohttp.ClientResponse) -> Optional[models.User]:
179
+ """Extract user info from OAuth2 proxy response headers."""
180
+ email_header = response.headers.get('X-Auth-Request-Email')
181
+ if email_header:
182
+ user_hash = hashlib.md5(email_header.encode()).hexdigest(
183
+ )[:common_utils.USER_HASH_LENGTH]
184
+ return models.User(id=user_hash, name=email_header)
185
+ return None
sky/server/common.py CHANGED
@@ -16,13 +16,15 @@ import tempfile
16
16
  import threading
17
17
  import time
18
18
  import typing
19
- from typing import Any, Dict, Literal, Optional, Tuple, Union
19
+ from typing import (Any, Callable, cast, Dict, Literal, Optional, Tuple,
20
+ TypeVar, Union)
20
21
  from urllib import parse
21
22
  import uuid
22
23
 
23
24
  import cachetools
24
25
  import colorama
25
26
  import filelock
27
+ from typing_extensions import ParamSpec
26
28
 
27
29
  from sky import exceptions
28
30
  from sky import sky_logging
@@ -41,12 +43,14 @@ from sky.utils import rich_utils
41
43
  from sky.utils import ux_utils
42
44
 
43
45
  if typing.TYPE_CHECKING:
46
+ import aiohttp
44
47
  import pydantic
45
48
  import requests
46
49
 
47
50
  from sky import dag as dag_lib
48
51
  from sky import models
49
52
  else:
53
+ aiohttp = adaptors_common.LazyImport('aiohttp')
50
54
  pydantic = adaptors_common.LazyImport('pydantic')
51
55
  requests = adaptors_common.LazyImport('requests')
52
56
 
@@ -92,6 +96,9 @@ logger = sky_logging.init_logger(__name__)
92
96
 
93
97
  hinted_for_server_install_version_mismatch = False
94
98
 
99
+ T = TypeVar('T')
100
+ P = ParamSpec('P')
101
+
95
102
 
96
103
  class ApiServerStatus(enum.Enum):
97
104
  HEALTHY = 'healthy'
@@ -175,24 +182,14 @@ def get_cookies_from_response(
175
182
  return cookies
176
183
 
177
184
 
178
- def make_authenticated_request(method: str,
179
- path: str,
180
- server_url: Optional[str] = None,
181
- retry: bool = True,
182
- **kwargs) -> 'requests.Response':
183
- """Make an authenticated HTTP request to the API server.
184
-
185
- Automatically handles service account token authentication or cookie-based
186
- authentication based on what's available.
187
-
188
- Args:
189
- method: HTTP method (GET, POST, etc.)
190
- path: API path (e.g., '/api/v1/status')
191
- server_url: Server URL, defaults to configured server
192
- **kwargs: Additional arguments to pass to requests
185
+ def _prepare_authenticated_request_params(
186
+ path: str,
187
+ server_url: Optional[str] = None,
188
+ **kwargs) -> Tuple[str, Dict[str, Any]]:
189
+ """Prepare common parameters for authenticated requests (sync or async).
193
190
 
194
191
  Returns:
195
- requests.Response object
192
+ Tuple of (url, updated_kwargs)
196
193
  """
197
194
  if server_url is None:
198
195
  server_url = get_server_url()
@@ -214,6 +211,41 @@ def make_authenticated_request(method: str,
214
211
  if not headers.get('Authorization') and 'cookies' not in kwargs:
215
212
  kwargs['cookies'] = get_api_cookie_jar()
216
213
 
214
+ return url, kwargs
215
+
216
+
217
+ def _convert_requests_cookies_to_aiohttp(
218
+ cookie_jar: requests.cookies.RequestsCookieJar) -> Dict[str, str]:
219
+ """Convert requests cookie jar to aiohttp-compatible dict format."""
220
+ cookies = {}
221
+ for cookie in cookie_jar:
222
+ cookies[cookie.name] = cookie.value
223
+ return cookies # type: ignore
224
+
225
+
226
+ def make_authenticated_request(method: str,
227
+ path: str,
228
+ server_url: Optional[str] = None,
229
+ retry: bool = True,
230
+ **kwargs) -> 'requests.Response':
231
+ """Make an authenticated HTTP request to the API server.
232
+
233
+ Automatically handles service account token authentication or cookie-based
234
+ authentication based on what's available.
235
+
236
+ Args:
237
+ method: HTTP method (GET, POST, etc.)
238
+ path: API path (e.g., '/api/v1/status')
239
+ server_url: Server URL, defaults to configured server
240
+ retry: Whether to retry on transient errors
241
+ **kwargs: Additional arguments to pass to requests
242
+
243
+ Returns:
244
+ requests.Response object
245
+ """
246
+ url, kwargs = _prepare_authenticated_request_params(path, server_url,
247
+ **kwargs)
248
+
217
249
  # Make the request
218
250
  if retry:
219
251
  return rest.request(method, url, **kwargs)
@@ -222,6 +254,69 @@ def make_authenticated_request(method: str,
222
254
  return rest.request_without_retry(method, url, **kwargs)
223
255
 
224
256
 
257
+ async def make_authenticated_request_async(
258
+ session: 'aiohttp.ClientSession',
259
+ method: str,
260
+ path: str,
261
+ server_url: Optional[str] = None,
262
+ retry: bool = True,
263
+ **kwargs) -> 'aiohttp.ClientResponse':
264
+ """Make an authenticated async HTTP request to the API server using aiohttp.
265
+
266
+ Automatically handles service account token authentication or cookie-based
267
+ authentication based on what's available.
268
+
269
+ Example usage:
270
+ async with aiohttp.ClientSession() as session:
271
+ response = await make_authenticated_request_async(
272
+ session, 'GET', '/api/v1/status')
273
+ data = await response.json()
274
+
275
+ Args:
276
+ session: aiohttp ClientSession to use for the request
277
+ method: HTTP method (GET, POST, etc.)
278
+ path: API path (e.g., '/api/v1/status')
279
+ server_url: Server URL, defaults to configured server
280
+ retry: Whether to retry on transient errors
281
+ **kwargs: Additional arguments to pass to aiohttp
282
+
283
+ Returns:
284
+ aiohttp.ClientResponse object
285
+
286
+ Raises:
287
+ aiohttp.ClientError: For HTTP-related errors
288
+ exceptions.ServerTemporarilyUnavailableError: When server returns 503
289
+ exceptions.RequestInterruptedError: When request is interrupted
290
+ """
291
+ url, kwargs = _prepare_authenticated_request_params(path, server_url,
292
+ **kwargs)
293
+
294
+ # Convert cookies to aiohttp format if needed
295
+ if 'cookies' in kwargs and isinstance(kwargs['cookies'],
296
+ requests.cookies.RequestsCookieJar):
297
+ kwargs['cookies'] = _convert_requests_cookies_to_aiohttp(
298
+ kwargs['cookies'])
299
+
300
+ # Convert params to strings for aiohttp compatibility
301
+ if 'params' in kwargs and kwargs['params'] is not None:
302
+ normalized_params = {}
303
+ for key, value in kwargs['params'].items():
304
+ if isinstance(value, bool):
305
+ normalized_params[key] = str(value).lower()
306
+ elif value is not None:
307
+ normalized_params[key] = str(value)
308
+ # Skip None values
309
+ kwargs['params'] = normalized_params
310
+
311
+ # Make the request
312
+ if retry:
313
+ return await rest.request_async(session, method, url, **kwargs)
314
+ else:
315
+ assert method == 'GET', 'Only GET requests can be done without retry'
316
+ return await rest.request_without_retry_async(session, method, url,
317
+ **kwargs)
318
+
319
+
225
320
  @annotations.lru_cache(scope='global')
226
321
  def get_server_url(host: Optional[str] = None) -> str:
227
322
  endpoint = DEFAULT_SERVER_URL
@@ -322,13 +417,14 @@ def get_api_server_status(endpoint: Optional[str] = None) -> ApiServerInfo:
322
417
  # The response is 200, so we can parse the response.
323
418
  try:
324
419
  result = response.json()
420
+ server_status = result.get('status')
325
421
  api_version = result.get('api_version')
326
422
  version = result.get('version')
327
423
  version_on_disk = result.get('version_on_disk')
328
424
  commit = result.get('commit')
329
425
  user = result.get('user')
330
426
  basic_auth_enabled = result.get('basic_auth_enabled')
331
- server_info = ApiServerInfo(status=ApiServerStatus.HEALTHY,
427
+ server_info = ApiServerInfo(status=ApiServerStatus(server_status),
332
428
  api_version=api_version,
333
429
  version=version,
334
430
  version_on_disk=version_on_disk,
@@ -662,14 +758,14 @@ def check_server_healthy_or_start_fn(deploy: bool = False,
662
758
  metrics_port, enable_basic_auth)
663
759
 
664
760
 
665
- def check_server_healthy_or_start(func):
761
+ def check_server_healthy_or_start(func: Callable[P, T]) -> Callable[P, T]:
666
762
 
667
763
  @functools.wraps(func)
668
764
  def wrapper(*args, deploy: bool = False, host: str = '127.0.0.1', **kwargs):
669
765
  check_server_healthy_or_start_fn(deploy, host)
670
766
  return func(*args, **kwargs)
671
767
 
672
- return wrapper
768
+ return cast(Callable[P, T], wrapper)
673
769
 
674
770
 
675
771
  def process_mounts_in_task_on_api_server(task: str, env_vars: Dict[str, str],
@@ -787,7 +883,8 @@ def request_body_to_params(body: 'pydantic.BaseModel') -> Dict[str, Any]:
787
883
 
788
884
  def reload_for_new_request(client_entrypoint: Optional[str],
789
885
  client_command: Optional[str],
790
- using_remote_api_server: bool, user: 'models.User'):
886
+ using_remote_api_server: bool, user: 'models.User',
887
+ request_id: str) -> None:
791
888
  """Reload modules, global variables, and usage message for a new request."""
792
889
  # This should be called first to make sure the logger is up-to-date.
793
890
  sky_logging.reload_logger()
@@ -801,6 +898,7 @@ def reload_for_new_request(client_entrypoint: Optional[str],
801
898
  client_command=client_command,
802
899
  using_remote_api_server=using_remote_api_server,
803
900
  user=user,
901
+ request_id=request_id,
804
902
  )
805
903
 
806
904
  # Clear cache should be called before reload_logger and usage reset,
sky/server/constants.py CHANGED
@@ -10,7 +10,7 @@ from sky.skylet import constants
10
10
  # based on version info is needed.
11
11
  # For more details and code guidelines, refer to:
12
12
  # https://docs.skypilot.co/en/latest/developers/CONTRIBUTING.html#backward-compatibility-guidelines
13
- API_VERSION = 13
13
+ API_VERSION = 15
14
14
 
15
15
  # The minimum peer API version that the code should still work with.
16
16
  # Notes (dev):
sky/server/daemons.py CHANGED
@@ -14,6 +14,10 @@ from sky.utils import ux_utils
14
14
  logger = sky_logging.init_logger(__name__)
15
15
 
16
16
 
17
+ def _default_should_skip():
18
+ return False
19
+
20
+
17
21
  @dataclasses.dataclass
18
22
  class InternalRequestDaemon:
19
23
  """Internal daemon that runs an event in the background."""
@@ -22,6 +26,7 @@ class InternalRequestDaemon:
22
26
  name: str
23
27
  event_fn: Callable[[], None]
24
28
  default_log_level: str = 'INFO'
29
+ should_skip: Callable[[], bool] = _default_should_skip
25
30
 
26
31
  def refresh_log_level(self) -> int:
27
32
  # pylint: disable=import-outside-toplevel
@@ -110,14 +115,14 @@ def managed_job_status_refresh_event():
110
115
  """Refresh the managed job status for controller consolidation mode."""
111
116
  # pylint: disable=import-outside-toplevel
112
117
  from sky.jobs import utils as managed_job_utils
113
- if not managed_job_utils.is_consolidation_mode():
114
- return
118
+ from sky.utils import controller_utils
119
+
115
120
  # We run the recovery logic before starting the event loop as those two are
116
121
  # conflicting. Check PERSISTENT_RUN_RESTARTING_SIGNAL_FILE for details.
117
- from sky.utils import controller_utils
118
122
  if controller_utils.high_availability_specified(
119
123
  controller_utils.Controllers.JOBS_CONTROLLER.value.cluster_name):
120
124
  managed_job_utils.ha_recovery_for_consolidation_mode()
125
+
121
126
  # After recovery, we start the event loop.
122
127
  from sky.skylet import events
123
128
  refresh_event = events.ManagedJobEvent()
@@ -128,20 +133,58 @@ def managed_job_status_refresh_event():
128
133
  time.sleep(events.EVENT_CHECKING_INTERVAL_SECONDS)
129
134
 
130
135
 
131
- def sky_serve_status_refresh_event():
136
+ def should_skip_managed_job_status_refresh():
137
+ """Check if the managed job status refresh event should be skipped."""
138
+ # pylint: disable=import-outside-toplevel
139
+ from sky.jobs import utils as managed_job_utils
140
+ return not managed_job_utils.is_consolidation_mode()
141
+
142
+
143
+ def _serve_status_refresh_event(pool: bool):
132
144
  """Refresh the sky serve status for controller consolidation mode."""
133
145
  # pylint: disable=import-outside-toplevel
134
146
  from sky.serve import serve_utils
135
- if not serve_utils.is_consolidation_mode():
136
- return
137
- # TODO(tian): Add HA recovery logic.
147
+ from sky.utils import controller_utils
148
+
149
+ # We run the recovery logic before starting the event loop as those two are
150
+ # conflicting. Check PERSISTENT_RUN_RESTARTING_SIGNAL_FILE for details.
151
+ controller = controller_utils.get_controller_for_pool(pool)
152
+ if controller_utils.high_availability_specified(
153
+ controller.value.cluster_name):
154
+ serve_utils.ha_recovery_for_consolidation_mode(pool=pool)
155
+
156
+ # After recovery, we start the event loop.
138
157
  from sky.skylet import events
139
- event = events.ServiceUpdateEvent()
140
- logger.info('=== Running serve status refresh event ===')
158
+ event = events.ServiceUpdateEvent(pool=pool)
159
+ noun = 'pool' if pool else 'serve'
160
+ logger.info(f'=== Running {noun} status refresh event ===')
141
161
  event.run()
142
162
  time.sleep(events.EVENT_CHECKING_INTERVAL_SECONDS)
143
163
 
144
164
 
165
+ def _should_skip_serve_status_refresh_event(pool: bool):
166
+ """Check if the serve status refresh event should be skipped."""
167
+ # pylint: disable=import-outside-toplevel
168
+ from sky.serve import serve_utils
169
+ return not serve_utils.is_consolidation_mode(pool=pool)
170
+
171
+
172
+ def sky_serve_status_refresh_event():
173
+ _serve_status_refresh_event(pool=False)
174
+
175
+
176
+ def should_skip_sky_serve_status_refresh():
177
+ return _should_skip_serve_status_refresh_event(pool=False)
178
+
179
+
180
+ def pool_status_refresh_event():
181
+ _serve_status_refresh_event(pool=True)
182
+
183
+
184
+ def should_skip_pool_status_refresh():
185
+ return _should_skip_serve_status_refresh_event(pool=True)
186
+
187
+
145
188
  # Register the events to run in the background.
146
189
  INTERNAL_REQUEST_DAEMONS = [
147
190
  # This status refresh daemon can cause the autostopp'ed/autodown'ed cluster
@@ -157,8 +200,14 @@ INTERNAL_REQUEST_DAEMONS = [
157
200
  event_fn=refresh_volume_status_event),
158
201
  InternalRequestDaemon(id='managed-job-status-refresh-daemon',
159
202
  name='managed-job-status',
160
- event_fn=managed_job_status_refresh_event),
203
+ event_fn=managed_job_status_refresh_event,
204
+ should_skip=should_skip_managed_job_status_refresh),
161
205
  InternalRequestDaemon(id='sky-serve-status-refresh-daemon',
162
206
  name='sky-serve-status',
163
- event_fn=sky_serve_status_refresh_event),
207
+ event_fn=sky_serve_status_refresh_event,
208
+ should_skip=should_skip_sky_serve_status_refresh),
209
+ InternalRequestDaemon(id='pool-status-refresh-daemon',
210
+ name='pool-status',
211
+ event_fn=pool_status_refresh_event,
212
+ should_skip=should_skip_pool_status_refresh),
164
213
  ]
@@ -271,7 +271,8 @@ def _get_queue(schedule_type: api_requests.ScheduleType) -> RequestQueue:
271
271
 
272
272
  @contextlib.contextmanager
273
273
  def override_request_env_and_config(
274
- request_body: payloads.RequestBody) -> Generator[None, None, None]:
274
+ request_body: payloads.RequestBody,
275
+ request_id: str) -> Generator[None, None, None]:
275
276
  """Override the environment and SkyPilot config for a request."""
276
277
  original_env = os.environ.copy()
277
278
  os.environ.update(request_body.env_vars)
@@ -292,7 +293,8 @@ def override_request_env_and_config(
292
293
  client_entrypoint=request_body.entrypoint,
293
294
  client_command=request_body.entrypoint_command,
294
295
  using_remote_api_server=request_body.using_remote_api_server,
295
- user=user)
296
+ user=user,
297
+ request_id=request_id)
296
298
  try:
297
299
  logger.debug(
298
300
  f'override path: {request_body.override_skypilot_config_path}')
@@ -376,7 +378,7 @@ def _request_execution_wrapper(request_id: str,
376
378
  # config, as there can be some logs during override that needs to be
377
379
  # captured in the log file.
378
380
  try:
379
- with override_request_env_and_config(request_body), \
381
+ with override_request_env_and_config(request_body, request_id), \
380
382
  tempstore.tempdir():
381
383
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
382
384
  config = skypilot_config.to_dict()