skypilot-nightly 1.0.0.dev20251009__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +6 -2
- sky/adaptors/aws.py +25 -7
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/kubernetes.py +64 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +20 -0
- sky/authentication.py +59 -149
- sky/backends/backend_utils.py +104 -63
- sky/backends/cloud_vm_ray_backend.py +84 -39
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/catalog/kubernetes_catalog.py +24 -28
- sky/catalog/runpod_catalog.py +5 -1
- sky/catalog/shadeform_catalog.py +165 -0
- sky/check.py +25 -13
- sky/client/cli/command.py +335 -86
- sky/client/cli/flags.py +4 -2
- sky/client/cli/table_utils.py +17 -9
- sky/client/sdk.py +59 -12
- sky/cloud_stores.py +73 -0
- sky/clouds/__init__.py +2 -0
- sky/clouds/aws.py +71 -16
- sky/clouds/azure.py +12 -5
- sky/clouds/cloud.py +19 -9
- sky/clouds/cudo.py +12 -5
- sky/clouds/do.py +4 -1
- sky/clouds/fluidstack.py +12 -5
- sky/clouds/gcp.py +12 -5
- sky/clouds/hyperbolic.py +12 -5
- sky/clouds/ibm.py +12 -5
- sky/clouds/kubernetes.py +62 -25
- sky/clouds/lambda_cloud.py +12 -5
- sky/clouds/nebius.py +12 -5
- sky/clouds/oci.py +12 -5
- sky/clouds/paperspace.py +4 -1
- sky/clouds/primeintellect.py +4 -1
- sky/clouds/runpod.py +12 -5
- sky/clouds/scp.py +12 -5
- sky/clouds/seeweb.py +4 -1
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +4 -2
- sky/clouds/vast.py +12 -5
- sky/clouds/vsphere.py +4 -1
- sky/core.py +12 -11
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/{1871-49141c317f3a9020.js → 1871-74503c8e80fd253b.js} +1 -1
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
- sky/dashboard/out/_next/static/chunks/{3785.a19328ba41517b8b.js → 3785.ad6adaa2a0fa9768.js} +1 -1
- sky/dashboard/out/_next/static/chunks/{4725.10f7a9a5d3ea8208.js → 4725.a830b5c9e7867c92.js} +1 -1
- sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
- sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
- sky/dashboard/out/_next/static/chunks/pages/{_app-ce361c6959bc2001.js → _app-bde01e4a2beec258.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-477555ab7c0b13d8.js → [cluster]-a37d2063af475a1c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{clusters-2f61f65487f6d8ff.js → clusters-d44859594e6f8064.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-553b8b5cb65e100b.js → [context]-c0b5935149902e6f.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{infra-910a22500c50596f.js → infra-aed0ea19df7cf961.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-bc979970c247d8f3.js → [pool]-6edeb7d06032adfc.js} +2 -2
- sky/dashboard/out/_next/static/chunks/pages/{jobs-a35a9dc3c5ccd657.js → jobs-479dde13399cf270.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{users-98d2ed979084162a.js → users-5ab3b907622cf0fe.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{volumes-835d14ba94808f79.js → volumes-b84b948ff357c43e.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-e8688c35c06f0ac5.js → [name]-c5a3eeee1c218af1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{workspaces-69c80d677d3c2949.js → workspaces-22b23febb3e89ce1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +92 -1
- sky/data/mounting_utils.py +143 -19
- sky/data/storage.py +168 -11
- sky/exceptions.py +13 -1
- sky/execution.py +13 -0
- sky/global_user_state.py +189 -113
- sky/jobs/client/sdk.py +32 -10
- sky/jobs/client/sdk_async.py +9 -3
- sky/jobs/constants.py +3 -1
- sky/jobs/controller.py +164 -192
- sky/jobs/file_content_utils.py +80 -0
- sky/jobs/log_gc.py +201 -0
- sky/jobs/recovery_strategy.py +59 -82
- sky/jobs/scheduler.py +20 -9
- sky/jobs/server/core.py +105 -23
- sky/jobs/server/server.py +40 -28
- sky/jobs/server/utils.py +32 -11
- sky/jobs/state.py +588 -110
- sky/jobs/utils.py +442 -209
- sky/logs/agent.py +1 -1
- sky/metrics/utils.py +45 -6
- sky/optimizer.py +1 -1
- sky/provision/__init__.py +7 -0
- sky/provision/aws/instance.py +2 -1
- sky/provision/azure/instance.py +2 -1
- sky/provision/common.py +2 -0
- sky/provision/cudo/instance.py +2 -1
- sky/provision/do/instance.py +2 -1
- sky/provision/fluidstack/instance.py +4 -3
- sky/provision/gcp/instance.py +2 -1
- sky/provision/hyperbolic/instance.py +2 -1
- sky/provision/instance_setup.py +10 -2
- sky/provision/kubernetes/constants.py +0 -1
- sky/provision/kubernetes/instance.py +222 -89
- sky/provision/kubernetes/network.py +12 -8
- sky/provision/kubernetes/utils.py +114 -53
- sky/provision/kubernetes/volume.py +5 -4
- sky/provision/lambda_cloud/instance.py +2 -1
- sky/provision/nebius/instance.py +2 -1
- sky/provision/oci/instance.py +2 -1
- sky/provision/paperspace/instance.py +2 -1
- sky/provision/provisioner.py +11 -2
- sky/provision/runpod/instance.py +2 -1
- sky/provision/scp/instance.py +2 -1
- sky/provision/seeweb/instance.py +3 -3
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/vast/instance.py +2 -1
- sky/provision/vsphere/instance.py +2 -1
- sky/resources.py +1 -1
- sky/schemas/api/responses.py +9 -5
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/generated/jobsv1_pb2.py +52 -52
- sky/schemas/generated/jobsv1_pb2.pyi +4 -2
- sky/schemas/generated/managed_jobsv1_pb2.py +39 -35
- sky/schemas/generated/managed_jobsv1_pb2.pyi +21 -5
- sky/serve/client/impl.py +11 -3
- sky/serve/replica_managers.py +5 -2
- sky/serve/serve_utils.py +9 -2
- sky/serve/server/impl.py +7 -2
- sky/serve/server/server.py +18 -15
- sky/serve/service.py +2 -2
- sky/server/auth/oauth2_proxy.py +2 -5
- sky/server/common.py +31 -28
- sky/server/constants.py +5 -1
- sky/server/daemons.py +27 -19
- sky/server/requests/executor.py +138 -74
- sky/server/requests/payloads.py +9 -1
- sky/server/requests/preconditions.py +13 -10
- sky/server/requests/request_names.py +120 -0
- sky/server/requests/requests.py +485 -153
- sky/server/requests/serializers/decoders.py +26 -13
- sky/server/requests/serializers/encoders.py +56 -11
- sky/server/requests/threads.py +106 -0
- sky/server/rest.py +70 -18
- sky/server/server.py +283 -104
- sky/server/stream_utils.py +233 -59
- sky/server/uvicorn.py +18 -17
- sky/setup_files/alembic.ini +4 -0
- sky/setup_files/dependencies.py +32 -13
- sky/sky_logging.py +0 -2
- sky/skylet/constants.py +30 -7
- sky/skylet/events.py +7 -0
- sky/skylet/log_lib.py +8 -2
- sky/skylet/log_lib.pyi +1 -1
- sky/skylet/services.py +26 -13
- sky/skylet/subprocess_daemon.py +103 -29
- sky/skypilot_config.py +87 -75
- sky/ssh_node_pools/server.py +9 -8
- sky/task.py +67 -54
- sky/templates/kubernetes-ray.yml.j2 +8 -1
- sky/templates/nebius-ray.yml.j2 +1 -0
- sky/templates/shadeform-ray.yml.j2 +72 -0
- sky/templates/websocket_proxy.py +142 -12
- sky/users/permission.py +8 -1
- sky/utils/admin_policy_utils.py +16 -3
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/auth_utils.py +153 -0
- sky/utils/cli_utils/status_utils.py +8 -2
- sky/utils/command_runner.py +11 -0
- sky/utils/common.py +3 -1
- sky/utils/common_utils.py +7 -4
- sky/utils/context.py +57 -51
- sky/utils/context_utils.py +30 -12
- sky/utils/controller_utils.py +35 -8
- sky/utils/db/db_utils.py +37 -10
- sky/utils/db/migration_utils.py +8 -4
- sky/utils/locks.py +24 -6
- sky/utils/resource_checker.py +4 -1
- sky/utils/resources_utils.py +53 -29
- sky/utils/schemas.py +23 -4
- sky/utils/subprocess_utils.py +17 -4
- sky/volumes/server/server.py +7 -6
- sky/workspaces/server.py +13 -12
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/METADATA +306 -55
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/RECORD +215 -195
- sky/dashboard/out/_next/static/chunks/1121-d0782b9251f0fcd3.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-3b40c39626f99c89.js +0 -11
- sky/dashboard/out/_next/static/chunks/2755.97300e1362fe7c98.js +0 -26
- sky/dashboard/out/_next/static/chunks/3015-8d748834fcc60b46.js +0 -1
- sky/dashboard/out/_next/static/chunks/3294.1fafbf42b3bcebff.js +0 -1
- sky/dashboard/out/_next/static/chunks/6135-4b4d5e824b7f9d3c.js +0 -1
- sky/dashboard/out/_next/static/chunks/6856-5fdc9b851a18acdb.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-f6818c84ed8f1c86.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-66237729cdf9749e.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.c12318fb6a1a9093.js +0 -6
- sky/dashboard/out/_next/static/chunks/9360.71e83b2ddc844ec2.js +0 -31
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8f058b0346db2aff.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-4f7079dcab6ed653.js +0 -16
- sky/dashboard/out/_next/static/chunks/webpack-6a5ddd0184bfa22c.js +0 -1
- sky/dashboard/out/_next/static/css/4614e06482d7309e.css +0 -3
- sky/dashboard/out/_next/static/hIViZcQBkn0HE8SpaSsUU/_buildManifest.js +0 -1
- /sky/dashboard/out/_next/static/{hIViZcQBkn0HE8SpaSsUU → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/server/requests/requests.py
CHANGED
|
@@ -5,7 +5,6 @@ import contextlib
|
|
|
5
5
|
import dataclasses
|
|
6
6
|
import enum
|
|
7
7
|
import functools
|
|
8
|
-
import json
|
|
9
8
|
import os
|
|
10
9
|
import pathlib
|
|
11
10
|
import shutil
|
|
@@ -14,12 +13,14 @@ import sqlite3
|
|
|
14
13
|
import threading
|
|
15
14
|
import time
|
|
16
15
|
import traceback
|
|
17
|
-
from typing import (Any,
|
|
18
|
-
|
|
16
|
+
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
|
17
|
+
Tuple)
|
|
18
|
+
import uuid
|
|
19
19
|
|
|
20
20
|
import anyio
|
|
21
21
|
import colorama
|
|
22
22
|
import filelock
|
|
23
|
+
import orjson
|
|
23
24
|
|
|
24
25
|
from sky import exceptions
|
|
25
26
|
from sky import global_user_state
|
|
@@ -32,6 +33,7 @@ from sky.server import daemons
|
|
|
32
33
|
from sky.server.requests import payloads
|
|
33
34
|
from sky.server.requests.serializers import decoders
|
|
34
35
|
from sky.server.requests.serializers import encoders
|
|
36
|
+
from sky.utils import asyncio_utils
|
|
35
37
|
from sky.utils import common_utils
|
|
36
38
|
from sky.utils import ux_utils
|
|
37
39
|
from sky.utils.db import db_utils
|
|
@@ -211,8 +213,8 @@ class Request:
|
|
|
211
213
|
entrypoint=self.entrypoint.__name__,
|
|
212
214
|
request_body=self.request_body.model_dump_json(),
|
|
213
215
|
status=self.status.value,
|
|
214
|
-
return_value=
|
|
215
|
-
error=
|
|
216
|
+
return_value=orjson.dumps(None).decode('utf-8'),
|
|
217
|
+
error=orjson.dumps(None).decode('utf-8'),
|
|
216
218
|
pid=None,
|
|
217
219
|
created_at=self.created_at,
|
|
218
220
|
schedule_type=self.schedule_type.value,
|
|
@@ -235,8 +237,8 @@ class Request:
|
|
|
235
237
|
entrypoint=encoders.pickle_and_encode(self.entrypoint),
|
|
236
238
|
request_body=encoders.pickle_and_encode(self.request_body),
|
|
237
239
|
status=self.status.value,
|
|
238
|
-
return_value=
|
|
239
|
-
error=
|
|
240
|
+
return_value=orjson.dumps(self.return_value).decode('utf-8'),
|
|
241
|
+
error=orjson.dumps(self.error).decode('utf-8'),
|
|
240
242
|
pid=self.pid,
|
|
241
243
|
created_at=self.created_at,
|
|
242
244
|
schedule_type=self.schedule_type.value,
|
|
@@ -268,8 +270,8 @@ class Request:
|
|
|
268
270
|
entrypoint=decoders.decode_and_unpickle(payload.entrypoint),
|
|
269
271
|
request_body=decoders.decode_and_unpickle(payload.request_body),
|
|
270
272
|
status=RequestStatus(payload.status),
|
|
271
|
-
return_value=
|
|
272
|
-
error=
|
|
273
|
+
return_value=orjson.loads(payload.return_value),
|
|
274
|
+
error=orjson.loads(payload.error),
|
|
273
275
|
pid=payload.pid,
|
|
274
276
|
created_at=payload.created_at,
|
|
275
277
|
schedule_type=ScheduleType(payload.schedule_type),
|
|
@@ -292,72 +294,104 @@ class Request:
|
|
|
292
294
|
raise
|
|
293
295
|
|
|
294
296
|
|
|
295
|
-
def
|
|
296
|
-
"""
|
|
297
|
+
def get_new_request_id() -> str:
|
|
298
|
+
"""Get a new request ID."""
|
|
299
|
+
return str(uuid.uuid4())
|
|
297
300
|
|
|
298
|
-
Args:
|
|
299
|
-
cluster_name: the name of the cluster.
|
|
300
|
-
exclude_request_names: exclude requests with these names. This is to
|
|
301
|
-
prevent killing the caller request.
|
|
302
|
-
"""
|
|
303
|
-
request_ids = [
|
|
304
|
-
request_task.request_id
|
|
305
|
-
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
|
|
306
|
-
cluster_names=[cluster_name],
|
|
307
|
-
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
|
|
308
|
-
exclude_request_names=[exclude_request_name]))
|
|
309
|
-
]
|
|
310
|
-
kill_requests(request_ids)
|
|
311
301
|
|
|
302
|
+
def encode_requests(requests: List[Request]) -> List[payloads.RequestPayload]:
|
|
303
|
+
"""Serialize the SkyPilot API request for display purposes.
|
|
312
304
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
305
|
+
This function should be called on the server side to serialize the
|
|
306
|
+
request body into human readable format, e.g., the entrypoint should
|
|
307
|
+
be a string, and the pid, error, or return value are not needed.
|
|
316
308
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
user are killed.
|
|
320
|
-
user_id: The user ID to kill requests for. If None, all users are
|
|
321
|
-
killed.
|
|
309
|
+
The returned value will then be displayed on the client side in request
|
|
310
|
+
table.
|
|
322
311
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
312
|
+
We do not use `encode` for display to avoid a large amount of data being
|
|
313
|
+
sent to the client side, especially for the request table could include
|
|
314
|
+
all the requests.
|
|
315
|
+
"""
|
|
316
|
+
encoded_requests = []
|
|
317
|
+
all_users = global_user_state.get_all_users()
|
|
318
|
+
all_users_map = {user.id: user.name for user in all_users}
|
|
319
|
+
for request in requests:
|
|
320
|
+
if request.request_body is not None:
|
|
321
|
+
assert isinstance(request.request_body,
|
|
322
|
+
payloads.RequestBody), (request.name,
|
|
323
|
+
request.request_body)
|
|
324
|
+
user_name = all_users_map.get(request.user_id)
|
|
325
|
+
payload = payloads.RequestPayload(
|
|
326
|
+
request_id=request.request_id,
|
|
327
|
+
name=request.name,
|
|
328
|
+
entrypoint=request.entrypoint.__name__
|
|
329
|
+
if request.entrypoint is not None else '',
|
|
330
|
+
request_body=request.request_body.model_dump_json()
|
|
331
|
+
if request.request_body is not None else
|
|
332
|
+
orjson.dumps(None).decode('utf-8'),
|
|
333
|
+
status=request.status.value,
|
|
334
|
+
return_value=orjson.dumps(None).decode('utf-8'),
|
|
335
|
+
error=orjson.dumps(None).decode('utf-8'),
|
|
336
|
+
pid=None,
|
|
337
|
+
created_at=request.created_at,
|
|
338
|
+
schedule_type=request.schedule_type.value,
|
|
339
|
+
user_id=request.user_id,
|
|
340
|
+
user_name=user_name,
|
|
341
|
+
cluster_name=request.cluster_name,
|
|
342
|
+
status_msg=request.status_msg,
|
|
343
|
+
should_retry=request.should_retry,
|
|
344
|
+
finished_at=request.finished_at,
|
|
345
|
+
)
|
|
346
|
+
encoded_requests.append(payload)
|
|
347
|
+
return encoded_requests
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _update_request_row_fields(
|
|
351
|
+
row: Tuple[Any, ...],
|
|
352
|
+
fields: Optional[List[str]] = None) -> Tuple[Any, ...]:
|
|
353
|
+
"""Update the request row fields."""
|
|
354
|
+
if not fields:
|
|
355
|
+
return row
|
|
356
|
+
|
|
357
|
+
# Convert tuple to dictionary for easier manipulation
|
|
358
|
+
content = dict(zip(fields, row))
|
|
359
|
+
|
|
360
|
+
# Required fields in RequestPayload
|
|
361
|
+
if 'request_id' not in fields:
|
|
362
|
+
content['request_id'] = ''
|
|
363
|
+
if 'name' not in fields:
|
|
364
|
+
content['name'] = ''
|
|
365
|
+
if 'entrypoint' not in fields:
|
|
366
|
+
content['entrypoint'] = server_constants.EMPTY_PICKLED_VALUE
|
|
367
|
+
if 'request_body' not in fields:
|
|
368
|
+
content['request_body'] = server_constants.EMPTY_PICKLED_VALUE
|
|
369
|
+
if 'status' not in fields:
|
|
370
|
+
content['status'] = RequestStatus.PENDING.value
|
|
371
|
+
if 'created_at' not in fields:
|
|
372
|
+
content['created_at'] = 0
|
|
373
|
+
if 'user_id' not in fields:
|
|
374
|
+
content['user_id'] = ''
|
|
375
|
+
if 'return_value' not in fields:
|
|
376
|
+
content['return_value'] = orjson.dumps(None).decode('utf-8')
|
|
377
|
+
if 'error' not in fields:
|
|
378
|
+
content['error'] = orjson.dumps(None).decode('utf-8')
|
|
379
|
+
if 'schedule_type' not in fields:
|
|
380
|
+
content['schedule_type'] = ScheduleType.SHORT.value
|
|
381
|
+
# Optional fields in RequestPayload
|
|
382
|
+
if 'pid' not in fields:
|
|
383
|
+
content['pid'] = None
|
|
384
|
+
if 'cluster_name' not in fields:
|
|
385
|
+
content['cluster_name'] = None
|
|
386
|
+
if 'status_msg' not in fields:
|
|
387
|
+
content['status_msg'] = None
|
|
388
|
+
if 'should_retry' not in fields:
|
|
389
|
+
content['should_retry'] = False
|
|
390
|
+
if 'finished_at' not in fields:
|
|
391
|
+
content['finished_at'] = None
|
|
392
|
+
|
|
393
|
+
# Convert back to tuple in the same order as REQUEST_COLUMNS
|
|
394
|
+
return tuple(content[col] for col in REQUEST_COLUMNS)
|
|
361
395
|
|
|
362
396
|
|
|
363
397
|
def create_table(cursor, conn):
|
|
@@ -402,6 +436,21 @@ def create_table(cursor, conn):
|
|
|
402
436
|
db_utils.add_column_to_table(cursor, conn, REQUEST_TABLE, COL_FINISHED_AT,
|
|
403
437
|
'REAL')
|
|
404
438
|
|
|
439
|
+
# Add an index on (status, name) to speed up queries
|
|
440
|
+
# that filter on these columns.
|
|
441
|
+
cursor.execute(f"""\
|
|
442
|
+
CREATE INDEX IF NOT EXISTS status_name_idx ON {REQUEST_TABLE} (status, name) WHERE status IN ('PENDING', 'RUNNING');
|
|
443
|
+
""")
|
|
444
|
+
# Add an index on cluster_name to speed up queries
|
|
445
|
+
# that filter on this column.
|
|
446
|
+
cursor.execute(f"""\
|
|
447
|
+
CREATE INDEX IF NOT EXISTS cluster_name_idx ON {REQUEST_TABLE} ({COL_CLUSTER_NAME}) WHERE status IN ('PENDING', 'RUNNING');
|
|
448
|
+
""")
|
|
449
|
+
# Add an index on created_at to speed up queries that sort on this column.
|
|
450
|
+
cursor.execute(f"""\
|
|
451
|
+
CREATE INDEX IF NOT EXISTS created_at_idx ON {REQUEST_TABLE} (created_at);
|
|
452
|
+
""")
|
|
453
|
+
|
|
405
454
|
|
|
406
455
|
_DB = None
|
|
407
456
|
_init_db_lock = threading.Lock()
|
|
@@ -460,6 +509,26 @@ def reset_db_and_logs():
|
|
|
460
509
|
f'{server_common.API_SERVER_CLIENT_DIR.expanduser()}')
|
|
461
510
|
shutil.rmtree(server_common.API_SERVER_CLIENT_DIR.expanduser(),
|
|
462
511
|
ignore_errors=True)
|
|
512
|
+
with _init_db_lock:
|
|
513
|
+
_init_db_within_lock()
|
|
514
|
+
assert _DB is not None
|
|
515
|
+
with _DB.conn:
|
|
516
|
+
cursor = _DB.conn.cursor()
|
|
517
|
+
cursor.execute('SELECT sqlite_version()')
|
|
518
|
+
row = cursor.fetchone()
|
|
519
|
+
if row is None:
|
|
520
|
+
raise RuntimeError('Failed to get SQLite version')
|
|
521
|
+
version_str = row[0]
|
|
522
|
+
version_parts = version_str.split('.')
|
|
523
|
+
assert len(version_parts) >= 2, \
|
|
524
|
+
f'Invalid version string: {version_str}'
|
|
525
|
+
major, minor = int(version_parts[0]), int(version_parts[1])
|
|
526
|
+
# SQLite 3.35.0+ supports RETURNING statements.
|
|
527
|
+
# 3.35.0 was released in March 2021.
|
|
528
|
+
if not ((major > 3) or (major == 3 and minor >= 35)):
|
|
529
|
+
raise RuntimeError(
|
|
530
|
+
f'SQLite version {version_str} is not supported. '
|
|
531
|
+
'Please upgrade to SQLite 3.35.0 or later.')
|
|
463
532
|
|
|
464
533
|
|
|
465
534
|
def request_lock_path(request_id: str) -> str:
|
|
@@ -468,6 +537,132 @@ def request_lock_path(request_id: str) -> str:
|
|
|
468
537
|
return os.path.join(lock_path, f'.{request_id}.lock')
|
|
469
538
|
|
|
470
539
|
|
|
540
|
+
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
|
|
541
|
+
"""Kill all pending and running requests for a cluster.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
cluster_name: the name of the cluster.
|
|
545
|
+
exclude_request_names: exclude requests with these names. This is to
|
|
546
|
+
prevent killing the caller request.
|
|
547
|
+
"""
|
|
548
|
+
request_ids = [
|
|
549
|
+
request_task.request_id
|
|
550
|
+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
|
|
551
|
+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
|
|
552
|
+
exclude_request_names=[exclude_request_name],
|
|
553
|
+
cluster_names=[cluster_name],
|
|
554
|
+
fields=['request_id']))
|
|
555
|
+
]
|
|
556
|
+
_kill_requests(request_ids)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
|
|
560
|
+
user_id: Optional[str] = None) -> List[str]:
|
|
561
|
+
"""Kill requests with a given request ID prefix."""
|
|
562
|
+
expanded_request_ids: Optional[List[str]] = None
|
|
563
|
+
if request_ids is not None:
|
|
564
|
+
expanded_request_ids = []
|
|
565
|
+
for request_id in request_ids:
|
|
566
|
+
request_tasks = get_requests_with_prefix(request_id,
|
|
567
|
+
fields=['request_id'])
|
|
568
|
+
if request_tasks is None or len(request_tasks) == 0:
|
|
569
|
+
continue
|
|
570
|
+
if len(request_tasks) > 1:
|
|
571
|
+
raise ValueError(f'Multiple requests found for '
|
|
572
|
+
f'request ID prefix: {request_id}')
|
|
573
|
+
expanded_request_ids.append(request_tasks[0].request_id)
|
|
574
|
+
return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
# needed for backward compatibility. Remove by v0.10.7 or v0.11.0
|
|
578
|
+
kill_requests = kill_requests_with_prefix
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def _should_kill_request(request_id: str,
|
|
582
|
+
request_record: Optional[Request]) -> bool:
|
|
583
|
+
if request_record is None:
|
|
584
|
+
logger.debug(f'No request ID {request_id}')
|
|
585
|
+
return False
|
|
586
|
+
# Skip internal requests. The internal requests are scheduled with
|
|
587
|
+
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
|
|
588
|
+
if request_record.request_id in set(
|
|
589
|
+
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
|
|
590
|
+
return False
|
|
591
|
+
if request_record.status > RequestStatus.RUNNING:
|
|
592
|
+
logger.debug(f'Request {request_id} already finished')
|
|
593
|
+
return False
|
|
594
|
+
return True
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def _kill_requests(request_ids: Optional[List[str]] = None,
|
|
598
|
+
user_id: Optional[str] = None) -> List[str]:
|
|
599
|
+
"""Kill a SkyPilot API request and set its status to cancelled.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
request_ids: The request IDs to kill. If None, all requests for the
|
|
603
|
+
user are killed.
|
|
604
|
+
user_id: The user ID to kill requests for. If None, all users are
|
|
605
|
+
killed.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
A list of request IDs that were cancelled.
|
|
609
|
+
"""
|
|
610
|
+
if request_ids is None:
|
|
611
|
+
request_ids = [
|
|
612
|
+
request_task.request_id
|
|
613
|
+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
|
|
614
|
+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
|
|
615
|
+
# Avoid cancelling the cancel request itself.
|
|
616
|
+
exclude_request_names=['sky.api_cancel'],
|
|
617
|
+
user_id=user_id,
|
|
618
|
+
fields=['request_id']))
|
|
619
|
+
]
|
|
620
|
+
cancelled_request_ids = []
|
|
621
|
+
for request_id in request_ids:
|
|
622
|
+
with update_request(request_id) as request_record:
|
|
623
|
+
if not _should_kill_request(request_id, request_record):
|
|
624
|
+
continue
|
|
625
|
+
if request_record.pid is not None:
|
|
626
|
+
logger.debug(f'Killing request process {request_record.pid}')
|
|
627
|
+
# Use SIGTERM instead of SIGKILL:
|
|
628
|
+
# - The executor can handle SIGTERM gracefully
|
|
629
|
+
# - After SIGTERM, the executor can reuse the request process
|
|
630
|
+
# for other requests, avoiding the overhead of forking a new
|
|
631
|
+
# process for each request.
|
|
632
|
+
os.kill(request_record.pid, signal.SIGTERM)
|
|
633
|
+
request_record.status = RequestStatus.CANCELLED
|
|
634
|
+
request_record.finished_at = time.time()
|
|
635
|
+
cancelled_request_ids.append(request_id)
|
|
636
|
+
return cancelled_request_ids
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
@init_db_async
|
|
640
|
+
@asyncio_utils.shield
|
|
641
|
+
async def kill_request_async(request_id: str) -> bool:
|
|
642
|
+
"""Kill a SkyPilot API request and set its status to cancelled.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
True if the request was killed, False otherwise.
|
|
646
|
+
"""
|
|
647
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
648
|
+
request = await _get_request_no_lock_async(request_id)
|
|
649
|
+
if not _should_kill_request(request_id, request):
|
|
650
|
+
return False
|
|
651
|
+
assert request is not None
|
|
652
|
+
if request.pid is not None:
|
|
653
|
+
logger.debug(f'Killing request process {request.pid}')
|
|
654
|
+
# Use SIGTERM instead of SIGKILL:
|
|
655
|
+
# - The executor can handle SIGTERM gracefully
|
|
656
|
+
# - After SIGTERM, the executor can reuse the request process
|
|
657
|
+
# for other requests, avoiding the overhead of forking a new
|
|
658
|
+
# process for each request.
|
|
659
|
+
os.kill(request.pid, signal.SIGTERM)
|
|
660
|
+
request.status = RequestStatus.CANCELLED
|
|
661
|
+
request.finished_at = time.time()
|
|
662
|
+
await _add_or_update_request_no_lock_async(request)
|
|
663
|
+
return True
|
|
664
|
+
|
|
665
|
+
|
|
471
666
|
@contextlib.contextmanager
|
|
472
667
|
@init_db
|
|
473
668
|
@metrics_lib.time_me
|
|
@@ -482,85 +677,144 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
|
|
|
482
677
|
_add_or_update_request_no_lock(request)
|
|
483
678
|
|
|
484
679
|
|
|
485
|
-
@
|
|
680
|
+
@init_db_async
|
|
486
681
|
@metrics_lib.time_me
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
"""
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
@contextlib.asynccontextmanager
|
|
496
|
-
async def _cm():
|
|
497
|
-
# Acquire the lock to avoid race conditions between multiple request
|
|
498
|
-
# operations, e.g. execute and cancel.
|
|
499
|
-
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
500
|
-
request = await _get_request_no_lock_async(request_id)
|
|
501
|
-
try:
|
|
502
|
-
yield request
|
|
503
|
-
finally:
|
|
504
|
-
if request is not None:
|
|
505
|
-
await _add_or_update_request_no_lock_async(request)
|
|
506
|
-
|
|
507
|
-
return _cm()
|
|
682
|
+
@asyncio_utils.shield
|
|
683
|
+
async def update_status_async(request_id: str, status: RequestStatus) -> None:
|
|
684
|
+
"""Update the status of a request"""
|
|
685
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
686
|
+
request = await _get_request_no_lock_async(request_id)
|
|
687
|
+
if request is not None:
|
|
688
|
+
request.status = status
|
|
689
|
+
await _add_or_update_request_no_lock_async(request)
|
|
508
690
|
|
|
509
691
|
|
|
510
|
-
|
|
511
|
-
|
|
692
|
+
@init_db_async
|
|
693
|
+
@metrics_lib.time_me
|
|
694
|
+
@asyncio_utils.shield
|
|
695
|
+
async def update_status_msg_async(request_id: str, status_msg: str) -> None:
|
|
696
|
+
"""Update the status message of a request"""
|
|
697
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
698
|
+
request = await _get_request_no_lock_async(request_id)
|
|
699
|
+
if request is not None:
|
|
700
|
+
request.status_msg = status_msg
|
|
701
|
+
await _add_or_update_request_no_lock_async(request)
|
|
512
702
|
|
|
513
703
|
|
|
514
|
-
def _get_request_no_lock(
|
|
704
|
+
def _get_request_no_lock(
|
|
705
|
+
request_id: str,
|
|
706
|
+
fields: Optional[List[str]] = None) -> Optional[Request]:
|
|
515
707
|
"""Get a SkyPilot API request."""
|
|
516
708
|
assert _DB is not None
|
|
709
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
|
710
|
+
if fields:
|
|
711
|
+
columns_str = ', '.join(fields)
|
|
517
712
|
with _DB.conn:
|
|
518
713
|
cursor = _DB.conn.cursor()
|
|
519
|
-
cursor.execute(
|
|
714
|
+
cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
|
|
715
|
+
'WHERE request_id LIKE ?'), (request_id + '%',))
|
|
520
716
|
row = cursor.fetchone()
|
|
521
717
|
if row is None:
|
|
522
718
|
return None
|
|
719
|
+
if fields:
|
|
720
|
+
row = _update_request_row_fields(row, fields)
|
|
523
721
|
return Request.from_row(row)
|
|
524
722
|
|
|
525
723
|
|
|
526
|
-
async def _get_request_no_lock_async(
|
|
724
|
+
async def _get_request_no_lock_async(
|
|
725
|
+
request_id: str,
|
|
726
|
+
fields: Optional[List[str]] = None) -> Optional[Request]:
|
|
527
727
|
"""Async version of _get_request_no_lock."""
|
|
528
728
|
assert _DB is not None
|
|
529
|
-
|
|
530
|
-
|
|
729
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
|
730
|
+
if fields:
|
|
731
|
+
columns_str = ', '.join(fields)
|
|
732
|
+
async with _DB.execute_fetchall_async(
|
|
733
|
+
(f'SELECT {columns_str} FROM {REQUEST_TABLE} '
|
|
734
|
+
'WHERE request_id LIKE ?'), (request_id + '%',)) as rows:
|
|
531
735
|
row = rows[0] if rows else None
|
|
532
736
|
if row is None:
|
|
533
737
|
return None
|
|
738
|
+
if fields:
|
|
739
|
+
row = _update_request_row_fields(row, fields)
|
|
534
740
|
return Request.from_row(row)
|
|
535
741
|
|
|
536
742
|
|
|
537
|
-
@
|
|
743
|
+
@init_db_async
|
|
538
744
|
@metrics_lib.time_me
|
|
539
|
-
def
|
|
745
|
+
async def get_latest_request_id_async() -> Optional[str]:
|
|
540
746
|
"""Get the latest request ID."""
|
|
541
747
|
assert _DB is not None
|
|
542
|
-
with _DB.
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
row = cursor.fetchone()
|
|
547
|
-
return row[0] if row else None
|
|
748
|
+
async with _DB.execute_fetchall_async(
|
|
749
|
+
(f'SELECT request_id FROM {REQUEST_TABLE} '
|
|
750
|
+
'ORDER BY created_at DESC LIMIT 1')) as rows:
|
|
751
|
+
return rows[0][0] if rows else None
|
|
548
752
|
|
|
549
753
|
|
|
550
754
|
@init_db
|
|
551
755
|
@metrics_lib.time_me
|
|
552
|
-
def get_request(request_id: str
|
|
756
|
+
def get_request(request_id: str,
|
|
757
|
+
fields: Optional[List[str]] = None) -> Optional[Request]:
|
|
553
758
|
"""Get a SkyPilot API request."""
|
|
554
759
|
with filelock.FileLock(request_lock_path(request_id)):
|
|
555
|
-
return _get_request_no_lock(request_id)
|
|
760
|
+
return _get_request_no_lock(request_id, fields)
|
|
556
761
|
|
|
557
762
|
|
|
558
763
|
@init_db_async
|
|
559
764
|
@metrics_lib.time_me_async
|
|
560
|
-
|
|
765
|
+
@asyncio_utils.shield
|
|
766
|
+
async def get_request_async(
|
|
767
|
+
request_id: str,
|
|
768
|
+
fields: Optional[List[str]] = None) -> Optional[Request]:
|
|
561
769
|
"""Async version of get_request."""
|
|
770
|
+
# TODO(aylei): figure out how to remove FileLock here to avoid the overhead
|
|
562
771
|
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
563
|
-
return await _get_request_no_lock_async(request_id)
|
|
772
|
+
return await _get_request_no_lock_async(request_id, fields)
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
@init_db
|
|
776
|
+
@metrics_lib.time_me
|
|
777
|
+
def get_requests_with_prefix(
|
|
778
|
+
request_id_prefix: str,
|
|
779
|
+
fields: Optional[List[str]] = None) -> Optional[List[Request]]:
|
|
780
|
+
"""Get requests with a given request ID prefix."""
|
|
781
|
+
assert _DB is not None
|
|
782
|
+
if fields:
|
|
783
|
+
columns_str = ', '.join(fields)
|
|
784
|
+
else:
|
|
785
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
|
786
|
+
with _DB.conn:
|
|
787
|
+
cursor = _DB.conn.cursor()
|
|
788
|
+
cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
|
|
789
|
+
'WHERE request_id LIKE ?'), (request_id_prefix + '%',))
|
|
790
|
+
rows = cursor.fetchall()
|
|
791
|
+
if not rows:
|
|
792
|
+
return None
|
|
793
|
+
if fields:
|
|
794
|
+
rows = [_update_request_row_fields(row, fields) for row in rows]
|
|
795
|
+
return [Request.from_row(row) for row in rows]
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
@init_db_async
|
|
799
|
+
@metrics_lib.time_me_async
|
|
800
|
+
@asyncio_utils.shield
|
|
801
|
+
async def get_requests_async_with_prefix(
|
|
802
|
+
request_id_prefix: str,
|
|
803
|
+
fields: Optional[List[str]] = None) -> Optional[List[Request]]:
|
|
804
|
+
"""Async version of get_request_with_prefix."""
|
|
805
|
+
assert _DB is not None
|
|
806
|
+
if fields:
|
|
807
|
+
columns_str = ', '.join(fields)
|
|
808
|
+
else:
|
|
809
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
|
810
|
+
async with _DB.execute_fetchall_async(
|
|
811
|
+
(f'SELECT {columns_str} FROM {REQUEST_TABLE} '
|
|
812
|
+
'WHERE request_id LIKE ?'), (request_id_prefix + '%',)) as rows:
|
|
813
|
+
if not rows:
|
|
814
|
+
return None
|
|
815
|
+
if fields:
|
|
816
|
+
rows = [_update_request_row_fields(row, fields) for row in rows]
|
|
817
|
+
return [Request.from_row(row) for row in rows]
|
|
564
818
|
|
|
565
819
|
|
|
566
820
|
class StatusWithMsg(NamedTuple):
|
|
@@ -597,26 +851,29 @@ async def get_request_status_async(
|
|
|
597
851
|
return StatusWithMsg(status, status_msg)
|
|
598
852
|
|
|
599
853
|
|
|
600
|
-
@init_db
|
|
601
|
-
@metrics_lib.time_me
|
|
602
|
-
def create_if_not_exists(request: Request) -> bool:
|
|
603
|
-
"""Create a SkyPilot API request if it does not exist."""
|
|
604
|
-
with filelock.FileLock(request_lock_path(request.request_id)):
|
|
605
|
-
if _get_request_no_lock(request.request_id) is not None:
|
|
606
|
-
return False
|
|
607
|
-
_add_or_update_request_no_lock(request)
|
|
608
|
-
return True
|
|
609
|
-
|
|
610
|
-
|
|
611
854
|
@init_db_async
|
|
612
855
|
@metrics_lib.time_me_async
|
|
856
|
+
@asyncio_utils.shield
|
|
613
857
|
async def create_if_not_exists_async(request: Request) -> bool:
|
|
614
|
-
"""
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
858
|
+
"""Create a request if it does not exist, otherwise do nothing.
|
|
859
|
+
|
|
860
|
+
Returns:
|
|
861
|
+
True if a new request is created, False if the request already exists.
|
|
862
|
+
"""
|
|
863
|
+
assert _DB is not None
|
|
864
|
+
request_columns = ', '.join(REQUEST_COLUMNS)
|
|
865
|
+
values_str = ', '.join(['?'] * len(REQUEST_COLUMNS))
|
|
866
|
+
sql_statement = (
|
|
867
|
+
f'INSERT INTO {REQUEST_TABLE} '
|
|
868
|
+
f'({request_columns}) VALUES '
|
|
869
|
+
f'({values_str}) ON CONFLICT(request_id) DO NOTHING RETURNING ROWID')
|
|
870
|
+
request_row = request.to_row()
|
|
871
|
+
# Execute the SQL statement without getting the request lock.
|
|
872
|
+
# The request lock is used to prevent racing with cancellation codepath,
|
|
873
|
+
# but a request cannot be cancelled before it is created.
|
|
874
|
+
row = await _DB.execute_get_returning_value_async(sql_statement,
|
|
875
|
+
request_row)
|
|
876
|
+
return True if row else False
|
|
620
877
|
|
|
621
878
|
|
|
622
879
|
@dataclasses.dataclass
|
|
@@ -634,6 +891,7 @@ class RequestTaskFilter:
|
|
|
634
891
|
Mutually exclusive with exclude_request_names.
|
|
635
892
|
finished_before: if provided, only include requests finished before this
|
|
636
893
|
timestamp.
|
|
894
|
+
limit: the number of requests to show. If None, show all requests.
|
|
637
895
|
|
|
638
896
|
Raises:
|
|
639
897
|
ValueError: If both exclude_request_names and include_request_names are
|
|
@@ -645,6 +903,9 @@ class RequestTaskFilter:
|
|
|
645
903
|
exclude_request_names: Optional[List[str]] = None
|
|
646
904
|
include_request_names: Optional[List[str]] = None
|
|
647
905
|
finished_before: Optional[float] = None
|
|
906
|
+
limit: Optional[int] = None
|
|
907
|
+
fields: Optional[List[str]] = None
|
|
908
|
+
sort: bool = False
|
|
648
909
|
|
|
649
910
|
def __post_init__(self):
|
|
650
911
|
if (self.exclude_request_names is not None and
|
|
@@ -665,6 +926,10 @@ class RequestTaskFilter:
|
|
|
665
926
|
status_list_str = ','.join(
|
|
666
927
|
repr(status.value) for status in self.status)
|
|
667
928
|
filters.append(f'status IN ({status_list_str})')
|
|
929
|
+
if self.include_request_names is not None:
|
|
930
|
+
request_names_str = ','.join(
|
|
931
|
+
repr(name) for name in self.include_request_names)
|
|
932
|
+
filters.append(f'name IN ({request_names_str})')
|
|
668
933
|
if self.exclude_request_names is not None:
|
|
669
934
|
exclude_request_names_str = ','.join(
|
|
670
935
|
repr(name) for name in self.exclude_request_names)
|
|
@@ -676,10 +941,6 @@ class RequestTaskFilter:
|
|
|
676
941
|
if self.user_id is not None:
|
|
677
942
|
filters.append(f'{COL_USER_ID} = ?')
|
|
678
943
|
filter_params.append(self.user_id)
|
|
679
|
-
if self.include_request_names is not None:
|
|
680
|
-
request_names_str = ','.join(
|
|
681
|
-
repr(name) for name in self.include_request_names)
|
|
682
|
-
filters.append(f'name IN ({request_names_str})')
|
|
683
944
|
if self.finished_before is not None:
|
|
684
945
|
filters.append('finished_at < ?')
|
|
685
946
|
filter_params.append(self.finished_before)
|
|
@@ -687,8 +948,16 @@ class RequestTaskFilter:
|
|
|
687
948
|
if filter_str:
|
|
688
949
|
filter_str = f' WHERE {filter_str}'
|
|
689
950
|
columns_str = ', '.join(REQUEST_COLUMNS)
|
|
690
|
-
|
|
691
|
-
|
|
951
|
+
if self.fields:
|
|
952
|
+
columns_str = ', '.join(self.fields)
|
|
953
|
+
sort_str = ''
|
|
954
|
+
if self.sort:
|
|
955
|
+
sort_str = ' ORDER BY created_at DESC'
|
|
956
|
+
query_str = (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str}'
|
|
957
|
+
f'{sort_str}')
|
|
958
|
+
if self.limit is not None:
|
|
959
|
+
query_str += f' LIMIT {self.limit}'
|
|
960
|
+
return query_str, filter_params
|
|
692
961
|
|
|
693
962
|
|
|
694
963
|
@init_db
|
|
@@ -707,6 +976,10 @@ def get_request_tasks(req_filter: RequestTaskFilter) -> List[Request]:
|
|
|
707
976
|
rows = cursor.fetchall()
|
|
708
977
|
if rows is None:
|
|
709
978
|
return []
|
|
979
|
+
if req_filter.fields:
|
|
980
|
+
rows = [
|
|
981
|
+
_update_request_row_fields(row, req_filter.fields) for row in rows
|
|
982
|
+
]
|
|
710
983
|
return [Request.from_row(row) for row in rows]
|
|
711
984
|
|
|
712
985
|
|
|
@@ -719,6 +992,10 @@ async def get_request_tasks_async(
|
|
|
719
992
|
async with _DB.execute_fetchall_async(*req_filter.build_query()) as rows:
|
|
720
993
|
if not rows:
|
|
721
994
|
return []
|
|
995
|
+
if req_filter.fields:
|
|
996
|
+
rows = [
|
|
997
|
+
_update_request_row_fields(row, req_filter.fields) for row in rows
|
|
998
|
+
]
|
|
722
999
|
return [Request.from_row(row) for row in rows]
|
|
723
1000
|
|
|
724
1001
|
|
|
@@ -776,6 +1053,23 @@ def set_request_failed(request_id: str, e: BaseException) -> None:
|
|
|
776
1053
|
request_task.set_error(e)
|
|
777
1054
|
|
|
778
1055
|
|
|
1056
|
+
@init_db_async
|
|
1057
|
+
@metrics_lib.time_me_async
|
|
1058
|
+
@asyncio_utils.shield
|
|
1059
|
+
async def set_request_failed_async(request_id: str, e: BaseException) -> None:
|
|
1060
|
+
"""Set a request to failed and populate the error message."""
|
|
1061
|
+
with ux_utils.enable_traceback():
|
|
1062
|
+
stacktrace = traceback.format_exc()
|
|
1063
|
+
setattr(e, 'stacktrace', stacktrace)
|
|
1064
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
1065
|
+
request_task = await _get_request_no_lock_async(request_id)
|
|
1066
|
+
assert request_task is not None, request_id
|
|
1067
|
+
request_task.status = RequestStatus.FAILED
|
|
1068
|
+
request_task.finished_at = time.time()
|
|
1069
|
+
request_task.set_error(e)
|
|
1070
|
+
await _add_or_update_request_no_lock_async(request_task)
|
|
1071
|
+
|
|
1072
|
+
|
|
779
1073
|
def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
|
|
780
1074
|
"""Set a request to succeeded and populate the result."""
|
|
781
1075
|
with update_request(request_id) as request_task:
|
|
@@ -786,28 +1080,50 @@ def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
|
|
|
786
1080
|
request_task.set_return_value(result)
|
|
787
1081
|
|
|
788
1082
|
|
|
789
|
-
|
|
1083
|
+
@init_db_async
|
|
1084
|
+
@metrics_lib.time_me_async
|
|
1085
|
+
@asyncio_utils.shield
|
|
1086
|
+
async def set_request_succeeded_async(request_id: str,
|
|
1087
|
+
result: Optional[Any]) -> None:
|
|
1088
|
+
"""Set a request to succeeded and populate the result."""
|
|
1089
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
1090
|
+
request_task = await _get_request_no_lock_async(request_id)
|
|
1091
|
+
assert request_task is not None, request_id
|
|
1092
|
+
request_task.status = RequestStatus.SUCCEEDED
|
|
1093
|
+
request_task.finished_at = time.time()
|
|
1094
|
+
if result is not None:
|
|
1095
|
+
request_task.set_return_value(result)
|
|
1096
|
+
await _add_or_update_request_no_lock_async(request_task)
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
@init_db_async
|
|
1100
|
+
@metrics_lib.time_me_async
|
|
1101
|
+
@asyncio_utils.shield
|
|
1102
|
+
async def set_request_cancelled_async(request_id: str) -> None:
|
|
790
1103
|
"""Set a pending or running request to cancelled."""
|
|
791
|
-
with
|
|
1104
|
+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
|
|
1105
|
+
request_task = await _get_request_no_lock_async(request_id)
|
|
792
1106
|
assert request_task is not None, request_id
|
|
793
1107
|
# Already finished or cancelled.
|
|
794
1108
|
if request_task.status > RequestStatus.RUNNING:
|
|
795
1109
|
return
|
|
796
1110
|
request_task.finished_at = time.time()
|
|
797
1111
|
request_task.status = RequestStatus.CANCELLED
|
|
1112
|
+
await _add_or_update_request_no_lock_async(request_task)
|
|
798
1113
|
|
|
799
1114
|
|
|
800
1115
|
@init_db
|
|
801
1116
|
@metrics_lib.time_me
|
|
802
|
-
async def _delete_requests(
|
|
1117
|
+
async def _delete_requests(request_ids: List[str]):
|
|
803
1118
|
"""Clean up requests by their IDs."""
|
|
804
|
-
id_list_str = ','.join(repr(
|
|
1119
|
+
id_list_str = ','.join(repr(request_id) for request_id in request_ids)
|
|
805
1120
|
assert _DB is not None
|
|
806
1121
|
await _DB.execute_and_commit_async(
|
|
807
1122
|
f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
|
|
808
1123
|
|
|
809
1124
|
|
|
810
|
-
async def clean_finished_requests_with_retention(retention_seconds: int
|
|
1125
|
+
async def clean_finished_requests_with_retention(retention_seconds: int,
|
|
1126
|
+
batch_size: int = 1000):
|
|
811
1127
|
"""Clean up finished requests older than the retention period.
|
|
812
1128
|
|
|
813
1129
|
This function removes old finished requests (SUCCEEDED, FAILED, CANCELLED)
|
|
@@ -816,24 +1132,40 @@ async def clean_finished_requests_with_retention(retention_seconds: int):
|
|
|
816
1132
|
Args:
|
|
817
1133
|
retention_seconds: Requests older than this many seconds will be
|
|
818
1134
|
deleted.
|
|
1135
|
+
batch_size: batch delete 'batch_size' requests at a time to
|
|
1136
|
+
avoid using too much memory and once and to let each
|
|
1137
|
+
db query complete in a reasonable time. All stale
|
|
1138
|
+
requests older than the retention period will be deleted
|
|
1139
|
+
regardless of the batch size.
|
|
819
1140
|
"""
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
1141
|
+
total_deleted = 0
|
|
1142
|
+
while True:
|
|
1143
|
+
reqs = await get_request_tasks_async(
|
|
1144
|
+
req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
|
|
1145
|
+
finished_before=time.time() -
|
|
1146
|
+
retention_seconds,
|
|
1147
|
+
limit=batch_size,
|
|
1148
|
+
fields=['request_id']))
|
|
1149
|
+
if len(reqs) == 0:
|
|
1150
|
+
break
|
|
1151
|
+
futs = []
|
|
1152
|
+
for req in reqs:
|
|
1153
|
+
# req.log_path is derived from request_id,
|
|
1154
|
+
# so it's ok to just grab the request_id in the above query.
|
|
1155
|
+
futs.append(
|
|
1156
|
+
asyncio.create_task(
|
|
1157
|
+
anyio.Path(
|
|
1158
|
+
req.log_path.absolute()).unlink(missing_ok=True)))
|
|
1159
|
+
await asyncio.gather(*futs)
|
|
1160
|
+
|
|
1161
|
+
await _delete_requests([req.request_id for req in reqs])
|
|
1162
|
+
total_deleted += len(reqs)
|
|
1163
|
+
if len(reqs) < batch_size:
|
|
1164
|
+
break
|
|
833
1165
|
|
|
834
1166
|
# To avoid leakage of the log file, logs must be deleted before the
|
|
835
1167
|
# request task in the database.
|
|
836
|
-
logger.info(f'Cleaned up {
|
|
1168
|
+
logger.info(f'Cleaned up {total_deleted} finished requests '
|
|
837
1169
|
f'older than {retention_seconds} seconds')
|
|
838
1170
|
|
|
839
1171
|
|