skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +64 -32
- sky/adaptors/aws.py +23 -6
- sky/adaptors/azure.py +432 -15
- sky/adaptors/cloudflare.py +5 -5
- sky/adaptors/common.py +19 -9
- sky/adaptors/do.py +20 -0
- sky/adaptors/gcp.py +3 -2
- sky/adaptors/kubernetes.py +122 -88
- sky/adaptors/nebius.py +100 -0
- sky/adaptors/oci.py +39 -1
- sky/adaptors/vast.py +29 -0
- sky/admin_policy.py +101 -0
- sky/authentication.py +117 -98
- sky/backends/backend.py +52 -20
- sky/backends/backend_utils.py +669 -557
- sky/backends/cloud_vm_ray_backend.py +1099 -808
- sky/backends/local_docker_backend.py +14 -8
- sky/backends/wheel_utils.py +38 -20
- sky/benchmark/benchmark_utils.py +22 -23
- sky/check.py +76 -27
- sky/cli.py +1586 -1139
- sky/client/__init__.py +1 -0
- sky/client/cli.py +5683 -0
- sky/client/common.py +345 -0
- sky/client/sdk.py +1765 -0
- sky/cloud_stores.py +283 -19
- sky/clouds/__init__.py +7 -2
- sky/clouds/aws.py +303 -112
- sky/clouds/azure.py +185 -179
- sky/clouds/cloud.py +115 -37
- sky/clouds/cudo.py +29 -22
- sky/clouds/do.py +313 -0
- sky/clouds/fluidstack.py +44 -54
- sky/clouds/gcp.py +206 -65
- sky/clouds/ibm.py +26 -21
- sky/clouds/kubernetes.py +345 -91
- sky/clouds/lambda_cloud.py +40 -29
- sky/clouds/nebius.py +297 -0
- sky/clouds/oci.py +129 -90
- sky/clouds/paperspace.py +22 -18
- sky/clouds/runpod.py +53 -34
- sky/clouds/scp.py +28 -24
- sky/clouds/service_catalog/__init__.py +19 -13
- sky/clouds/service_catalog/aws_catalog.py +29 -12
- sky/clouds/service_catalog/azure_catalog.py +33 -6
- sky/clouds/service_catalog/common.py +95 -75
- sky/clouds/service_catalog/constants.py +3 -3
- sky/clouds/service_catalog/cudo_catalog.py +13 -3
- sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
- sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
- sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
- sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
- sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
- sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
- sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
- sky/clouds/service_catalog/do_catalog.py +111 -0
- sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
- sky/clouds/service_catalog/gcp_catalog.py +16 -2
- sky/clouds/service_catalog/ibm_catalog.py +2 -2
- sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
- sky/clouds/service_catalog/lambda_catalog.py +8 -3
- sky/clouds/service_catalog/nebius_catalog.py +116 -0
- sky/clouds/service_catalog/oci_catalog.py +31 -4
- sky/clouds/service_catalog/paperspace_catalog.py +2 -2
- sky/clouds/service_catalog/runpod_catalog.py +2 -2
- sky/clouds/service_catalog/scp_catalog.py +2 -2
- sky/clouds/service_catalog/vast_catalog.py +104 -0
- sky/clouds/service_catalog/vsphere_catalog.py +2 -2
- sky/clouds/utils/aws_utils.py +65 -0
- sky/clouds/utils/azure_utils.py +91 -0
- sky/clouds/utils/gcp_utils.py +5 -9
- sky/clouds/utils/oci_utils.py +47 -5
- sky/clouds/utils/scp_utils.py +4 -3
- sky/clouds/vast.py +280 -0
- sky/clouds/vsphere.py +22 -18
- sky/core.py +361 -107
- sky/dag.py +41 -28
- sky/data/data_transfer.py +37 -0
- sky/data/data_utils.py +211 -32
- sky/data/mounting_utils.py +182 -30
- sky/data/storage.py +2118 -270
- sky/data/storage_utils.py +126 -5
- sky/exceptions.py +179 -8
- sky/execution.py +158 -85
- sky/global_user_state.py +150 -34
- sky/jobs/__init__.py +12 -10
- sky/jobs/client/__init__.py +0 -0
- sky/jobs/client/sdk.py +302 -0
- sky/jobs/constants.py +49 -11
- sky/jobs/controller.py +161 -99
- sky/jobs/dashboard/dashboard.py +171 -25
- sky/jobs/dashboard/templates/index.html +572 -60
- sky/jobs/recovery_strategy.py +157 -156
- sky/jobs/scheduler.py +307 -0
- sky/jobs/server/__init__.py +1 -0
- sky/jobs/server/core.py +598 -0
- sky/jobs/server/dashboard_utils.py +69 -0
- sky/jobs/server/server.py +190 -0
- sky/jobs/state.py +627 -122
- sky/jobs/utils.py +615 -206
- sky/models.py +27 -0
- sky/optimizer.py +142 -83
- sky/provision/__init__.py +20 -5
- sky/provision/aws/config.py +124 -42
- sky/provision/aws/instance.py +130 -53
- sky/provision/azure/__init__.py +7 -0
- sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
- sky/provision/azure/config.py +220 -0
- sky/provision/azure/instance.py +1012 -37
- sky/provision/common.py +31 -3
- sky/provision/constants.py +25 -0
- sky/provision/cudo/__init__.py +2 -1
- sky/provision/cudo/cudo_utils.py +112 -0
- sky/provision/cudo/cudo_wrapper.py +37 -16
- sky/provision/cudo/instance.py +28 -12
- sky/provision/do/__init__.py +11 -0
- sky/provision/do/config.py +14 -0
- sky/provision/do/constants.py +10 -0
- sky/provision/do/instance.py +287 -0
- sky/provision/do/utils.py +301 -0
- sky/provision/docker_utils.py +82 -46
- sky/provision/fluidstack/fluidstack_utils.py +57 -125
- sky/provision/fluidstack/instance.py +15 -43
- sky/provision/gcp/config.py +19 -9
- sky/provision/gcp/constants.py +7 -1
- sky/provision/gcp/instance.py +55 -34
- sky/provision/gcp/instance_utils.py +339 -80
- sky/provision/gcp/mig_utils.py +210 -0
- sky/provision/instance_setup.py +172 -133
- sky/provision/kubernetes/__init__.py +1 -0
- sky/provision/kubernetes/config.py +104 -90
- sky/provision/kubernetes/constants.py +8 -0
- sky/provision/kubernetes/instance.py +680 -325
- sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
- sky/provision/kubernetes/network.py +54 -20
- sky/provision/kubernetes/network_utils.py +70 -21
- sky/provision/kubernetes/utils.py +1370 -251
- sky/provision/lambda_cloud/__init__.py +11 -0
- sky/provision/lambda_cloud/config.py +10 -0
- sky/provision/lambda_cloud/instance.py +265 -0
- sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
- sky/provision/logging.py +1 -1
- sky/provision/nebius/__init__.py +11 -0
- sky/provision/nebius/config.py +11 -0
- sky/provision/nebius/instance.py +285 -0
- sky/provision/nebius/utils.py +318 -0
- sky/provision/oci/__init__.py +15 -0
- sky/provision/oci/config.py +51 -0
- sky/provision/oci/instance.py +436 -0
- sky/provision/oci/query_utils.py +681 -0
- sky/provision/paperspace/constants.py +6 -0
- sky/provision/paperspace/instance.py +4 -3
- sky/provision/paperspace/utils.py +2 -0
- sky/provision/provisioner.py +207 -130
- sky/provision/runpod/__init__.py +1 -0
- sky/provision/runpod/api/__init__.py +3 -0
- sky/provision/runpod/api/commands.py +119 -0
- sky/provision/runpod/api/pods.py +142 -0
- sky/provision/runpod/instance.py +64 -8
- sky/provision/runpod/utils.py +239 -23
- sky/provision/vast/__init__.py +10 -0
- sky/provision/vast/config.py +11 -0
- sky/provision/vast/instance.py +247 -0
- sky/provision/vast/utils.py +162 -0
- sky/provision/vsphere/common/vim_utils.py +1 -1
- sky/provision/vsphere/instance.py +8 -18
- sky/provision/vsphere/vsphere_utils.py +1 -1
- sky/resources.py +247 -102
- sky/serve/__init__.py +9 -9
- sky/serve/autoscalers.py +361 -299
- sky/serve/client/__init__.py +0 -0
- sky/serve/client/sdk.py +366 -0
- sky/serve/constants.py +12 -3
- sky/serve/controller.py +106 -36
- sky/serve/load_balancer.py +63 -12
- sky/serve/load_balancing_policies.py +84 -2
- sky/serve/replica_managers.py +42 -34
- sky/serve/serve_state.py +62 -32
- sky/serve/serve_utils.py +271 -160
- sky/serve/server/__init__.py +0 -0
- sky/serve/{core.py → server/core.py} +271 -90
- sky/serve/server/server.py +112 -0
- sky/serve/service.py +52 -16
- sky/serve/service_spec.py +95 -32
- sky/server/__init__.py +1 -0
- sky/server/common.py +430 -0
- sky/server/constants.py +21 -0
- sky/server/html/log.html +174 -0
- sky/server/requests/__init__.py +0 -0
- sky/server/requests/executor.py +472 -0
- sky/server/requests/payloads.py +487 -0
- sky/server/requests/queues/__init__.py +0 -0
- sky/server/requests/queues/mp_queue.py +76 -0
- sky/server/requests/requests.py +567 -0
- sky/server/requests/serializers/__init__.py +0 -0
- sky/server/requests/serializers/decoders.py +192 -0
- sky/server/requests/serializers/encoders.py +166 -0
- sky/server/server.py +1106 -0
- sky/server/stream_utils.py +141 -0
- sky/setup_files/MANIFEST.in +2 -5
- sky/setup_files/dependencies.py +159 -0
- sky/setup_files/setup.py +14 -125
- sky/sky_logging.py +59 -14
- sky/skylet/autostop_lib.py +2 -2
- sky/skylet/constants.py +183 -50
- sky/skylet/events.py +22 -10
- sky/skylet/job_lib.py +403 -258
- sky/skylet/log_lib.py +111 -71
- sky/skylet/log_lib.pyi +6 -0
- sky/skylet/providers/command_runner.py +6 -8
- sky/skylet/providers/ibm/node_provider.py +2 -2
- sky/skylet/providers/scp/config.py +11 -3
- sky/skylet/providers/scp/node_provider.py +8 -8
- sky/skylet/skylet.py +3 -1
- sky/skylet/subprocess_daemon.py +69 -17
- sky/skypilot_config.py +119 -57
- sky/task.py +205 -64
- sky/templates/aws-ray.yml.j2 +37 -7
- sky/templates/azure-ray.yml.j2 +27 -82
- sky/templates/cudo-ray.yml.j2 +7 -3
- sky/templates/do-ray.yml.j2 +98 -0
- sky/templates/fluidstack-ray.yml.j2 +7 -4
- sky/templates/gcp-ray.yml.j2 +26 -6
- sky/templates/ibm-ray.yml.j2 +3 -2
- sky/templates/jobs-controller.yaml.j2 +46 -11
- sky/templates/kubernetes-ingress.yml.j2 +7 -0
- sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
- sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
- sky/templates/kubernetes-ray.yml.j2 +292 -25
- sky/templates/lambda-ray.yml.j2 +30 -40
- sky/templates/nebius-ray.yml.j2 +79 -0
- sky/templates/oci-ray.yml.j2 +18 -57
- sky/templates/paperspace-ray.yml.j2 +10 -6
- sky/templates/runpod-ray.yml.j2 +26 -4
- sky/templates/scp-ray.yml.j2 +3 -2
- sky/templates/sky-serve-controller.yaml.j2 +12 -1
- sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
- sky/templates/vast-ray.yml.j2 +70 -0
- sky/templates/vsphere-ray.yml.j2 +8 -3
- sky/templates/websocket_proxy.py +64 -0
- sky/usage/constants.py +10 -1
- sky/usage/usage_lib.py +130 -37
- sky/utils/accelerator_registry.py +35 -51
- sky/utils/admin_policy_utils.py +147 -0
- sky/utils/annotations.py +51 -0
- sky/utils/cli_utils/status_utils.py +81 -23
- sky/utils/cluster_utils.py +356 -0
- sky/utils/command_runner.py +452 -89
- sky/utils/command_runner.pyi +77 -3
- sky/utils/common.py +54 -0
- sky/utils/common_utils.py +319 -108
- sky/utils/config_utils.py +204 -0
- sky/utils/control_master_utils.py +48 -0
- sky/utils/controller_utils.py +548 -266
- sky/utils/dag_utils.py +93 -32
- sky/utils/db_utils.py +18 -4
- sky/utils/env_options.py +29 -7
- sky/utils/kubernetes/create_cluster.sh +8 -60
- sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
- sky/utils/kubernetes/gpu_labeler.py +4 -4
- sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
- sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
- sky/utils/kubernetes/rsync_helper.sh +24 -0
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
- sky/utils/log_utils.py +240 -33
- sky/utils/message_utils.py +81 -0
- sky/utils/registry.py +127 -0
- sky/utils/resources_utils.py +94 -22
- sky/utils/rich_utils.py +247 -18
- sky/utils/schemas.py +284 -64
- sky/{status_lib.py → utils/status_lib.py} +12 -7
- sky/utils/subprocess_utils.py +212 -46
- sky/utils/timeline.py +12 -7
- sky/utils/ux_utils.py +168 -15
- skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
- skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
- sky/clouds/cloud_registry.py +0 -31
- sky/jobs/core.py +0 -330
- sky/skylet/providers/azure/__init__.py +0 -2
- sky/skylet/providers/azure/azure-vm-template.json +0 -301
- sky/skylet/providers/azure/config.py +0 -170
- sky/skylet/providers/azure/node_provider.py +0 -466
- sky/skylet/providers/lambda_cloud/__init__.py +0 -2
- sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
- sky/skylet/providers/oci/__init__.py +0 -2
- sky/skylet/providers/oci/node_provider.py +0 -488
- sky/skylet/providers/oci/query_helper.py +0 -383
- sky/skylet/providers/oci/utils.py +0 -21
- sky/utils/cluster_yaml_utils.py +0 -24
- sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
- skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
- skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,567 @@
|
|
1
|
+
"""Utilities for REST API."""
|
2
|
+
import contextlib
|
3
|
+
import dataclasses
|
4
|
+
import enum
|
5
|
+
import functools
|
6
|
+
import json
|
7
|
+
import os
|
8
|
+
import pathlib
|
9
|
+
import shutil
|
10
|
+
import signal
|
11
|
+
import sqlite3
|
12
|
+
import time
|
13
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
14
|
+
|
15
|
+
import colorama
|
16
|
+
import filelock
|
17
|
+
|
18
|
+
from sky import exceptions
|
19
|
+
from sky import global_user_state
|
20
|
+
from sky import sky_logging
|
21
|
+
from sky.server import common as server_common
|
22
|
+
from sky.server import constants as server_constants
|
23
|
+
from sky.server.requests import payloads
|
24
|
+
from sky.server.requests.serializers import decoders
|
25
|
+
from sky.server.requests.serializers import encoders
|
26
|
+
from sky.utils import common
|
27
|
+
from sky.utils import common_utils
|
28
|
+
from sky.utils import db_utils
|
29
|
+
from sky.utils import env_options
|
30
|
+
|
31
|
+
logger = sky_logging.init_logger(__name__)
|
32
|
+
|
33
|
+
# Tables in task.db.
|
34
|
+
REQUEST_TABLE = 'requests'
|
35
|
+
COL_CLUSTER_NAME = 'cluster_name'
|
36
|
+
COL_USER_ID = 'user_id'
|
37
|
+
REQUEST_LOG_PATH_PREFIX = '~/sky_logs/api_server/requests'
|
38
|
+
|
39
|
+
# TODO(zhwu): For scalability, there are several TODOs:
|
40
|
+
# [x] Have a way to queue requests.
|
41
|
+
# [ ] Move logs to persistent place.
|
42
|
+
# [ ] Deploy API server in a autoscaling fashion.
|
43
|
+
|
44
|
+
|
45
|
+
class RequestStatus(enum.Enum):
|
46
|
+
"""The status of a request."""
|
47
|
+
|
48
|
+
PENDING = 'PENDING'
|
49
|
+
RUNNING = 'RUNNING'
|
50
|
+
SUCCEEDED = 'SUCCEEDED'
|
51
|
+
FAILED = 'FAILED'
|
52
|
+
CANCELLED = 'CANCELLED'
|
53
|
+
|
54
|
+
def __gt__(self, other):
|
55
|
+
return (list(RequestStatus).index(self) >
|
56
|
+
list(RequestStatus).index(other))
|
57
|
+
|
58
|
+
def colored_str(self):
|
59
|
+
color = _STATUS_TO_COLOR[self]
|
60
|
+
return f'{color}{self.value}{colorama.Style.RESET_ALL}'
|
61
|
+
|
62
|
+
|
63
|
+
_STATUS_TO_COLOR = {
|
64
|
+
RequestStatus.PENDING: colorama.Fore.BLUE,
|
65
|
+
RequestStatus.RUNNING: colorama.Fore.GREEN,
|
66
|
+
RequestStatus.SUCCEEDED: colorama.Fore.GREEN,
|
67
|
+
RequestStatus.FAILED: colorama.Fore.RED,
|
68
|
+
RequestStatus.CANCELLED: colorama.Fore.WHITE,
|
69
|
+
}
|
70
|
+
|
71
|
+
REQUEST_COLUMNS = [
|
72
|
+
'request_id',
|
73
|
+
'name',
|
74
|
+
'entrypoint',
|
75
|
+
'request_body',
|
76
|
+
'status',
|
77
|
+
'return_value',
|
78
|
+
'error',
|
79
|
+
'pid',
|
80
|
+
'created_at',
|
81
|
+
COL_CLUSTER_NAME,
|
82
|
+
'schedule_type',
|
83
|
+
COL_USER_ID,
|
84
|
+
]
|
85
|
+
|
86
|
+
|
87
|
+
class ScheduleType(enum.Enum):
|
88
|
+
"""The schedule type for the requests."""
|
89
|
+
LONG = 'long'
|
90
|
+
# Queue for requests that should be executed quickly for a quick response.
|
91
|
+
SHORT = 'short'
|
92
|
+
|
93
|
+
|
94
|
+
@dataclasses.dataclass
|
95
|
+
class RequestPayload:
|
96
|
+
"""The payload for the requests."""
|
97
|
+
|
98
|
+
request_id: str
|
99
|
+
name: str
|
100
|
+
entrypoint: str
|
101
|
+
request_body: str
|
102
|
+
status: str
|
103
|
+
created_at: float
|
104
|
+
user_id: str
|
105
|
+
return_value: str
|
106
|
+
error: str
|
107
|
+
pid: Optional[int]
|
108
|
+
schedule_type: str
|
109
|
+
user_name: Optional[str] = None
|
110
|
+
# Resources the request operates on.
|
111
|
+
cluster_name: Optional[str] = None
|
112
|
+
|
113
|
+
|
114
|
+
@dataclasses.dataclass
|
115
|
+
class Request:
|
116
|
+
"""A SkyPilot API request."""
|
117
|
+
|
118
|
+
request_id: str
|
119
|
+
name: str
|
120
|
+
entrypoint: Callable
|
121
|
+
request_body: payloads.RequestBody
|
122
|
+
status: RequestStatus
|
123
|
+
created_at: float
|
124
|
+
user_id: str
|
125
|
+
return_value: Any = None
|
126
|
+
error: Optional[Dict[str, Any]] = None
|
127
|
+
# The pid of the request worker that is(was) running this request.
|
128
|
+
pid: Optional[int] = None
|
129
|
+
schedule_type: ScheduleType = ScheduleType.LONG
|
130
|
+
# Resources the request operates on.
|
131
|
+
cluster_name: Optional[str] = None
|
132
|
+
|
133
|
+
@property
|
134
|
+
def log_path(self) -> pathlib.Path:
|
135
|
+
log_path_prefix = pathlib.Path(
|
136
|
+
REQUEST_LOG_PATH_PREFIX).expanduser().absolute()
|
137
|
+
log_path_prefix.mkdir(parents=True, exist_ok=True)
|
138
|
+
log_path = (log_path_prefix / self.request_id).with_suffix('.log')
|
139
|
+
return log_path
|
140
|
+
|
141
|
+
def set_error(self, error: Exception) -> None:
|
142
|
+
"""Set the error."""
|
143
|
+
# TODO(zhwu): pickle.dump does not work well with custom exceptions if
|
144
|
+
# it has more than 1 arguments.
|
145
|
+
serialized = exceptions.serialize_exception(error)
|
146
|
+
self.error = {
|
147
|
+
'object': encoders.pickle_and_encode(serialized),
|
148
|
+
'type': type(error).__name__,
|
149
|
+
'message': str(error),
|
150
|
+
}
|
151
|
+
|
152
|
+
def get_error(self) -> Optional[Dict[str, Any]]:
|
153
|
+
"""Get the error."""
|
154
|
+
if self.error is None:
|
155
|
+
return None
|
156
|
+
unpickled = decoders.decode_and_unpickle(self.error['object'])
|
157
|
+
deserialized = exceptions.deserialize_exception(unpickled)
|
158
|
+
return {
|
159
|
+
'object': deserialized,
|
160
|
+
'type': self.error['type'],
|
161
|
+
'message': self.error['message'],
|
162
|
+
}
|
163
|
+
|
164
|
+
def set_return_value(self, return_value: Any) -> None:
|
165
|
+
"""Set the return value."""
|
166
|
+
self.return_value = encoders.get_encoder(self.name)(return_value)
|
167
|
+
|
168
|
+
def get_return_value(self) -> Any:
|
169
|
+
"""Get the return value."""
|
170
|
+
return decoders.get_decoder(self.name)(self.return_value)
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def from_row(cls, row: Tuple[Any, ...]) -> 'Request':
|
174
|
+
content = dict(zip(REQUEST_COLUMNS, row))
|
175
|
+
return cls.decode(RequestPayload(**content))
|
176
|
+
|
177
|
+
def to_row(self) -> Tuple[Any, ...]:
|
178
|
+
payload = self.encode()
|
179
|
+
row = []
|
180
|
+
for k in REQUEST_COLUMNS:
|
181
|
+
row.append(getattr(payload, k))
|
182
|
+
return tuple(row)
|
183
|
+
|
184
|
+
def readable_encode(self) -> RequestPayload:
|
185
|
+
"""Serialize the SkyPilot API request for display purposes.
|
186
|
+
|
187
|
+
This function should be called on the server side to serialize the
|
188
|
+
request body into human readable format, e.g., the entrypoint should
|
189
|
+
be a string, and the pid, error, or return value are not needed.
|
190
|
+
|
191
|
+
The returned value will then be displayed on the client side in request
|
192
|
+
table.
|
193
|
+
|
194
|
+
We do not use `encode` for display to avoid a large amount of data being
|
195
|
+
sent to the client side, especially for the request table could include
|
196
|
+
all the requests.
|
197
|
+
"""
|
198
|
+
assert isinstance(self.request_body,
|
199
|
+
payloads.RequestBody), (self.name, self.request_body)
|
200
|
+
user_name = global_user_state.get_user(self.user_id).name
|
201
|
+
return RequestPayload(
|
202
|
+
request_id=self.request_id,
|
203
|
+
name=self.name,
|
204
|
+
entrypoint=self.entrypoint.__name__,
|
205
|
+
request_body=self.request_body.model_dump_json(),
|
206
|
+
status=self.status.value,
|
207
|
+
return_value=json.dumps(None),
|
208
|
+
error=json.dumps(None),
|
209
|
+
pid=None,
|
210
|
+
created_at=self.created_at,
|
211
|
+
schedule_type=self.schedule_type.value,
|
212
|
+
user_id=self.user_id,
|
213
|
+
user_name=user_name,
|
214
|
+
cluster_name=self.cluster_name,
|
215
|
+
)
|
216
|
+
|
217
|
+
def encode(self) -> RequestPayload:
|
218
|
+
"""Serialize the SkyPilot API request."""
|
219
|
+
assert isinstance(self.request_body,
|
220
|
+
payloads.RequestBody), (self.name, self.request_body)
|
221
|
+
try:
|
222
|
+
return RequestPayload(
|
223
|
+
request_id=self.request_id,
|
224
|
+
name=self.name,
|
225
|
+
entrypoint=encoders.pickle_and_encode(self.entrypoint),
|
226
|
+
request_body=encoders.pickle_and_encode(self.request_body),
|
227
|
+
status=self.status.value,
|
228
|
+
return_value=json.dumps(self.return_value),
|
229
|
+
error=json.dumps(self.error),
|
230
|
+
pid=self.pid,
|
231
|
+
created_at=self.created_at,
|
232
|
+
schedule_type=self.schedule_type.value,
|
233
|
+
user_id=self.user_id,
|
234
|
+
cluster_name=self.cluster_name,
|
235
|
+
)
|
236
|
+
except (TypeError, ValueError) as e:
|
237
|
+
# The error is unexpected, so we don't suppress the stack trace.
|
238
|
+
logger.error(
|
239
|
+
f'Error encoding: {e}\n'
|
240
|
+
f' {self.request_id}\n'
|
241
|
+
f' {self.name}\n'
|
242
|
+
f' {self.request_body}\n'
|
243
|
+
f' {self.return_value}\n'
|
244
|
+
f' {self.created_at}\n',
|
245
|
+
exc_info=e)
|
246
|
+
raise
|
247
|
+
|
248
|
+
@classmethod
|
249
|
+
def decode(cls, payload: RequestPayload) -> 'Request':
|
250
|
+
"""Deserialize the SkyPilot API request."""
|
251
|
+
try:
|
252
|
+
return cls(
|
253
|
+
request_id=payload.request_id,
|
254
|
+
name=payload.name,
|
255
|
+
entrypoint=decoders.decode_and_unpickle(payload.entrypoint),
|
256
|
+
request_body=decoders.decode_and_unpickle(payload.request_body),
|
257
|
+
status=RequestStatus(payload.status),
|
258
|
+
return_value=json.loads(payload.return_value),
|
259
|
+
error=json.loads(payload.error),
|
260
|
+
pid=payload.pid,
|
261
|
+
created_at=payload.created_at,
|
262
|
+
schedule_type=ScheduleType(payload.schedule_type),
|
263
|
+
user_id=payload.user_id,
|
264
|
+
cluster_name=payload.cluster_name,
|
265
|
+
)
|
266
|
+
except (TypeError, ValueError) as e:
|
267
|
+
logger.error(
|
268
|
+
f'Error decoding: {e}\n'
|
269
|
+
f' {payload.request_id}\n'
|
270
|
+
f' {payload.name}\n'
|
271
|
+
f' {payload.entrypoint}\n'
|
272
|
+
f' {payload.request_body}\n'
|
273
|
+
f' {payload.created_at}\n',
|
274
|
+
exc_info=e)
|
275
|
+
# The error is unexpected, so we don't suppress the stack trace.
|
276
|
+
raise
|
277
|
+
|
278
|
+
|
279
|
+
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
|
280
|
+
"""Kill all pending and running requests for a cluster.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
cluster_name: the name of the cluster.
|
284
|
+
exclude_request_names: exclude requests with these names. This is to
|
285
|
+
prevent killing the caller request.
|
286
|
+
"""
|
287
|
+
request_ids = [
|
288
|
+
request_task.request_id for request_task in get_request_tasks(
|
289
|
+
cluster_names=[cluster_name],
|
290
|
+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
|
291
|
+
exclude_request_names=[exclude_request_name])
|
292
|
+
]
|
293
|
+
kill_requests(request_ids)
|
294
|
+
|
295
|
+
|
296
|
+
def refresh_cluster_status_event():
|
297
|
+
"""Periodically refresh the cluster status."""
|
298
|
+
# pylint: disable=import-outside-toplevel
|
299
|
+
from sky import core
|
300
|
+
|
301
|
+
# Disable logging for periodic refresh to avoid the usage message being
|
302
|
+
# sent multiple times.
|
303
|
+
os.environ[env_options.Options.DISABLE_LOGGING.env_key] = '1'
|
304
|
+
|
305
|
+
while True:
|
306
|
+
logger.info('=== Refreshing cluster status ===')
|
307
|
+
# This periodically refresh will hold the lock for the cluster being
|
308
|
+
# refreshed, but it is OK because other operations will just wait for
|
309
|
+
# the lock and get the just refreshed status without refreshing again.
|
310
|
+
core.status(refresh=common.StatusRefreshMode.FORCE, all_users=True)
|
311
|
+
logger.info(
|
312
|
+
'Status refreshed. Sleeping '
|
313
|
+
f'{server_constants.CLUSTER_REFRESH_DAEMON_INTERVAL_SECONDS}'
|
314
|
+
' seconds for the next refresh...\n')
|
315
|
+
time.sleep(server_constants.CLUSTER_REFRESH_DAEMON_INTERVAL_SECONDS)
|
316
|
+
|
317
|
+
|
318
|
+
@dataclasses.dataclass
|
319
|
+
class InternalRequestDaemon:
|
320
|
+
id: str
|
321
|
+
name: str
|
322
|
+
event_fn: Callable[[], None]
|
323
|
+
|
324
|
+
|
325
|
+
# Register the events to run in the background.
|
326
|
+
INTERNAL_REQUEST_DAEMONS = [
|
327
|
+
# This status refresh daemon can cause the autostopp'ed/autodown'ed cluster
|
328
|
+
# set to updated status automatically, without showing users the hint of
|
329
|
+
# cluster being stopped or down when `sky status -r` is called.
|
330
|
+
InternalRequestDaemon(id='skypilot-status-refresh-daemon',
|
331
|
+
name='status',
|
332
|
+
event_fn=refresh_cluster_status_event)
|
333
|
+
]
|
334
|
+
|
335
|
+
|
336
|
+
def kill_requests(request_ids: Optional[List[str]] = None,
|
337
|
+
user_id: Optional[str] = None) -> List[str]:
|
338
|
+
"""Kill a SkyPilot API request and set its status to cancelled.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
request_ids: The request IDs to kill. If None, all requests for the
|
342
|
+
user are killed.
|
343
|
+
user_id: The user ID to kill requests for. If None, all users are
|
344
|
+
killed.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
A list of request IDs that were cancelled.
|
348
|
+
"""
|
349
|
+
if request_ids is None:
|
350
|
+
request_ids = [
|
351
|
+
request_task.request_id for request_task in get_request_tasks(
|
352
|
+
user_id=user_id,
|
353
|
+
status=[RequestStatus.RUNNING, RequestStatus.PENDING],
|
354
|
+
# Avoid cancelling the cancel request itself.
|
355
|
+
exclude_request_names=['sky.api_cancel'])
|
356
|
+
]
|
357
|
+
cancelled_request_ids = []
|
358
|
+
for request_id in request_ids:
|
359
|
+
with update_request(request_id) as request_record:
|
360
|
+
if request_record is None:
|
361
|
+
logger.debug(f'No request ID {request_id}')
|
362
|
+
continue
|
363
|
+
# Skip internal requests. The internal requests are scheduled with
|
364
|
+
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
|
365
|
+
if request_record.request_id in set(
|
366
|
+
event.id for event in INTERNAL_REQUEST_DAEMONS):
|
367
|
+
continue
|
368
|
+
if request_record.status > RequestStatus.RUNNING:
|
369
|
+
logger.debug(f'Request {request_id} already finished')
|
370
|
+
continue
|
371
|
+
if request_record.pid is not None:
|
372
|
+
logger.debug(f'Killing request process {request_record.pid}')
|
373
|
+
# Use SIGTERM instead of SIGKILL:
|
374
|
+
# - The executor can handle SIGTERM gracefully
|
375
|
+
# - After SIGTERM, the executor can reuse the request process
|
376
|
+
# for other requests, avoiding the overhead of forking a new
|
377
|
+
# process for each request.
|
378
|
+
os.kill(request_record.pid, signal.SIGTERM)
|
379
|
+
request_record.status = RequestStatus.CANCELLED
|
380
|
+
cancelled_request_ids.append(request_id)
|
381
|
+
return cancelled_request_ids
|
382
|
+
|
383
|
+
|
384
|
+
_DB_PATH = os.path.expanduser(server_constants.API_SERVER_REQUEST_DB_PATH)
|
385
|
+
pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)
|
386
|
+
|
387
|
+
|
388
|
+
def create_table(cursor, conn):
|
389
|
+
del conn
|
390
|
+
# Enable WAL mode to avoid locking issues.
|
391
|
+
# See: issue #1441 and PR #1509
|
392
|
+
# https://github.com/microsoft/WSL/issues/2395
|
393
|
+
# TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
|
394
|
+
# This may cause the database locked problem from WSL issue #1441.
|
395
|
+
if not common_utils.is_wsl():
|
396
|
+
try:
|
397
|
+
cursor.execute('PRAGMA journal_mode=WAL')
|
398
|
+
except sqlite3.OperationalError as e:
|
399
|
+
if 'database is locked' not in str(e):
|
400
|
+
raise
|
401
|
+
# If the database is locked, it is OK to continue, as the WAL mode
|
402
|
+
# is not critical and is likely to be enabled by other processes.
|
403
|
+
|
404
|
+
# Table for Requests
|
405
|
+
cursor.execute(f"""\
|
406
|
+
CREATE TABLE IF NOT EXISTS {REQUEST_TABLE} (
|
407
|
+
request_id TEXT PRIMARY KEY,
|
408
|
+
name TEXT,
|
409
|
+
entrypoint TEXT,
|
410
|
+
request_body TEXT,
|
411
|
+
status TEXT,
|
412
|
+
created_at REAL,
|
413
|
+
return_value TEXT,
|
414
|
+
error BLOB,
|
415
|
+
pid INTEGER,
|
416
|
+
{COL_CLUSTER_NAME} TEXT,
|
417
|
+
schedule_type TEXT,
|
418
|
+
{COL_USER_ID} TEXT)""")
|
419
|
+
|
420
|
+
|
421
|
+
_DB = None
|
422
|
+
|
423
|
+
|
424
|
+
def init_db(func):
|
425
|
+
"""Initialize the database."""
|
426
|
+
|
427
|
+
@functools.wraps(func)
|
428
|
+
def wrapper(*args, **kwargs):
|
429
|
+
global _DB
|
430
|
+
if _DB is None:
|
431
|
+
_DB = db_utils.SQLiteConn(_DB_PATH, create_table)
|
432
|
+
return func(*args, **kwargs)
|
433
|
+
|
434
|
+
return wrapper
|
435
|
+
|
436
|
+
|
437
|
+
def reset_db_and_logs():
|
438
|
+
"""Create the database."""
|
439
|
+
common_utils.remove_file_if_exists(_DB_PATH)
|
440
|
+
shutil.rmtree(pathlib.Path(REQUEST_LOG_PATH_PREFIX).expanduser(),
|
441
|
+
ignore_errors=True)
|
442
|
+
shutil.rmtree(server_common.API_SERVER_CLIENT_DIR.expanduser(),
|
443
|
+
ignore_errors=True)
|
444
|
+
|
445
|
+
|
446
|
+
def request_lock_path(request_id: str) -> str:
|
447
|
+
lock_path = os.path.expanduser(REQUEST_LOG_PATH_PREFIX)
|
448
|
+
os.makedirs(lock_path, exist_ok=True)
|
449
|
+
return os.path.join(lock_path, f'.{request_id}.lock')
|
450
|
+
|
451
|
+
|
452
|
+
@contextlib.contextmanager
|
453
|
+
@init_db
|
454
|
+
def update_request(request_id: str):
|
455
|
+
"""Get a SkyPilot API request."""
|
456
|
+
request = _get_request_no_lock(request_id)
|
457
|
+
yield request
|
458
|
+
if request is not None:
|
459
|
+
_add_or_update_request_no_lock(request)
|
460
|
+
|
461
|
+
|
462
|
+
def _get_request_no_lock(request_id: str) -> Optional[Request]:
|
463
|
+
"""Get a SkyPilot API request."""
|
464
|
+
assert _DB is not None
|
465
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
466
|
+
with _DB.conn:
|
467
|
+
cursor = _DB.conn.cursor()
|
468
|
+
cursor.execute(
|
469
|
+
f'SELECT {columns_str} FROM {REQUEST_TABLE} '
|
470
|
+
'WHERE request_id LIKE ?', (request_id + '%',))
|
471
|
+
row = cursor.fetchone()
|
472
|
+
if row is None:
|
473
|
+
return None
|
474
|
+
return Request.from_row(row)
|
475
|
+
|
476
|
+
|
477
|
+
@init_db
|
478
|
+
def get_latest_request_id() -> Optional[str]:
|
479
|
+
"""Get the latest request ID."""
|
480
|
+
assert _DB is not None
|
481
|
+
with _DB.conn:
|
482
|
+
cursor = _DB.conn.cursor()
|
483
|
+
cursor.execute(f'SELECT request_id FROM {REQUEST_TABLE} '
|
484
|
+
'ORDER BY created_at DESC LIMIT 1')
|
485
|
+
row = cursor.fetchone()
|
486
|
+
return row[0] if row else None
|
487
|
+
|
488
|
+
|
489
|
+
@init_db
|
490
|
+
def get_request(request_id: str) -> Optional[Request]:
|
491
|
+
"""Get a SkyPilot API request."""
|
492
|
+
with filelock.FileLock(request_lock_path(request_id)):
|
493
|
+
return _get_request_no_lock(request_id)
|
494
|
+
|
495
|
+
|
496
|
+
@init_db
|
497
|
+
def create_if_not_exists(request: Request) -> bool:
|
498
|
+
"""Create a SkyPilot API request if it does not exist."""
|
499
|
+
with filelock.FileLock(request_lock_path(request.request_id)):
|
500
|
+
if _get_request_no_lock(request.request_id) is not None:
|
501
|
+
return False
|
502
|
+
_add_or_update_request_no_lock(request)
|
503
|
+
return True
|
504
|
+
|
505
|
+
|
506
|
+
@init_db
|
507
|
+
def get_request_tasks(
|
508
|
+
status: Optional[List[RequestStatus]] = None,
|
509
|
+
cluster_names: Optional[List[str]] = None,
|
510
|
+
exclude_request_names: Optional[List[str]] = None,
|
511
|
+
user_id: Optional[str] = None,
|
512
|
+
) -> List[Request]:
|
513
|
+
"""Get a list of requests that match the given filters.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
status: a list of statuses of the requests to filter on.
|
517
|
+
cluster_names: a list of cluster names to filter requests on.
|
518
|
+
exclude_request_names: a list of request names to exclude from results.
|
519
|
+
user_id: the user ID to filter requests on.
|
520
|
+
If None, all users are included.
|
521
|
+
"""
|
522
|
+
filters = []
|
523
|
+
filter_params = []
|
524
|
+
if status is not None:
|
525
|
+
status_list_str = ','.join(repr(status.value) for status in status)
|
526
|
+
filters.append(f'status IN ({status_list_str})')
|
527
|
+
if exclude_request_names is not None:
|
528
|
+
exclude_request_names_str = ','.join(
|
529
|
+
repr(name) for name in exclude_request_names)
|
530
|
+
filters.append(f'name NOT IN ({exclude_request_names_str})')
|
531
|
+
if cluster_names is not None:
|
532
|
+
cluster_names_str = ','.join(repr(name) for name in cluster_names)
|
533
|
+
filters.append(f'{COL_CLUSTER_NAME} IN ({cluster_names_str})')
|
534
|
+
if user_id is not None:
|
535
|
+
filters.append(f'{COL_USER_ID} = ?')
|
536
|
+
filter_params.append(user_id)
|
537
|
+
assert _DB is not None
|
538
|
+
with _DB.conn:
|
539
|
+
cursor = _DB.conn.cursor()
|
540
|
+
filter_str = ' AND '.join(filters)
|
541
|
+
if filter_str:
|
542
|
+
filter_str = f' WHERE {filter_str}'
|
543
|
+
columns_str = ', '.join(REQUEST_COLUMNS)
|
544
|
+
cursor.execute(
|
545
|
+
f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
|
546
|
+
'ORDER BY created_at DESC', filter_params)
|
547
|
+
rows = cursor.fetchall()
|
548
|
+
if rows is None:
|
549
|
+
return []
|
550
|
+
requests = []
|
551
|
+
for row in rows:
|
552
|
+
request = Request.from_row(row)
|
553
|
+
requests.append(request)
|
554
|
+
return requests
|
555
|
+
|
556
|
+
|
557
|
+
def _add_or_update_request_no_lock(request: Request):
|
558
|
+
"""Add or update a REST request into the database."""
|
559
|
+
row = request.to_row()
|
560
|
+
key_str = ', '.join(REQUEST_COLUMNS)
|
561
|
+
fill_str = ', '.join(['?'] * len(row))
|
562
|
+
assert _DB is not None
|
563
|
+
with _DB.conn:
|
564
|
+
cursor = _DB.conn.cursor()
|
565
|
+
cursor.execute(
|
566
|
+
f'INSERT OR REPLACE INTO {REQUEST_TABLE} ({key_str}) '
|
567
|
+
f'VALUES ({fill_str})', row)
|
File without changes
|