skypilot-nightly 1.0.0.dev20251009__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (231) hide show
  1. sky/__init__.py +6 -2
  2. sky/adaptors/aws.py +25 -7
  3. sky/adaptors/coreweave.py +278 -0
  4. sky/adaptors/kubernetes.py +64 -0
  5. sky/adaptors/shadeform.py +89 -0
  6. sky/admin_policy.py +20 -0
  7. sky/authentication.py +59 -149
  8. sky/backends/backend_utils.py +104 -63
  9. sky/backends/cloud_vm_ray_backend.py +84 -39
  10. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  11. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  12. sky/catalog/kubernetes_catalog.py +24 -28
  13. sky/catalog/runpod_catalog.py +5 -1
  14. sky/catalog/shadeform_catalog.py +165 -0
  15. sky/check.py +25 -13
  16. sky/client/cli/command.py +335 -86
  17. sky/client/cli/flags.py +4 -2
  18. sky/client/cli/table_utils.py +17 -9
  19. sky/client/sdk.py +59 -12
  20. sky/cloud_stores.py +73 -0
  21. sky/clouds/__init__.py +2 -0
  22. sky/clouds/aws.py +71 -16
  23. sky/clouds/azure.py +12 -5
  24. sky/clouds/cloud.py +19 -9
  25. sky/clouds/cudo.py +12 -5
  26. sky/clouds/do.py +4 -1
  27. sky/clouds/fluidstack.py +12 -5
  28. sky/clouds/gcp.py +12 -5
  29. sky/clouds/hyperbolic.py +12 -5
  30. sky/clouds/ibm.py +12 -5
  31. sky/clouds/kubernetes.py +62 -25
  32. sky/clouds/lambda_cloud.py +12 -5
  33. sky/clouds/nebius.py +12 -5
  34. sky/clouds/oci.py +12 -5
  35. sky/clouds/paperspace.py +4 -1
  36. sky/clouds/primeintellect.py +4 -1
  37. sky/clouds/runpod.py +12 -5
  38. sky/clouds/scp.py +12 -5
  39. sky/clouds/seeweb.py +4 -1
  40. sky/clouds/shadeform.py +400 -0
  41. sky/clouds/ssh.py +4 -2
  42. sky/clouds/vast.py +12 -5
  43. sky/clouds/vsphere.py +4 -1
  44. sky/core.py +12 -11
  45. sky/dashboard/out/404.html +1 -1
  46. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  47. sky/dashboard/out/_next/static/chunks/{1871-49141c317f3a9020.js → 1871-74503c8e80fd253b.js} +1 -1
  48. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  49. sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
  50. sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
  51. sky/dashboard/out/_next/static/chunks/{3785.a19328ba41517b8b.js → 3785.ad6adaa2a0fa9768.js} +1 -1
  52. sky/dashboard/out/_next/static/chunks/{4725.10f7a9a5d3ea8208.js → 4725.a830b5c9e7867c92.js} +1 -1
  53. sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
  54. sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
  55. sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
  56. sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
  57. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  58. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  59. sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
  60. sky/dashboard/out/_next/static/chunks/pages/{_app-ce361c6959bc2001.js → _app-bde01e4a2beec258.js} +1 -1
  61. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
  62. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-477555ab7c0b13d8.js → [cluster]-a37d2063af475a1c.js} +1 -1
  63. sky/dashboard/out/_next/static/chunks/pages/{clusters-2f61f65487f6d8ff.js → clusters-d44859594e6f8064.js} +1 -1
  64. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-553b8b5cb65e100b.js → [context]-c0b5935149902e6f.js} +1 -1
  65. sky/dashboard/out/_next/static/chunks/pages/{infra-910a22500c50596f.js → infra-aed0ea19df7cf961.js} +1 -1
  66. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
  67. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-bc979970c247d8f3.js → [pool]-6edeb7d06032adfc.js} +2 -2
  68. sky/dashboard/out/_next/static/chunks/pages/{jobs-a35a9dc3c5ccd657.js → jobs-479dde13399cf270.js} +1 -1
  69. sky/dashboard/out/_next/static/chunks/pages/{users-98d2ed979084162a.js → users-5ab3b907622cf0fe.js} +1 -1
  70. sky/dashboard/out/_next/static/chunks/pages/{volumes-835d14ba94808f79.js → volumes-b84b948ff357c43e.js} +1 -1
  71. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-e8688c35c06f0ac5.js → [name]-c5a3eeee1c218af1.js} +1 -1
  72. sky/dashboard/out/_next/static/chunks/pages/{workspaces-69c80d677d3c2949.js → workspaces-22b23febb3e89ce1.js} +1 -1
  73. sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
  74. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  75. sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
  76. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  77. sky/dashboard/out/clusters/[cluster].html +1 -1
  78. sky/dashboard/out/clusters.html +1 -1
  79. sky/dashboard/out/config.html +1 -1
  80. sky/dashboard/out/index.html +1 -1
  81. sky/dashboard/out/infra/[context].html +1 -1
  82. sky/dashboard/out/infra.html +1 -1
  83. sky/dashboard/out/jobs/[job].html +1 -1
  84. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  85. sky/dashboard/out/jobs.html +1 -1
  86. sky/dashboard/out/users.html +1 -1
  87. sky/dashboard/out/volumes.html +1 -1
  88. sky/dashboard/out/workspace/new.html +1 -1
  89. sky/dashboard/out/workspaces/[name].html +1 -1
  90. sky/dashboard/out/workspaces.html +1 -1
  91. sky/data/data_utils.py +92 -1
  92. sky/data/mounting_utils.py +143 -19
  93. sky/data/storage.py +168 -11
  94. sky/exceptions.py +13 -1
  95. sky/execution.py +13 -0
  96. sky/global_user_state.py +189 -113
  97. sky/jobs/client/sdk.py +32 -10
  98. sky/jobs/client/sdk_async.py +9 -3
  99. sky/jobs/constants.py +3 -1
  100. sky/jobs/controller.py +164 -192
  101. sky/jobs/file_content_utils.py +80 -0
  102. sky/jobs/log_gc.py +201 -0
  103. sky/jobs/recovery_strategy.py +59 -82
  104. sky/jobs/scheduler.py +20 -9
  105. sky/jobs/server/core.py +105 -23
  106. sky/jobs/server/server.py +40 -28
  107. sky/jobs/server/utils.py +32 -11
  108. sky/jobs/state.py +588 -110
  109. sky/jobs/utils.py +442 -209
  110. sky/logs/agent.py +1 -1
  111. sky/metrics/utils.py +45 -6
  112. sky/optimizer.py +1 -1
  113. sky/provision/__init__.py +7 -0
  114. sky/provision/aws/instance.py +2 -1
  115. sky/provision/azure/instance.py +2 -1
  116. sky/provision/common.py +2 -0
  117. sky/provision/cudo/instance.py +2 -1
  118. sky/provision/do/instance.py +2 -1
  119. sky/provision/fluidstack/instance.py +4 -3
  120. sky/provision/gcp/instance.py +2 -1
  121. sky/provision/hyperbolic/instance.py +2 -1
  122. sky/provision/instance_setup.py +10 -2
  123. sky/provision/kubernetes/constants.py +0 -1
  124. sky/provision/kubernetes/instance.py +222 -89
  125. sky/provision/kubernetes/network.py +12 -8
  126. sky/provision/kubernetes/utils.py +114 -53
  127. sky/provision/kubernetes/volume.py +5 -4
  128. sky/provision/lambda_cloud/instance.py +2 -1
  129. sky/provision/nebius/instance.py +2 -1
  130. sky/provision/oci/instance.py +2 -1
  131. sky/provision/paperspace/instance.py +2 -1
  132. sky/provision/provisioner.py +11 -2
  133. sky/provision/runpod/instance.py +2 -1
  134. sky/provision/scp/instance.py +2 -1
  135. sky/provision/seeweb/instance.py +3 -3
  136. sky/provision/shadeform/__init__.py +11 -0
  137. sky/provision/shadeform/config.py +12 -0
  138. sky/provision/shadeform/instance.py +351 -0
  139. sky/provision/shadeform/shadeform_utils.py +83 -0
  140. sky/provision/vast/instance.py +2 -1
  141. sky/provision/vsphere/instance.py +2 -1
  142. sky/resources.py +1 -1
  143. sky/schemas/api/responses.py +9 -5
  144. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  145. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  146. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  147. sky/schemas/generated/jobsv1_pb2.py +52 -52
  148. sky/schemas/generated/jobsv1_pb2.pyi +4 -2
  149. sky/schemas/generated/managed_jobsv1_pb2.py +39 -35
  150. sky/schemas/generated/managed_jobsv1_pb2.pyi +21 -5
  151. sky/serve/client/impl.py +11 -3
  152. sky/serve/replica_managers.py +5 -2
  153. sky/serve/serve_utils.py +9 -2
  154. sky/serve/server/impl.py +7 -2
  155. sky/serve/server/server.py +18 -15
  156. sky/serve/service.py +2 -2
  157. sky/server/auth/oauth2_proxy.py +2 -5
  158. sky/server/common.py +31 -28
  159. sky/server/constants.py +5 -1
  160. sky/server/daemons.py +27 -19
  161. sky/server/requests/executor.py +138 -74
  162. sky/server/requests/payloads.py +9 -1
  163. sky/server/requests/preconditions.py +13 -10
  164. sky/server/requests/request_names.py +120 -0
  165. sky/server/requests/requests.py +485 -153
  166. sky/server/requests/serializers/decoders.py +26 -13
  167. sky/server/requests/serializers/encoders.py +56 -11
  168. sky/server/requests/threads.py +106 -0
  169. sky/server/rest.py +70 -18
  170. sky/server/server.py +283 -104
  171. sky/server/stream_utils.py +233 -59
  172. sky/server/uvicorn.py +18 -17
  173. sky/setup_files/alembic.ini +4 -0
  174. sky/setup_files/dependencies.py +32 -13
  175. sky/sky_logging.py +0 -2
  176. sky/skylet/constants.py +30 -7
  177. sky/skylet/events.py +7 -0
  178. sky/skylet/log_lib.py +8 -2
  179. sky/skylet/log_lib.pyi +1 -1
  180. sky/skylet/services.py +26 -13
  181. sky/skylet/subprocess_daemon.py +103 -29
  182. sky/skypilot_config.py +87 -75
  183. sky/ssh_node_pools/server.py +9 -8
  184. sky/task.py +67 -54
  185. sky/templates/kubernetes-ray.yml.j2 +8 -1
  186. sky/templates/nebius-ray.yml.j2 +1 -0
  187. sky/templates/shadeform-ray.yml.j2 +72 -0
  188. sky/templates/websocket_proxy.py +142 -12
  189. sky/users/permission.py +8 -1
  190. sky/utils/admin_policy_utils.py +16 -3
  191. sky/utils/asyncio_utils.py +78 -0
  192. sky/utils/auth_utils.py +153 -0
  193. sky/utils/cli_utils/status_utils.py +8 -2
  194. sky/utils/command_runner.py +11 -0
  195. sky/utils/common.py +3 -1
  196. sky/utils/common_utils.py +7 -4
  197. sky/utils/context.py +57 -51
  198. sky/utils/context_utils.py +30 -12
  199. sky/utils/controller_utils.py +35 -8
  200. sky/utils/db/db_utils.py +37 -10
  201. sky/utils/db/migration_utils.py +8 -4
  202. sky/utils/locks.py +24 -6
  203. sky/utils/resource_checker.py +4 -1
  204. sky/utils/resources_utils.py +53 -29
  205. sky/utils/schemas.py +23 -4
  206. sky/utils/subprocess_utils.py +17 -4
  207. sky/volumes/server/server.py +7 -6
  208. sky/workspaces/server.py +13 -12
  209. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/METADATA +306 -55
  210. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/RECORD +215 -195
  211. sky/dashboard/out/_next/static/chunks/1121-d0782b9251f0fcd3.js +0 -1
  212. sky/dashboard/out/_next/static/chunks/1141-3b40c39626f99c89.js +0 -11
  213. sky/dashboard/out/_next/static/chunks/2755.97300e1362fe7c98.js +0 -26
  214. sky/dashboard/out/_next/static/chunks/3015-8d748834fcc60b46.js +0 -1
  215. sky/dashboard/out/_next/static/chunks/3294.1fafbf42b3bcebff.js +0 -1
  216. sky/dashboard/out/_next/static/chunks/6135-4b4d5e824b7f9d3c.js +0 -1
  217. sky/dashboard/out/_next/static/chunks/6856-5fdc9b851a18acdb.js +0 -1
  218. sky/dashboard/out/_next/static/chunks/6990-f6818c84ed8f1c86.js +0 -1
  219. sky/dashboard/out/_next/static/chunks/8969-66237729cdf9749e.js +0 -1
  220. sky/dashboard/out/_next/static/chunks/9025.c12318fb6a1a9093.js +0 -6
  221. sky/dashboard/out/_next/static/chunks/9360.71e83b2ddc844ec2.js +0 -31
  222. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8f058b0346db2aff.js +0 -16
  223. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-4f7079dcab6ed653.js +0 -16
  224. sky/dashboard/out/_next/static/chunks/webpack-6a5ddd0184bfa22c.js +0 -1
  225. sky/dashboard/out/_next/static/css/4614e06482d7309e.css +0 -3
  226. sky/dashboard/out/_next/static/hIViZcQBkn0HE8SpaSsUU/_buildManifest.js +0 -1
  227. /sky/dashboard/out/_next/static/{hIViZcQBkn0HE8SpaSsUU → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
  228. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +0 -0
  229. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
  230. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
  231. {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@ import contextlib
5
5
  import dataclasses
6
6
  import enum
7
7
  import functools
8
- import json
9
8
  import os
10
9
  import pathlib
11
10
  import shutil
@@ -14,12 +13,14 @@ import sqlite3
14
13
  import threading
15
14
  import time
16
15
  import traceback
17
- from typing import (Any, AsyncContextManager, Callable, Dict, Generator, List,
18
- NamedTuple, Optional, Tuple)
16
+ from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
17
+ Tuple)
18
+ import uuid
19
19
 
20
20
  import anyio
21
21
  import colorama
22
22
  import filelock
23
+ import orjson
23
24
 
24
25
  from sky import exceptions
25
26
  from sky import global_user_state
@@ -32,6 +33,7 @@ from sky.server import daemons
32
33
  from sky.server.requests import payloads
33
34
  from sky.server.requests.serializers import decoders
34
35
  from sky.server.requests.serializers import encoders
36
+ from sky.utils import asyncio_utils
35
37
  from sky.utils import common_utils
36
38
  from sky.utils import ux_utils
37
39
  from sky.utils.db import db_utils
@@ -211,8 +213,8 @@ class Request:
211
213
  entrypoint=self.entrypoint.__name__,
212
214
  request_body=self.request_body.model_dump_json(),
213
215
  status=self.status.value,
214
- return_value=json.dumps(None),
215
- error=json.dumps(None),
216
+ return_value=orjson.dumps(None).decode('utf-8'),
217
+ error=orjson.dumps(None).decode('utf-8'),
216
218
  pid=None,
217
219
  created_at=self.created_at,
218
220
  schedule_type=self.schedule_type.value,
@@ -235,8 +237,8 @@ class Request:
235
237
  entrypoint=encoders.pickle_and_encode(self.entrypoint),
236
238
  request_body=encoders.pickle_and_encode(self.request_body),
237
239
  status=self.status.value,
238
- return_value=json.dumps(self.return_value),
239
- error=json.dumps(self.error),
240
+ return_value=orjson.dumps(self.return_value).decode('utf-8'),
241
+ error=orjson.dumps(self.error).decode('utf-8'),
240
242
  pid=self.pid,
241
243
  created_at=self.created_at,
242
244
  schedule_type=self.schedule_type.value,
@@ -268,8 +270,8 @@ class Request:
268
270
  entrypoint=decoders.decode_and_unpickle(payload.entrypoint),
269
271
  request_body=decoders.decode_and_unpickle(payload.request_body),
270
272
  status=RequestStatus(payload.status),
271
- return_value=json.loads(payload.return_value),
272
- error=json.loads(payload.error),
273
+ return_value=orjson.loads(payload.return_value),
274
+ error=orjson.loads(payload.error),
273
275
  pid=payload.pid,
274
276
  created_at=payload.created_at,
275
277
  schedule_type=ScheduleType(payload.schedule_type),
@@ -292,72 +294,104 @@ class Request:
292
294
  raise
293
295
 
294
296
 
295
- def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
296
- """Kill all pending and running requests for a cluster.
297
+ def get_new_request_id() -> str:
298
+ """Get a new request ID."""
299
+ return str(uuid.uuid4())
297
300
 
298
- Args:
299
- cluster_name: the name of the cluster.
300
- exclude_request_names: exclude requests with these names. This is to
301
- prevent killing the caller request.
302
- """
303
- request_ids = [
304
- request_task.request_id
305
- for request_task in get_request_tasks(req_filter=RequestTaskFilter(
306
- cluster_names=[cluster_name],
307
- status=[RequestStatus.PENDING, RequestStatus.RUNNING],
308
- exclude_request_names=[exclude_request_name]))
309
- ]
310
- kill_requests(request_ids)
311
301
 
302
+ def encode_requests(requests: List[Request]) -> List[payloads.RequestPayload]:
303
+ """Serialize the SkyPilot API request for display purposes.
312
304
 
313
- def kill_requests(request_ids: Optional[List[str]] = None,
314
- user_id: Optional[str] = None) -> List[str]:
315
- """Kill a SkyPilot API request and set its status to cancelled.
305
+ This function should be called on the server side to serialize the
306
+ request body into human readable format, e.g., the entrypoint should
307
+ be a string, and the pid, error, or return value are not needed.
316
308
 
317
- Args:
318
- request_ids: The request IDs to kill. If None, all requests for the
319
- user are killed.
320
- user_id: The user ID to kill requests for. If None, all users are
321
- killed.
309
+ The returned value will then be displayed on the client side in request
310
+ table.
322
311
 
323
- Returns:
324
- A list of request IDs that were cancelled.
325
- """
326
- if request_ids is None:
327
- request_ids = [
328
- request_task.request_id
329
- for request_task in get_request_tasks(req_filter=RequestTaskFilter(
330
- user_id=user_id,
331
- status=[RequestStatus.RUNNING, RequestStatus.PENDING],
332
- # Avoid cancelling the cancel request itself.
333
- exclude_request_names=['sky.api_cancel']))
334
- ]
335
- cancelled_request_ids = []
336
- for request_id in request_ids:
337
- with update_request(request_id) as request_record:
338
- if request_record is None:
339
- logger.debug(f'No request ID {request_id}')
340
- continue
341
- # Skip internal requests. The internal requests are scheduled with
342
- # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
343
- if request_record.request_id in set(
344
- event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
345
- continue
346
- if request_record.status > RequestStatus.RUNNING:
347
- logger.debug(f'Request {request_id} already finished')
348
- continue
349
- if request_record.pid is not None:
350
- logger.debug(f'Killing request process {request_record.pid}')
351
- # Use SIGTERM instead of SIGKILL:
352
- # - The executor can handle SIGTERM gracefully
353
- # - After SIGTERM, the executor can reuse the request process
354
- # for other requests, avoiding the overhead of forking a new
355
- # process for each request.
356
- os.kill(request_record.pid, signal.SIGTERM)
357
- request_record.status = RequestStatus.CANCELLED
358
- request_record.finished_at = time.time()
359
- cancelled_request_ids.append(request_id)
360
- return cancelled_request_ids
312
+ We do not use `encode` for display to avoid a large amount of data being
313
+ sent to the client side, especially for the request table could include
314
+ all the requests.
315
+ """
316
+ encoded_requests = []
317
+ all_users = global_user_state.get_all_users()
318
+ all_users_map = {user.id: user.name for user in all_users}
319
+ for request in requests:
320
+ if request.request_body is not None:
321
+ assert isinstance(request.request_body,
322
+ payloads.RequestBody), (request.name,
323
+ request.request_body)
324
+ user_name = all_users_map.get(request.user_id)
325
+ payload = payloads.RequestPayload(
326
+ request_id=request.request_id,
327
+ name=request.name,
328
+ entrypoint=request.entrypoint.__name__
329
+ if request.entrypoint is not None else '',
330
+ request_body=request.request_body.model_dump_json()
331
+ if request.request_body is not None else
332
+ orjson.dumps(None).decode('utf-8'),
333
+ status=request.status.value,
334
+ return_value=orjson.dumps(None).decode('utf-8'),
335
+ error=orjson.dumps(None).decode('utf-8'),
336
+ pid=None,
337
+ created_at=request.created_at,
338
+ schedule_type=request.schedule_type.value,
339
+ user_id=request.user_id,
340
+ user_name=user_name,
341
+ cluster_name=request.cluster_name,
342
+ status_msg=request.status_msg,
343
+ should_retry=request.should_retry,
344
+ finished_at=request.finished_at,
345
+ )
346
+ encoded_requests.append(payload)
347
+ return encoded_requests
348
+
349
+
350
+ def _update_request_row_fields(
351
+ row: Tuple[Any, ...],
352
+ fields: Optional[List[str]] = None) -> Tuple[Any, ...]:
353
+ """Update the request row fields."""
354
+ if not fields:
355
+ return row
356
+
357
+ # Convert tuple to dictionary for easier manipulation
358
+ content = dict(zip(fields, row))
359
+
360
+ # Required fields in RequestPayload
361
+ if 'request_id' not in fields:
362
+ content['request_id'] = ''
363
+ if 'name' not in fields:
364
+ content['name'] = ''
365
+ if 'entrypoint' not in fields:
366
+ content['entrypoint'] = server_constants.EMPTY_PICKLED_VALUE
367
+ if 'request_body' not in fields:
368
+ content['request_body'] = server_constants.EMPTY_PICKLED_VALUE
369
+ if 'status' not in fields:
370
+ content['status'] = RequestStatus.PENDING.value
371
+ if 'created_at' not in fields:
372
+ content['created_at'] = 0
373
+ if 'user_id' not in fields:
374
+ content['user_id'] = ''
375
+ if 'return_value' not in fields:
376
+ content['return_value'] = orjson.dumps(None).decode('utf-8')
377
+ if 'error' not in fields:
378
+ content['error'] = orjson.dumps(None).decode('utf-8')
379
+ if 'schedule_type' not in fields:
380
+ content['schedule_type'] = ScheduleType.SHORT.value
381
+ # Optional fields in RequestPayload
382
+ if 'pid' not in fields:
383
+ content['pid'] = None
384
+ if 'cluster_name' not in fields:
385
+ content['cluster_name'] = None
386
+ if 'status_msg' not in fields:
387
+ content['status_msg'] = None
388
+ if 'should_retry' not in fields:
389
+ content['should_retry'] = False
390
+ if 'finished_at' not in fields:
391
+ content['finished_at'] = None
392
+
393
+ # Convert back to tuple in the same order as REQUEST_COLUMNS
394
+ return tuple(content[col] for col in REQUEST_COLUMNS)
361
395
 
362
396
 
363
397
  def create_table(cursor, conn):
@@ -402,6 +436,21 @@ def create_table(cursor, conn):
402
436
  db_utils.add_column_to_table(cursor, conn, REQUEST_TABLE, COL_FINISHED_AT,
403
437
  'REAL')
404
438
 
439
+ # Add an index on (status, name) to speed up queries
440
+ # that filter on these columns.
441
+ cursor.execute(f"""\
442
+ CREATE INDEX IF NOT EXISTS status_name_idx ON {REQUEST_TABLE} (status, name) WHERE status IN ('PENDING', 'RUNNING');
443
+ """)
444
+ # Add an index on cluster_name to speed up queries
445
+ # that filter on this column.
446
+ cursor.execute(f"""\
447
+ CREATE INDEX IF NOT EXISTS cluster_name_idx ON {REQUEST_TABLE} ({COL_CLUSTER_NAME}) WHERE status IN ('PENDING', 'RUNNING');
448
+ """)
449
+ # Add an index on created_at to speed up queries that sort on this column.
450
+ cursor.execute(f"""\
451
+ CREATE INDEX IF NOT EXISTS created_at_idx ON {REQUEST_TABLE} (created_at);
452
+ """)
453
+
405
454
 
406
455
  _DB = None
407
456
  _init_db_lock = threading.Lock()
@@ -460,6 +509,26 @@ def reset_db_and_logs():
460
509
  f'{server_common.API_SERVER_CLIENT_DIR.expanduser()}')
461
510
  shutil.rmtree(server_common.API_SERVER_CLIENT_DIR.expanduser(),
462
511
  ignore_errors=True)
512
+ with _init_db_lock:
513
+ _init_db_within_lock()
514
+ assert _DB is not None
515
+ with _DB.conn:
516
+ cursor = _DB.conn.cursor()
517
+ cursor.execute('SELECT sqlite_version()')
518
+ row = cursor.fetchone()
519
+ if row is None:
520
+ raise RuntimeError('Failed to get SQLite version')
521
+ version_str = row[0]
522
+ version_parts = version_str.split('.')
523
+ assert len(version_parts) >= 2, \
524
+ f'Invalid version string: {version_str}'
525
+ major, minor = int(version_parts[0]), int(version_parts[1])
526
+ # SQLite 3.35.0+ supports RETURNING statements.
527
+ # 3.35.0 was released in March 2021.
528
+ if not ((major > 3) or (major == 3 and minor >= 35)):
529
+ raise RuntimeError(
530
+ f'SQLite version {version_str} is not supported. '
531
+ 'Please upgrade to SQLite 3.35.0 or later.')
463
532
 
464
533
 
465
534
  def request_lock_path(request_id: str) -> str:
@@ -468,6 +537,132 @@ def request_lock_path(request_id: str) -> str:
468
537
  return os.path.join(lock_path, f'.{request_id}.lock')
469
538
 
470
539
 
540
+ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
541
+ """Kill all pending and running requests for a cluster.
542
+
543
+ Args:
544
+ cluster_name: the name of the cluster.
545
+ exclude_request_names: exclude requests with these names. This is to
546
+ prevent killing the caller request.
547
+ """
548
+ request_ids = [
549
+ request_task.request_id
550
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
551
+ status=[RequestStatus.PENDING, RequestStatus.RUNNING],
552
+ exclude_request_names=[exclude_request_name],
553
+ cluster_names=[cluster_name],
554
+ fields=['request_id']))
555
+ ]
556
+ _kill_requests(request_ids)
557
+
558
+
559
+ def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
560
+ user_id: Optional[str] = None) -> List[str]:
561
+ """Kill requests with a given request ID prefix."""
562
+ expanded_request_ids: Optional[List[str]] = None
563
+ if request_ids is not None:
564
+ expanded_request_ids = []
565
+ for request_id in request_ids:
566
+ request_tasks = get_requests_with_prefix(request_id,
567
+ fields=['request_id'])
568
+ if request_tasks is None or len(request_tasks) == 0:
569
+ continue
570
+ if len(request_tasks) > 1:
571
+ raise ValueError(f'Multiple requests found for '
572
+ f'request ID prefix: {request_id}')
573
+ expanded_request_ids.append(request_tasks[0].request_id)
574
+ return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
575
+
576
+
577
+ # needed for backward compatibility. Remove by v0.10.7 or v0.11.0
578
+ kill_requests = kill_requests_with_prefix
579
+
580
+
581
+ def _should_kill_request(request_id: str,
582
+ request_record: Optional[Request]) -> bool:
583
+ if request_record is None:
584
+ logger.debug(f'No request ID {request_id}')
585
+ return False
586
+ # Skip internal requests. The internal requests are scheduled with
587
+ # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
588
+ if request_record.request_id in set(
589
+ event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
590
+ return False
591
+ if request_record.status > RequestStatus.RUNNING:
592
+ logger.debug(f'Request {request_id} already finished')
593
+ return False
594
+ return True
595
+
596
+
597
+ def _kill_requests(request_ids: Optional[List[str]] = None,
598
+ user_id: Optional[str] = None) -> List[str]:
599
+ """Kill a SkyPilot API request and set its status to cancelled.
600
+
601
+ Args:
602
+ request_ids: The request IDs to kill. If None, all requests for the
603
+ user are killed.
604
+ user_id: The user ID to kill requests for. If None, all users are
605
+ killed.
606
+
607
+ Returns:
608
+ A list of request IDs that were cancelled.
609
+ """
610
+ if request_ids is None:
611
+ request_ids = [
612
+ request_task.request_id
613
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
614
+ status=[RequestStatus.PENDING, RequestStatus.RUNNING],
615
+ # Avoid cancelling the cancel request itself.
616
+ exclude_request_names=['sky.api_cancel'],
617
+ user_id=user_id,
618
+ fields=['request_id']))
619
+ ]
620
+ cancelled_request_ids = []
621
+ for request_id in request_ids:
622
+ with update_request(request_id) as request_record:
623
+ if not _should_kill_request(request_id, request_record):
624
+ continue
625
+ if request_record.pid is not None:
626
+ logger.debug(f'Killing request process {request_record.pid}')
627
+ # Use SIGTERM instead of SIGKILL:
628
+ # - The executor can handle SIGTERM gracefully
629
+ # - After SIGTERM, the executor can reuse the request process
630
+ # for other requests, avoiding the overhead of forking a new
631
+ # process for each request.
632
+ os.kill(request_record.pid, signal.SIGTERM)
633
+ request_record.status = RequestStatus.CANCELLED
634
+ request_record.finished_at = time.time()
635
+ cancelled_request_ids.append(request_id)
636
+ return cancelled_request_ids
637
+
638
+
639
+ @init_db_async
640
+ @asyncio_utils.shield
641
+ async def kill_request_async(request_id: str) -> bool:
642
+ """Kill a SkyPilot API request and set its status to cancelled.
643
+
644
+ Returns:
645
+ True if the request was killed, False otherwise.
646
+ """
647
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
648
+ request = await _get_request_no_lock_async(request_id)
649
+ if not _should_kill_request(request_id, request):
650
+ return False
651
+ assert request is not None
652
+ if request.pid is not None:
653
+ logger.debug(f'Killing request process {request.pid}')
654
+ # Use SIGTERM instead of SIGKILL:
655
+ # - The executor can handle SIGTERM gracefully
656
+ # - After SIGTERM, the executor can reuse the request process
657
+ # for other requests, avoiding the overhead of forking a new
658
+ # process for each request.
659
+ os.kill(request.pid, signal.SIGTERM)
660
+ request.status = RequestStatus.CANCELLED
661
+ request.finished_at = time.time()
662
+ await _add_or_update_request_no_lock_async(request)
663
+ return True
664
+
665
+
471
666
  @contextlib.contextmanager
472
667
  @init_db
473
668
  @metrics_lib.time_me
@@ -482,85 +677,144 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
482
677
  _add_or_update_request_no_lock(request)
483
678
 
484
679
 
485
- @init_db
680
+ @init_db_async
486
681
  @metrics_lib.time_me
487
- def update_request_async(
488
- request_id: str) -> AsyncContextManager[Optional[Request]]:
489
- """Async version of update_request.
490
-
491
- Returns an async context manager that yields the request record and
492
- persists any in-place updates upon exit.
493
- """
494
-
495
- @contextlib.asynccontextmanager
496
- async def _cm():
497
- # Acquire the lock to avoid race conditions between multiple request
498
- # operations, e.g. execute and cancel.
499
- async with filelock.AsyncFileLock(request_lock_path(request_id)):
500
- request = await _get_request_no_lock_async(request_id)
501
- try:
502
- yield request
503
- finally:
504
- if request is not None:
505
- await _add_or_update_request_no_lock_async(request)
506
-
507
- return _cm()
682
+ @asyncio_utils.shield
683
+ async def update_status_async(request_id: str, status: RequestStatus) -> None:
684
+ """Update the status of a request"""
685
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
686
+ request = await _get_request_no_lock_async(request_id)
687
+ if request is not None:
688
+ request.status = status
689
+ await _add_or_update_request_no_lock_async(request)
508
690
 
509
691
 
510
- _get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
511
- 'WHERE request_id LIKE ?')
692
+ @init_db_async
693
+ @metrics_lib.time_me
694
+ @asyncio_utils.shield
695
+ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
696
+ """Update the status message of a request"""
697
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
698
+ request = await _get_request_no_lock_async(request_id)
699
+ if request is not None:
700
+ request.status_msg = status_msg
701
+ await _add_or_update_request_no_lock_async(request)
512
702
 
513
703
 
514
- def _get_request_no_lock(request_id: str) -> Optional[Request]:
704
+ def _get_request_no_lock(
705
+ request_id: str,
706
+ fields: Optional[List[str]] = None) -> Optional[Request]:
515
707
  """Get a SkyPilot API request."""
516
708
  assert _DB is not None
709
+ columns_str = ', '.join(REQUEST_COLUMNS)
710
+ if fields:
711
+ columns_str = ', '.join(fields)
517
712
  with _DB.conn:
518
713
  cursor = _DB.conn.cursor()
519
- cursor.execute(_get_request_sql, (request_id + '%',))
714
+ cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
715
+ 'WHERE request_id LIKE ?'), (request_id + '%',))
520
716
  row = cursor.fetchone()
521
717
  if row is None:
522
718
  return None
719
+ if fields:
720
+ row = _update_request_row_fields(row, fields)
523
721
  return Request.from_row(row)
524
722
 
525
723
 
526
- async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:
724
+ async def _get_request_no_lock_async(
725
+ request_id: str,
726
+ fields: Optional[List[str]] = None) -> Optional[Request]:
527
727
  """Async version of _get_request_no_lock."""
528
728
  assert _DB is not None
529
- async with _DB.execute_fetchall_async(_get_request_sql,
530
- (request_id + '%',)) as rows:
729
+ columns_str = ', '.join(REQUEST_COLUMNS)
730
+ if fields:
731
+ columns_str = ', '.join(fields)
732
+ async with _DB.execute_fetchall_async(
733
+ (f'SELECT {columns_str} FROM {REQUEST_TABLE} '
734
+ 'WHERE request_id LIKE ?'), (request_id + '%',)) as rows:
531
735
  row = rows[0] if rows else None
532
736
  if row is None:
533
737
  return None
738
+ if fields:
739
+ row = _update_request_row_fields(row, fields)
534
740
  return Request.from_row(row)
535
741
 
536
742
 
537
- @init_db
743
+ @init_db_async
538
744
  @metrics_lib.time_me
539
- def get_latest_request_id() -> Optional[str]:
745
+ async def get_latest_request_id_async() -> Optional[str]:
540
746
  """Get the latest request ID."""
541
747
  assert _DB is not None
542
- with _DB.conn:
543
- cursor = _DB.conn.cursor()
544
- cursor.execute(f'SELECT request_id FROM {REQUEST_TABLE} '
545
- 'ORDER BY created_at DESC LIMIT 1')
546
- row = cursor.fetchone()
547
- return row[0] if row else None
748
+ async with _DB.execute_fetchall_async(
749
+ (f'SELECT request_id FROM {REQUEST_TABLE} '
750
+ 'ORDER BY created_at DESC LIMIT 1')) as rows:
751
+ return rows[0][0] if rows else None
548
752
 
549
753
 
550
754
  @init_db
551
755
  @metrics_lib.time_me
552
- def get_request(request_id: str) -> Optional[Request]:
756
+ def get_request(request_id: str,
757
+ fields: Optional[List[str]] = None) -> Optional[Request]:
553
758
  """Get a SkyPilot API request."""
554
759
  with filelock.FileLock(request_lock_path(request_id)):
555
- return _get_request_no_lock(request_id)
760
+ return _get_request_no_lock(request_id, fields)
556
761
 
557
762
 
558
763
  @init_db_async
559
764
  @metrics_lib.time_me_async
560
- async def get_request_async(request_id: str) -> Optional[Request]:
765
+ @asyncio_utils.shield
766
+ async def get_request_async(
767
+ request_id: str,
768
+ fields: Optional[List[str]] = None) -> Optional[Request]:
561
769
  """Async version of get_request."""
770
+ # TODO(aylei): figure out how to remove FileLock here to avoid the overhead
562
771
  async with filelock.AsyncFileLock(request_lock_path(request_id)):
563
- return await _get_request_no_lock_async(request_id)
772
+ return await _get_request_no_lock_async(request_id, fields)
773
+
774
+
775
+ @init_db
776
+ @metrics_lib.time_me
777
+ def get_requests_with_prefix(
778
+ request_id_prefix: str,
779
+ fields: Optional[List[str]] = None) -> Optional[List[Request]]:
780
+ """Get requests with a given request ID prefix."""
781
+ assert _DB is not None
782
+ if fields:
783
+ columns_str = ', '.join(fields)
784
+ else:
785
+ columns_str = ', '.join(REQUEST_COLUMNS)
786
+ with _DB.conn:
787
+ cursor = _DB.conn.cursor()
788
+ cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
789
+ 'WHERE request_id LIKE ?'), (request_id_prefix + '%',))
790
+ rows = cursor.fetchall()
791
+ if not rows:
792
+ return None
793
+ if fields:
794
+ rows = [_update_request_row_fields(row, fields) for row in rows]
795
+ return [Request.from_row(row) for row in rows]
796
+
797
+
798
+ @init_db_async
799
+ @metrics_lib.time_me_async
800
+ @asyncio_utils.shield
801
+ async def get_requests_async_with_prefix(
802
+ request_id_prefix: str,
803
+ fields: Optional[List[str]] = None) -> Optional[List[Request]]:
804
+ """Async version of get_request_with_prefix."""
805
+ assert _DB is not None
806
+ if fields:
807
+ columns_str = ', '.join(fields)
808
+ else:
809
+ columns_str = ', '.join(REQUEST_COLUMNS)
810
+ async with _DB.execute_fetchall_async(
811
+ (f'SELECT {columns_str} FROM {REQUEST_TABLE} '
812
+ 'WHERE request_id LIKE ?'), (request_id_prefix + '%',)) as rows:
813
+ if not rows:
814
+ return None
815
+ if fields:
816
+ rows = [_update_request_row_fields(row, fields) for row in rows]
817
+ return [Request.from_row(row) for row in rows]
564
818
 
565
819
 
566
820
  class StatusWithMsg(NamedTuple):
@@ -597,26 +851,29 @@ async def get_request_status_async(
597
851
  return StatusWithMsg(status, status_msg)
598
852
 
599
853
 
600
- @init_db
601
- @metrics_lib.time_me
602
- def create_if_not_exists(request: Request) -> bool:
603
- """Create a SkyPilot API request if it does not exist."""
604
- with filelock.FileLock(request_lock_path(request.request_id)):
605
- if _get_request_no_lock(request.request_id) is not None:
606
- return False
607
- _add_or_update_request_no_lock(request)
608
- return True
609
-
610
-
611
854
  @init_db_async
612
855
  @metrics_lib.time_me_async
856
+ @asyncio_utils.shield
613
857
  async def create_if_not_exists_async(request: Request) -> bool:
614
- """Async version of create_if_not_exists."""
615
- async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
616
- if await _get_request_no_lock_async(request.request_id) is not None:
617
- return False
618
- await _add_or_update_request_no_lock_async(request)
619
- return True
858
+ """Create a request if it does not exist, otherwise do nothing.
859
+
860
+ Returns:
861
+ True if a new request is created, False if the request already exists.
862
+ """
863
+ assert _DB is not None
864
+ request_columns = ', '.join(REQUEST_COLUMNS)
865
+ values_str = ', '.join(['?'] * len(REQUEST_COLUMNS))
866
+ sql_statement = (
867
+ f'INSERT INTO {REQUEST_TABLE} '
868
+ f'({request_columns}) VALUES '
869
+ f'({values_str}) ON CONFLICT(request_id) DO NOTHING RETURNING ROWID')
870
+ request_row = request.to_row()
871
+ # Execute the SQL statement without getting the request lock.
872
+ # The request lock is used to prevent racing with cancellation codepath,
873
+ # but a request cannot be cancelled before it is created.
874
+ row = await _DB.execute_get_returning_value_async(sql_statement,
875
+ request_row)
876
+ return True if row else False
620
877
 
621
878
 
622
879
  @dataclasses.dataclass
@@ -634,6 +891,7 @@ class RequestTaskFilter:
634
891
  Mutually exclusive with exclude_request_names.
635
892
  finished_before: if provided, only include requests finished before this
636
893
  timestamp.
894
+ limit: the number of requests to show. If None, show all requests.
637
895
 
638
896
  Raises:
639
897
  ValueError: If both exclude_request_names and include_request_names are
@@ -645,6 +903,9 @@ class RequestTaskFilter:
645
903
  exclude_request_names: Optional[List[str]] = None
646
904
  include_request_names: Optional[List[str]] = None
647
905
  finished_before: Optional[float] = None
906
+ limit: Optional[int] = None
907
+ fields: Optional[List[str]] = None
908
+ sort: bool = False
648
909
 
649
910
  def __post_init__(self):
650
911
  if (self.exclude_request_names is not None and
@@ -665,6 +926,10 @@ class RequestTaskFilter:
665
926
  status_list_str = ','.join(
666
927
  repr(status.value) for status in self.status)
667
928
  filters.append(f'status IN ({status_list_str})')
929
+ if self.include_request_names is not None:
930
+ request_names_str = ','.join(
931
+ repr(name) for name in self.include_request_names)
932
+ filters.append(f'name IN ({request_names_str})')
668
933
  if self.exclude_request_names is not None:
669
934
  exclude_request_names_str = ','.join(
670
935
  repr(name) for name in self.exclude_request_names)
@@ -676,10 +941,6 @@ class RequestTaskFilter:
676
941
  if self.user_id is not None:
677
942
  filters.append(f'{COL_USER_ID} = ?')
678
943
  filter_params.append(self.user_id)
679
- if self.include_request_names is not None:
680
- request_names_str = ','.join(
681
- repr(name) for name in self.include_request_names)
682
- filters.append(f'name IN ({request_names_str})')
683
944
  if self.finished_before is not None:
684
945
  filters.append('finished_at < ?')
685
946
  filter_params.append(self.finished_before)
@@ -687,8 +948,16 @@ class RequestTaskFilter:
687
948
  if filter_str:
688
949
  filter_str = f' WHERE {filter_str}'
689
950
  columns_str = ', '.join(REQUEST_COLUMNS)
690
- return (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
691
- 'ORDER BY created_at DESC'), filter_params
951
+ if self.fields:
952
+ columns_str = ', '.join(self.fields)
953
+ sort_str = ''
954
+ if self.sort:
955
+ sort_str = ' ORDER BY created_at DESC'
956
+ query_str = (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str}'
957
+ f'{sort_str}')
958
+ if self.limit is not None:
959
+ query_str += f' LIMIT {self.limit}'
960
+ return query_str, filter_params
692
961
 
693
962
 
694
963
  @init_db
@@ -707,6 +976,10 @@ def get_request_tasks(req_filter: RequestTaskFilter) -> List[Request]:
707
976
  rows = cursor.fetchall()
708
977
  if rows is None:
709
978
  return []
979
+ if req_filter.fields:
980
+ rows = [
981
+ _update_request_row_fields(row, req_filter.fields) for row in rows
982
+ ]
710
983
  return [Request.from_row(row) for row in rows]
711
984
 
712
985
 
@@ -719,6 +992,10 @@ async def get_request_tasks_async(
719
992
  async with _DB.execute_fetchall_async(*req_filter.build_query()) as rows:
720
993
  if not rows:
721
994
  return []
995
+ if req_filter.fields:
996
+ rows = [
997
+ _update_request_row_fields(row, req_filter.fields) for row in rows
998
+ ]
722
999
  return [Request.from_row(row) for row in rows]
723
1000
 
724
1001
 
@@ -776,6 +1053,23 @@ def set_request_failed(request_id: str, e: BaseException) -> None:
776
1053
  request_task.set_error(e)
777
1054
 
778
1055
 
1056
+ @init_db_async
1057
+ @metrics_lib.time_me_async
1058
+ @asyncio_utils.shield
1059
+ async def set_request_failed_async(request_id: str, e: BaseException) -> None:
1060
+ """Set a request to failed and populate the error message."""
1061
+ with ux_utils.enable_traceback():
1062
+ stacktrace = traceback.format_exc()
1063
+ setattr(e, 'stacktrace', stacktrace)
1064
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
1065
+ request_task = await _get_request_no_lock_async(request_id)
1066
+ assert request_task is not None, request_id
1067
+ request_task.status = RequestStatus.FAILED
1068
+ request_task.finished_at = time.time()
1069
+ request_task.set_error(e)
1070
+ await _add_or_update_request_no_lock_async(request_task)
1071
+
1072
+
779
1073
  def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
780
1074
  """Set a request to succeeded and populate the result."""
781
1075
  with update_request(request_id) as request_task:
@@ -786,28 +1080,50 @@ def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
786
1080
  request_task.set_return_value(result)
787
1081
 
788
1082
 
789
- def set_request_cancelled(request_id: str) -> None:
1083
+ @init_db_async
1084
+ @metrics_lib.time_me_async
1085
+ @asyncio_utils.shield
1086
+ async def set_request_succeeded_async(request_id: str,
1087
+ result: Optional[Any]) -> None:
1088
+ """Set a request to succeeded and populate the result."""
1089
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
1090
+ request_task = await _get_request_no_lock_async(request_id)
1091
+ assert request_task is not None, request_id
1092
+ request_task.status = RequestStatus.SUCCEEDED
1093
+ request_task.finished_at = time.time()
1094
+ if result is not None:
1095
+ request_task.set_return_value(result)
1096
+ await _add_or_update_request_no_lock_async(request_task)
1097
+
1098
+
1099
+ @init_db_async
1100
+ @metrics_lib.time_me_async
1101
+ @asyncio_utils.shield
1102
+ async def set_request_cancelled_async(request_id: str) -> None:
790
1103
  """Set a pending or running request to cancelled."""
791
- with update_request(request_id) as request_task:
1104
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
1105
+ request_task = await _get_request_no_lock_async(request_id)
792
1106
  assert request_task is not None, request_id
793
1107
  # Already finished or cancelled.
794
1108
  if request_task.status > RequestStatus.RUNNING:
795
1109
  return
796
1110
  request_task.finished_at = time.time()
797
1111
  request_task.status = RequestStatus.CANCELLED
1112
+ await _add_or_update_request_no_lock_async(request_task)
798
1113
 
799
1114
 
800
1115
  @init_db
801
1116
  @metrics_lib.time_me
802
- async def _delete_requests(requests: List[Request]):
1117
+ async def _delete_requests(request_ids: List[str]):
803
1118
  """Clean up requests by their IDs."""
804
- id_list_str = ','.join(repr(req.request_id) for req in requests)
1119
+ id_list_str = ','.join(repr(request_id) for request_id in request_ids)
805
1120
  assert _DB is not None
806
1121
  await _DB.execute_and_commit_async(
807
1122
  f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
808
1123
 
809
1124
 
810
- async def clean_finished_requests_with_retention(retention_seconds: int):
1125
+ async def clean_finished_requests_with_retention(retention_seconds: int,
1126
+ batch_size: int = 1000):
811
1127
  """Clean up finished requests older than the retention period.
812
1128
 
813
1129
  This function removes old finished requests (SUCCEEDED, FAILED, CANCELLED)
@@ -816,24 +1132,40 @@ async def clean_finished_requests_with_retention(retention_seconds: int):
816
1132
  Args:
817
1133
  retention_seconds: Requests older than this many seconds will be
818
1134
  deleted.
1135
+ batch_size: batch delete 'batch_size' requests at a time to
1136
+ avoid using too much memory and once and to let each
1137
+ db query complete in a reasonable time. All stale
1138
+ requests older than the retention period will be deleted
1139
+ regardless of the batch size.
819
1140
  """
820
- reqs = await get_request_tasks_async(
821
- req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
822
- finished_before=time.time() -
823
- retention_seconds))
824
-
825
- futs = []
826
- for req in reqs:
827
- futs.append(
828
- asyncio.create_task(
829
- anyio.Path(req.log_path.absolute()).unlink(missing_ok=True)))
830
- await asyncio.gather(*futs)
831
-
832
- await _delete_requests(reqs)
1141
+ total_deleted = 0
1142
+ while True:
1143
+ reqs = await get_request_tasks_async(
1144
+ req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
1145
+ finished_before=time.time() -
1146
+ retention_seconds,
1147
+ limit=batch_size,
1148
+ fields=['request_id']))
1149
+ if len(reqs) == 0:
1150
+ break
1151
+ futs = []
1152
+ for req in reqs:
1153
+ # req.log_path is derived from request_id,
1154
+ # so it's ok to just grab the request_id in the above query.
1155
+ futs.append(
1156
+ asyncio.create_task(
1157
+ anyio.Path(
1158
+ req.log_path.absolute()).unlink(missing_ok=True)))
1159
+ await asyncio.gather(*futs)
1160
+
1161
+ await _delete_requests([req.request_id for req in reqs])
1162
+ total_deleted += len(reqs)
1163
+ if len(reqs) < batch_size:
1164
+ break
833
1165
 
834
1166
  # To avoid leakage of the log file, logs must be deleted before the
835
1167
  # request task in the database.
836
- logger.info(f'Cleaned up {len(reqs)} finished requests '
1168
+ logger.info(f'Cleaned up {total_deleted} finished requests '
837
1169
  f'older than {retention_seconds} seconds')
838
1170
 
839
1171