skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +64 -32
- sky/adaptors/aws.py +23 -6
- sky/adaptors/azure.py +432 -15
- sky/adaptors/cloudflare.py +5 -5
- sky/adaptors/common.py +19 -9
- sky/adaptors/do.py +20 -0
- sky/adaptors/gcp.py +3 -2
- sky/adaptors/kubernetes.py +122 -88
- sky/adaptors/nebius.py +100 -0
- sky/adaptors/oci.py +39 -1
- sky/adaptors/vast.py +29 -0
- sky/admin_policy.py +101 -0
- sky/authentication.py +117 -98
- sky/backends/backend.py +52 -20
- sky/backends/backend_utils.py +669 -557
- sky/backends/cloud_vm_ray_backend.py +1099 -808
- sky/backends/local_docker_backend.py +14 -8
- sky/backends/wheel_utils.py +38 -20
- sky/benchmark/benchmark_utils.py +22 -23
- sky/check.py +76 -27
- sky/cli.py +1586 -1139
- sky/client/__init__.py +1 -0
- sky/client/cli.py +5683 -0
- sky/client/common.py +345 -0
- sky/client/sdk.py +1765 -0
- sky/cloud_stores.py +283 -19
- sky/clouds/__init__.py +7 -2
- sky/clouds/aws.py +303 -112
- sky/clouds/azure.py +185 -179
- sky/clouds/cloud.py +115 -37
- sky/clouds/cudo.py +29 -22
- sky/clouds/do.py +313 -0
- sky/clouds/fluidstack.py +44 -54
- sky/clouds/gcp.py +206 -65
- sky/clouds/ibm.py +26 -21
- sky/clouds/kubernetes.py +345 -91
- sky/clouds/lambda_cloud.py +40 -29
- sky/clouds/nebius.py +297 -0
- sky/clouds/oci.py +129 -90
- sky/clouds/paperspace.py +22 -18
- sky/clouds/runpod.py +53 -34
- sky/clouds/scp.py +28 -24
- sky/clouds/service_catalog/__init__.py +19 -13
- sky/clouds/service_catalog/aws_catalog.py +29 -12
- sky/clouds/service_catalog/azure_catalog.py +33 -6
- sky/clouds/service_catalog/common.py +95 -75
- sky/clouds/service_catalog/constants.py +3 -3
- sky/clouds/service_catalog/cudo_catalog.py +13 -3
- sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
- sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
- sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
- sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
- sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
- sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
- sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
- sky/clouds/service_catalog/do_catalog.py +111 -0
- sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
- sky/clouds/service_catalog/gcp_catalog.py +16 -2
- sky/clouds/service_catalog/ibm_catalog.py +2 -2
- sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
- sky/clouds/service_catalog/lambda_catalog.py +8 -3
- sky/clouds/service_catalog/nebius_catalog.py +116 -0
- sky/clouds/service_catalog/oci_catalog.py +31 -4
- sky/clouds/service_catalog/paperspace_catalog.py +2 -2
- sky/clouds/service_catalog/runpod_catalog.py +2 -2
- sky/clouds/service_catalog/scp_catalog.py +2 -2
- sky/clouds/service_catalog/vast_catalog.py +104 -0
- sky/clouds/service_catalog/vsphere_catalog.py +2 -2
- sky/clouds/utils/aws_utils.py +65 -0
- sky/clouds/utils/azure_utils.py +91 -0
- sky/clouds/utils/gcp_utils.py +5 -9
- sky/clouds/utils/oci_utils.py +47 -5
- sky/clouds/utils/scp_utils.py +4 -3
- sky/clouds/vast.py +280 -0
- sky/clouds/vsphere.py +22 -18
- sky/core.py +361 -107
- sky/dag.py +41 -28
- sky/data/data_transfer.py +37 -0
- sky/data/data_utils.py +211 -32
- sky/data/mounting_utils.py +182 -30
- sky/data/storage.py +2118 -270
- sky/data/storage_utils.py +126 -5
- sky/exceptions.py +179 -8
- sky/execution.py +158 -85
- sky/global_user_state.py +150 -34
- sky/jobs/__init__.py +12 -10
- sky/jobs/client/__init__.py +0 -0
- sky/jobs/client/sdk.py +302 -0
- sky/jobs/constants.py +49 -11
- sky/jobs/controller.py +161 -99
- sky/jobs/dashboard/dashboard.py +171 -25
- sky/jobs/dashboard/templates/index.html +572 -60
- sky/jobs/recovery_strategy.py +157 -156
- sky/jobs/scheduler.py +307 -0
- sky/jobs/server/__init__.py +1 -0
- sky/jobs/server/core.py +598 -0
- sky/jobs/server/dashboard_utils.py +69 -0
- sky/jobs/server/server.py +190 -0
- sky/jobs/state.py +627 -122
- sky/jobs/utils.py +615 -206
- sky/models.py +27 -0
- sky/optimizer.py +142 -83
- sky/provision/__init__.py +20 -5
- sky/provision/aws/config.py +124 -42
- sky/provision/aws/instance.py +130 -53
- sky/provision/azure/__init__.py +7 -0
- sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
- sky/provision/azure/config.py +220 -0
- sky/provision/azure/instance.py +1012 -37
- sky/provision/common.py +31 -3
- sky/provision/constants.py +25 -0
- sky/provision/cudo/__init__.py +2 -1
- sky/provision/cudo/cudo_utils.py +112 -0
- sky/provision/cudo/cudo_wrapper.py +37 -16
- sky/provision/cudo/instance.py +28 -12
- sky/provision/do/__init__.py +11 -0
- sky/provision/do/config.py +14 -0
- sky/provision/do/constants.py +10 -0
- sky/provision/do/instance.py +287 -0
- sky/provision/do/utils.py +301 -0
- sky/provision/docker_utils.py +82 -46
- sky/provision/fluidstack/fluidstack_utils.py +57 -125
- sky/provision/fluidstack/instance.py +15 -43
- sky/provision/gcp/config.py +19 -9
- sky/provision/gcp/constants.py +7 -1
- sky/provision/gcp/instance.py +55 -34
- sky/provision/gcp/instance_utils.py +339 -80
- sky/provision/gcp/mig_utils.py +210 -0
- sky/provision/instance_setup.py +172 -133
- sky/provision/kubernetes/__init__.py +1 -0
- sky/provision/kubernetes/config.py +104 -90
- sky/provision/kubernetes/constants.py +8 -0
- sky/provision/kubernetes/instance.py +680 -325
- sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
- sky/provision/kubernetes/network.py +54 -20
- sky/provision/kubernetes/network_utils.py +70 -21
- sky/provision/kubernetes/utils.py +1370 -251
- sky/provision/lambda_cloud/__init__.py +11 -0
- sky/provision/lambda_cloud/config.py +10 -0
- sky/provision/lambda_cloud/instance.py +265 -0
- sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
- sky/provision/logging.py +1 -1
- sky/provision/nebius/__init__.py +11 -0
- sky/provision/nebius/config.py +11 -0
- sky/provision/nebius/instance.py +285 -0
- sky/provision/nebius/utils.py +318 -0
- sky/provision/oci/__init__.py +15 -0
- sky/provision/oci/config.py +51 -0
- sky/provision/oci/instance.py +436 -0
- sky/provision/oci/query_utils.py +681 -0
- sky/provision/paperspace/constants.py +6 -0
- sky/provision/paperspace/instance.py +4 -3
- sky/provision/paperspace/utils.py +2 -0
- sky/provision/provisioner.py +207 -130
- sky/provision/runpod/__init__.py +1 -0
- sky/provision/runpod/api/__init__.py +3 -0
- sky/provision/runpod/api/commands.py +119 -0
- sky/provision/runpod/api/pods.py +142 -0
- sky/provision/runpod/instance.py +64 -8
- sky/provision/runpod/utils.py +239 -23
- sky/provision/vast/__init__.py +10 -0
- sky/provision/vast/config.py +11 -0
- sky/provision/vast/instance.py +247 -0
- sky/provision/vast/utils.py +162 -0
- sky/provision/vsphere/common/vim_utils.py +1 -1
- sky/provision/vsphere/instance.py +8 -18
- sky/provision/vsphere/vsphere_utils.py +1 -1
- sky/resources.py +247 -102
- sky/serve/__init__.py +9 -9
- sky/serve/autoscalers.py +361 -299
- sky/serve/client/__init__.py +0 -0
- sky/serve/client/sdk.py +366 -0
- sky/serve/constants.py +12 -3
- sky/serve/controller.py +106 -36
- sky/serve/load_balancer.py +63 -12
- sky/serve/load_balancing_policies.py +84 -2
- sky/serve/replica_managers.py +42 -34
- sky/serve/serve_state.py +62 -32
- sky/serve/serve_utils.py +271 -160
- sky/serve/server/__init__.py +0 -0
- sky/serve/{core.py → server/core.py} +271 -90
- sky/serve/server/server.py +112 -0
- sky/serve/service.py +52 -16
- sky/serve/service_spec.py +95 -32
- sky/server/__init__.py +1 -0
- sky/server/common.py +430 -0
- sky/server/constants.py +21 -0
- sky/server/html/log.html +174 -0
- sky/server/requests/__init__.py +0 -0
- sky/server/requests/executor.py +472 -0
- sky/server/requests/payloads.py +487 -0
- sky/server/requests/queues/__init__.py +0 -0
- sky/server/requests/queues/mp_queue.py +76 -0
- sky/server/requests/requests.py +567 -0
- sky/server/requests/serializers/__init__.py +0 -0
- sky/server/requests/serializers/decoders.py +192 -0
- sky/server/requests/serializers/encoders.py +166 -0
- sky/server/server.py +1106 -0
- sky/server/stream_utils.py +141 -0
- sky/setup_files/MANIFEST.in +2 -5
- sky/setup_files/dependencies.py +159 -0
- sky/setup_files/setup.py +14 -125
- sky/sky_logging.py +59 -14
- sky/skylet/autostop_lib.py +2 -2
- sky/skylet/constants.py +183 -50
- sky/skylet/events.py +22 -10
- sky/skylet/job_lib.py +403 -258
- sky/skylet/log_lib.py +111 -71
- sky/skylet/log_lib.pyi +6 -0
- sky/skylet/providers/command_runner.py +6 -8
- sky/skylet/providers/ibm/node_provider.py +2 -2
- sky/skylet/providers/scp/config.py +11 -3
- sky/skylet/providers/scp/node_provider.py +8 -8
- sky/skylet/skylet.py +3 -1
- sky/skylet/subprocess_daemon.py +69 -17
- sky/skypilot_config.py +119 -57
- sky/task.py +205 -64
- sky/templates/aws-ray.yml.j2 +37 -7
- sky/templates/azure-ray.yml.j2 +27 -82
- sky/templates/cudo-ray.yml.j2 +7 -3
- sky/templates/do-ray.yml.j2 +98 -0
- sky/templates/fluidstack-ray.yml.j2 +7 -4
- sky/templates/gcp-ray.yml.j2 +26 -6
- sky/templates/ibm-ray.yml.j2 +3 -2
- sky/templates/jobs-controller.yaml.j2 +46 -11
- sky/templates/kubernetes-ingress.yml.j2 +7 -0
- sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
- sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
- sky/templates/kubernetes-ray.yml.j2 +292 -25
- sky/templates/lambda-ray.yml.j2 +30 -40
- sky/templates/nebius-ray.yml.j2 +79 -0
- sky/templates/oci-ray.yml.j2 +18 -57
- sky/templates/paperspace-ray.yml.j2 +10 -6
- sky/templates/runpod-ray.yml.j2 +26 -4
- sky/templates/scp-ray.yml.j2 +3 -2
- sky/templates/sky-serve-controller.yaml.j2 +12 -1
- sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
- sky/templates/vast-ray.yml.j2 +70 -0
- sky/templates/vsphere-ray.yml.j2 +8 -3
- sky/templates/websocket_proxy.py +64 -0
- sky/usage/constants.py +10 -1
- sky/usage/usage_lib.py +130 -37
- sky/utils/accelerator_registry.py +35 -51
- sky/utils/admin_policy_utils.py +147 -0
- sky/utils/annotations.py +51 -0
- sky/utils/cli_utils/status_utils.py +81 -23
- sky/utils/cluster_utils.py +356 -0
- sky/utils/command_runner.py +452 -89
- sky/utils/command_runner.pyi +77 -3
- sky/utils/common.py +54 -0
- sky/utils/common_utils.py +319 -108
- sky/utils/config_utils.py +204 -0
- sky/utils/control_master_utils.py +48 -0
- sky/utils/controller_utils.py +548 -266
- sky/utils/dag_utils.py +93 -32
- sky/utils/db_utils.py +18 -4
- sky/utils/env_options.py +29 -7
- sky/utils/kubernetes/create_cluster.sh +8 -60
- sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
- sky/utils/kubernetes/gpu_labeler.py +4 -4
- sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
- sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
- sky/utils/kubernetes/rsync_helper.sh +24 -0
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
- sky/utils/log_utils.py +240 -33
- sky/utils/message_utils.py +81 -0
- sky/utils/registry.py +127 -0
- sky/utils/resources_utils.py +94 -22
- sky/utils/rich_utils.py +247 -18
- sky/utils/schemas.py +284 -64
- sky/{status_lib.py → utils/status_lib.py} +12 -7
- sky/utils/subprocess_utils.py +212 -46
- sky/utils/timeline.py +12 -7
- sky/utils/ux_utils.py +168 -15
- skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
- skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
- sky/clouds/cloud_registry.py +0 -31
- sky/jobs/core.py +0 -330
- sky/skylet/providers/azure/__init__.py +0 -2
- sky/skylet/providers/azure/azure-vm-template.json +0 -301
- sky/skylet/providers/azure/config.py +0 -170
- sky/skylet/providers/azure/node_provider.py +0 -466
- sky/skylet/providers/lambda_cloud/__init__.py +0 -2
- sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
- sky/skylet/providers/oci/__init__.py +0 -2
- sky/skylet/providers/oci/node_provider.py +0 -488
- sky/skylet/providers/oci/query_helper.py +0 -383
- sky/skylet/providers/oci/utils.py +0 -21
- sky/utils/cluster_yaml_utils.py +0 -24
- sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
- skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
- skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/jobs/server/core.py
ADDED
@@ -0,0 +1,598 @@
|
|
1
|
+
"""SDK functions for managed jobs."""
|
2
|
+
import os
|
3
|
+
import signal
|
4
|
+
import subprocess
|
5
|
+
import tempfile
|
6
|
+
import time
|
7
|
+
import typing
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
9
|
+
import uuid
|
10
|
+
|
11
|
+
import colorama
|
12
|
+
|
13
|
+
from sky import backends
|
14
|
+
from sky import core
|
15
|
+
from sky import exceptions
|
16
|
+
from sky import execution
|
17
|
+
from sky import provision as provision_lib
|
18
|
+
from sky import sky_logging
|
19
|
+
from sky import task as task_lib
|
20
|
+
from sky.backends import backend_utils
|
21
|
+
from sky.clouds.service_catalog import common as service_catalog_common
|
22
|
+
from sky.data import storage as storage_lib
|
23
|
+
from sky.jobs import constants as managed_job_constants
|
24
|
+
from sky.jobs import utils as managed_job_utils
|
25
|
+
from sky.provision import common as provision_common
|
26
|
+
from sky.skylet import constants as skylet_constants
|
27
|
+
from sky.usage import usage_lib
|
28
|
+
from sky.utils import admin_policy_utils
|
29
|
+
from sky.utils import common
|
30
|
+
from sky.utils import common_utils
|
31
|
+
from sky.utils import controller_utils
|
32
|
+
from sky.utils import dag_utils
|
33
|
+
from sky.utils import rich_utils
|
34
|
+
from sky.utils import status_lib
|
35
|
+
from sky.utils import subprocess_utils
|
36
|
+
from sky.utils import timeline
|
37
|
+
from sky.utils import ux_utils
|
38
|
+
|
39
|
+
if typing.TYPE_CHECKING:
|
40
|
+
import sky
|
41
|
+
from sky.backends import cloud_vm_ray_backend
|
42
|
+
|
43
|
+
logger = sky_logging.init_logger(__name__)
|
44
|
+
|
45
|
+
|
46
|
+
@timeline.event
|
47
|
+
@usage_lib.entrypoint
|
48
|
+
def launch(
|
49
|
+
task: Union['sky.Task', 'sky.Dag'],
|
50
|
+
name: Optional[str] = None,
|
51
|
+
stream_logs: bool = True,
|
52
|
+
) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
|
53
|
+
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
54
|
+
"""Launches a managed job.
|
55
|
+
|
56
|
+
Please refer to sky.cli.job_launch for documentation.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
|
60
|
+
managed job.
|
61
|
+
name: Name of the managed job.
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
ValueError: cluster does not exist. Or, the entrypoint is not a valid
|
65
|
+
chain dag.
|
66
|
+
sky.exceptions.NotSupportedError: the feature is not supported.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
job_id: Optional[int]; the job ID of the submitted job. None if the
|
70
|
+
backend is not CloudVmRayBackend, or no job is submitted to
|
71
|
+
the cluster.
|
72
|
+
handle: Optional[backends.ResourceHandle]; handle to the controller VM.
|
73
|
+
None if dryrun.
|
74
|
+
"""
|
75
|
+
entrypoint = task
|
76
|
+
dag_uuid = str(uuid.uuid4().hex[:4])
|
77
|
+
dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
|
78
|
+
# Always apply the policy again here, even though it might have been applied
|
79
|
+
# in the CLI. This is to ensure that we apply the policy to the final DAG
|
80
|
+
# and get the mutated config.
|
81
|
+
dag, mutated_user_config = admin_policy_utils.apply(
|
82
|
+
dag, use_mutated_config_in_current_request=False)
|
83
|
+
if not dag.is_chain():
|
84
|
+
with ux_utils.print_exception_no_traceback():
|
85
|
+
raise ValueError('Only single-task or chain DAG is '
|
86
|
+
f'allowed for job_launch. Dag: {dag}')
|
87
|
+
dag.validate()
|
88
|
+
dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)
|
89
|
+
|
90
|
+
task_names = set()
|
91
|
+
for task_ in dag.tasks:
|
92
|
+
if task_.name in task_names:
|
93
|
+
with ux_utils.print_exception_no_traceback():
|
94
|
+
raise ValueError(
|
95
|
+
f'Task name {task_.name!r} is duplicated in the DAG. '
|
96
|
+
'Either change task names to be unique, or specify the DAG '
|
97
|
+
'name only and comment out the task names (so that they '
|
98
|
+
'will be auto-generated) .')
|
99
|
+
task_names.add(task_.name)
|
100
|
+
|
101
|
+
dag_utils.fill_default_config_in_dag_for_job_launch(dag)
|
102
|
+
|
103
|
+
with rich_utils.safe_status(
|
104
|
+
ux_utils.spinner_message('Initializing managed job')):
|
105
|
+
|
106
|
+
local_to_controller_file_mounts = {}
|
107
|
+
|
108
|
+
if storage_lib.get_cached_enabled_storage_clouds_or_refresh():
|
109
|
+
for task_ in dag.tasks:
|
110
|
+
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
|
111
|
+
task_, task_type='jobs')
|
112
|
+
|
113
|
+
else:
|
114
|
+
# We do not have any cloud storage available, so fall back to
|
115
|
+
# two-hop file_mount uploading.
|
116
|
+
# Note: we can't easily hack sync_storage_mounts() to upload
|
117
|
+
# directly to the controller, because the controller may not
|
118
|
+
# even be up yet.
|
119
|
+
for task_ in dag.tasks:
|
120
|
+
if task_.storage_mounts:
|
121
|
+
# Technically, we could convert COPY storage_mounts that
|
122
|
+
# have a local source and do not specify `store`, but we
|
123
|
+
# will not do that for now. Only plain file_mounts are
|
124
|
+
# supported.
|
125
|
+
raise exceptions.NotSupportedError(
|
126
|
+
'Cloud-based file_mounts are specified, but no cloud '
|
127
|
+
'storage is available. Please specify local '
|
128
|
+
'file_mounts only.')
|
129
|
+
|
130
|
+
# Merge file mounts from all tasks.
|
131
|
+
local_to_controller_file_mounts.update(
|
132
|
+
controller_utils.translate_local_file_mounts_to_two_hop(
|
133
|
+
task_))
|
134
|
+
|
135
|
+
with tempfile.NamedTemporaryFile(prefix=f'managed-dag-{dag.name}-',
|
136
|
+
mode='w') as f:
|
137
|
+
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
|
138
|
+
controller = controller_utils.Controllers.JOBS_CONTROLLER
|
139
|
+
controller_name = controller.value.cluster_name
|
140
|
+
prefix = managed_job_constants.JOBS_TASK_YAML_PREFIX
|
141
|
+
remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml'
|
142
|
+
remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml'
|
143
|
+
remote_env_file_path = f'{prefix}/{dag.name}-{dag_uuid}.env'
|
144
|
+
controller_resources = controller_utils.get_controller_resources(
|
145
|
+
controller=controller_utils.Controllers.JOBS_CONTROLLER,
|
146
|
+
task_resources=sum([list(t.resources) for t in dag.tasks], []))
|
147
|
+
|
148
|
+
vars_to_fill = {
|
149
|
+
'remote_user_yaml_path': remote_user_yaml_path,
|
150
|
+
'user_yaml_path': f.name,
|
151
|
+
'local_to_controller_file_mounts': local_to_controller_file_mounts,
|
152
|
+
'jobs_controller': controller_name,
|
153
|
+
# Note: actual cluster name will be <task.name>-<managed job ID>
|
154
|
+
'dag_name': dag.name,
|
155
|
+
'remote_user_config_path': remote_user_config_path,
|
156
|
+
'remote_env_file_path': remote_env_file_path,
|
157
|
+
'modified_catalogs':
|
158
|
+
service_catalog_common.get_modified_catalog_file_mounts(),
|
159
|
+
'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
|
160
|
+
**controller_utils.shared_controller_vars_to_fill(
|
161
|
+
controller_utils.Controllers.JOBS_CONTROLLER,
|
162
|
+
remote_user_config_path=remote_user_config_path,
|
163
|
+
local_user_config=mutated_user_config,
|
164
|
+
),
|
165
|
+
}
|
166
|
+
|
167
|
+
yaml_path = os.path.join(
|
168
|
+
managed_job_constants.JOBS_CONTROLLER_YAML_PREFIX,
|
169
|
+
f'{name}-{dag_uuid}.yaml')
|
170
|
+
common_utils.fill_template(
|
171
|
+
managed_job_constants.JOBS_CONTROLLER_TEMPLATE,
|
172
|
+
vars_to_fill,
|
173
|
+
output_path=yaml_path)
|
174
|
+
controller_task = task_lib.Task.from_yaml(yaml_path)
|
175
|
+
controller_task.set_resources(controller_resources)
|
176
|
+
|
177
|
+
controller_task.managed_job_dag = dag
|
178
|
+
|
179
|
+
sky_logging.print(
|
180
|
+
f'{colorama.Fore.YELLOW}'
|
181
|
+
f'Launching managed job {dag.name!r} from jobs controller...'
|
182
|
+
f'{colorama.Style.RESET_ALL}')
|
183
|
+
|
184
|
+
# Launch with the api server's user hash, so that sky status does not
|
185
|
+
# show the owner of the controller as whatever user launched it first.
|
186
|
+
with common.with_server_user_hash():
|
187
|
+
return execution.launch(task=controller_task,
|
188
|
+
cluster_name=controller_name,
|
189
|
+
stream_logs=stream_logs,
|
190
|
+
idle_minutes_to_autostop=skylet_constants.
|
191
|
+
CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
|
192
|
+
retry_until_up=True,
|
193
|
+
fast=True,
|
194
|
+
_disable_controller_check=True)
|
195
|
+
|
196
|
+
|
197
|
+
def queue_from_kubernetes_pod(
|
198
|
+
pod_name: str,
|
199
|
+
context: Optional[str] = None,
|
200
|
+
skip_finished: bool = False) -> List[Dict[str, Any]]:
|
201
|
+
"""Gets the jobs queue from a specific controller pod.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
pod_name (str): The name of the controller pod to query for jobs.
|
205
|
+
context (Optional[str]): The Kubernetes context to use. If None, the
|
206
|
+
current context is used.
|
207
|
+
skip_finished (bool): If True, does not return finished jobs.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
[
|
211
|
+
{
|
212
|
+
'job_id': int,
|
213
|
+
'job_name': str,
|
214
|
+
'resources': str,
|
215
|
+
'submitted_at': (float) timestamp of submission,
|
216
|
+
'end_at': (float) timestamp of end,
|
217
|
+
'duration': (float) duration in seconds,
|
218
|
+
'recovery_count': (int) Number of retries,
|
219
|
+
'status': (sky.jobs.ManagedJobStatus) of the job,
|
220
|
+
'cluster_resources': (str) resources of the cluster,
|
221
|
+
'region': (str) region of the cluster,
|
222
|
+
}
|
223
|
+
]
|
224
|
+
|
225
|
+
Raises:
|
226
|
+
RuntimeError: If there's an error fetching the managed jobs.
|
227
|
+
"""
|
228
|
+
# Create dummy cluster info to get the command runner.
|
229
|
+
provider_config = {'context': context}
|
230
|
+
instances = {
|
231
|
+
pod_name: [
|
232
|
+
provision_common.InstanceInfo(instance_id=pod_name,
|
233
|
+
internal_ip='',
|
234
|
+
external_ip='',
|
235
|
+
tags={})
|
236
|
+
]
|
237
|
+
} # Internal IP is not required for Kubernetes
|
238
|
+
cluster_info = provision_common.ClusterInfo(provider_name='kubernetes',
|
239
|
+
head_instance_id=pod_name,
|
240
|
+
provider_config=provider_config,
|
241
|
+
instances=instances)
|
242
|
+
managed_jobs_runner = provision_lib.get_command_runners(
|
243
|
+
'kubernetes', cluster_info)[0]
|
244
|
+
|
245
|
+
code = managed_job_utils.ManagedJobCodeGen.get_job_table()
|
246
|
+
returncode, job_table_payload, stderr = managed_jobs_runner.run(
|
247
|
+
code,
|
248
|
+
require_outputs=True,
|
249
|
+
separate_stderr=True,
|
250
|
+
stream_logs=False,
|
251
|
+
)
|
252
|
+
try:
|
253
|
+
subprocess_utils.handle_returncode(returncode,
|
254
|
+
code,
|
255
|
+
'Failed to fetch managed jobs',
|
256
|
+
job_table_payload + stderr,
|
257
|
+
stream_logs=False)
|
258
|
+
except exceptions.CommandError as e:
|
259
|
+
raise RuntimeError(str(e)) from e
|
260
|
+
|
261
|
+
jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
|
262
|
+
if skip_finished:
|
263
|
+
# Filter out the finished jobs. If a multi-task job is partially
|
264
|
+
# finished, we will include all its tasks.
|
265
|
+
non_finished_tasks = list(
|
266
|
+
filter(lambda job: not job['status'].is_terminal(), jobs))
|
267
|
+
non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
|
268
|
+
jobs = list(
|
269
|
+
filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
|
270
|
+
return jobs
|
271
|
+
|
272
|
+
|
273
|
+
def _maybe_restart_controller(
|
274
|
+
refresh: bool, stopped_message: str, spinner_message: str
|
275
|
+
) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle':
|
276
|
+
"""Restart controller if refresh is True and it is stopped."""
|
277
|
+
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
|
278
|
+
if refresh:
|
279
|
+
stopped_message = ''
|
280
|
+
try:
|
281
|
+
handle = backend_utils.is_controller_accessible(
|
282
|
+
controller=jobs_controller_type, stopped_message=stopped_message)
|
283
|
+
except exceptions.ClusterNotUpError as e:
|
284
|
+
if not refresh:
|
285
|
+
raise
|
286
|
+
handle = None
|
287
|
+
controller_status = e.cluster_status
|
288
|
+
|
289
|
+
if handle is not None:
|
290
|
+
return handle
|
291
|
+
|
292
|
+
sky_logging.print(f'{colorama.Fore.YELLOW}'
|
293
|
+
f'Restarting {jobs_controller_type.value.name}...'
|
294
|
+
f'{colorama.Style.RESET_ALL}')
|
295
|
+
|
296
|
+
rich_utils.force_update_status(
|
297
|
+
ux_utils.spinner_message(f'{spinner_message} - restarting '
|
298
|
+
'controller'))
|
299
|
+
handle = core.start(cluster_name=jobs_controller_type.value.cluster_name)
|
300
|
+
# Make sure the dashboard is running when the controller is restarted.
|
301
|
+
# We should not directly use execution.launch() and have the dashboard cmd
|
302
|
+
# in the task setup because since we are using detached_setup, it will
|
303
|
+
# become a job on controller which messes up the job IDs (we assume the
|
304
|
+
# job ID in controller's job queue is consistent with managed job IDs).
|
305
|
+
with rich_utils.safe_status(
|
306
|
+
ux_utils.spinner_message('Starting dashboard...')):
|
307
|
+
runner = handle.get_command_runners()[0]
|
308
|
+
user_hash = common_utils.get_user_hash()
|
309
|
+
runner.run(
|
310
|
+
f'export '
|
311
|
+
f'{skylet_constants.USER_ID_ENV_VAR}={user_hash!r}; '
|
312
|
+
f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
|
313
|
+
stream_logs=True,
|
314
|
+
)
|
315
|
+
controller_status = status_lib.ClusterStatus.UP
|
316
|
+
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))
|
317
|
+
|
318
|
+
assert handle is not None, (controller_status, refresh)
|
319
|
+
return handle
|
320
|
+
|
321
|
+
|
322
|
+
@usage_lib.entrypoint
|
323
|
+
def queue(refresh: bool,
|
324
|
+
skip_finished: bool = False,
|
325
|
+
all_users: bool = False) -> List[Dict[str, Any]]:
|
326
|
+
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
327
|
+
"""Gets statuses of managed jobs.
|
328
|
+
|
329
|
+
Please refer to sky.cli.job_queue for documentation.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
[
|
333
|
+
{
|
334
|
+
'job_id': int,
|
335
|
+
'job_name': str,
|
336
|
+
'resources': str,
|
337
|
+
'submitted_at': (float) timestamp of submission,
|
338
|
+
'end_at': (float) timestamp of end,
|
339
|
+
'duration': (float) duration in seconds,
|
340
|
+
'recovery_count': (int) Number of retries,
|
341
|
+
'status': (sky.jobs.ManagedJobStatus) of the job,
|
342
|
+
'cluster_resources': (str) resources of the cluster,
|
343
|
+
'region': (str) region of the cluster,
|
344
|
+
}
|
345
|
+
]
|
346
|
+
Raises:
|
347
|
+
sky.exceptions.ClusterNotUpError: the jobs controller is not up or
|
348
|
+
does not exist.
|
349
|
+
RuntimeError: if failed to get the managed jobs with ssh.
|
350
|
+
"""
|
351
|
+
handle = _maybe_restart_controller(refresh,
|
352
|
+
stopped_message='No in-progress '
|
353
|
+
'managed jobs.',
|
354
|
+
spinner_message='Checking '
|
355
|
+
'managed jobs')
|
356
|
+
backend = backend_utils.get_backend_from_handle(handle)
|
357
|
+
assert isinstance(backend, backends.CloudVmRayBackend)
|
358
|
+
|
359
|
+
code = managed_job_utils.ManagedJobCodeGen.get_job_table()
|
360
|
+
returncode, job_table_payload, stderr = backend.run_on_head(
|
361
|
+
handle,
|
362
|
+
code,
|
363
|
+
require_outputs=True,
|
364
|
+
stream_logs=False,
|
365
|
+
separate_stderr=True)
|
366
|
+
|
367
|
+
if returncode != 0:
|
368
|
+
logger.error(job_table_payload + stderr)
|
369
|
+
raise RuntimeError('Failed to fetch managed jobs with returncode: '
|
370
|
+
f'{returncode}')
|
371
|
+
|
372
|
+
jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
|
373
|
+
|
374
|
+
if not all_users:
|
375
|
+
|
376
|
+
def user_hash_matches_or_missing(job: Dict[str, Any]) -> bool:
|
377
|
+
user_hash = job.get('user_hash', None)
|
378
|
+
if user_hash is None:
|
379
|
+
# For backwards compatibility, we show jobs that do not have a
|
380
|
+
# user_hash. TODO(cooperc): Remove before 0.12.0.
|
381
|
+
return True
|
382
|
+
return user_hash == common_utils.get_user_hash()
|
383
|
+
|
384
|
+
jobs = list(filter(user_hash_matches_or_missing, jobs))
|
385
|
+
|
386
|
+
if skip_finished:
|
387
|
+
# Filter out the finished jobs. If a multi-task job is partially
|
388
|
+
# finished, we will include all its tasks.
|
389
|
+
non_finished_tasks = list(
|
390
|
+
filter(lambda job: not job['status'].is_terminal(), jobs))
|
391
|
+
non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
|
392
|
+
jobs = list(
|
393
|
+
filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
|
394
|
+
|
395
|
+
return jobs
|
396
|
+
|
397
|
+
|
398
|
+
@usage_lib.entrypoint
|
399
|
+
# pylint: disable=redefined-builtin
|
400
|
+
def cancel(name: Optional[str] = None,
|
401
|
+
job_ids: Optional[List[int]] = None,
|
402
|
+
all: bool = False,
|
403
|
+
all_users: bool = False) -> None:
|
404
|
+
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
405
|
+
"""Cancels managed jobs.
|
406
|
+
|
407
|
+
Please refer to sky.cli.job_cancel for documentation.
|
408
|
+
|
409
|
+
Raises:
|
410
|
+
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
|
411
|
+
RuntimeError: failed to cancel the job.
|
412
|
+
"""
|
413
|
+
job_ids = [] if job_ids is None else job_ids
|
414
|
+
handle = backend_utils.is_controller_accessible(
|
415
|
+
controller=controller_utils.Controllers.JOBS_CONTROLLER,
|
416
|
+
stopped_message='All managed jobs should have finished.')
|
417
|
+
|
418
|
+
job_id_str = ','.join(map(str, job_ids))
|
419
|
+
if sum([bool(job_ids), name is not None, all or all_users]) != 1:
|
420
|
+
arguments = []
|
421
|
+
arguments += [f'job_ids={job_id_str}'] if job_ids else []
|
422
|
+
arguments += [f'name={name}'] if name is not None else []
|
423
|
+
arguments += ['all'] if all else []
|
424
|
+
arguments += ['all_users'] if all_users else []
|
425
|
+
with ux_utils.print_exception_no_traceback():
|
426
|
+
raise ValueError('Can only specify one of JOB_IDS, name, or all/'
|
427
|
+
f'all_users. Provided {" ".join(arguments)!r}.')
|
428
|
+
|
429
|
+
backend = backend_utils.get_backend_from_handle(handle)
|
430
|
+
assert isinstance(backend, backends.CloudVmRayBackend)
|
431
|
+
if all_users:
|
432
|
+
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(
|
433
|
+
None, all_users=True)
|
434
|
+
elif all:
|
435
|
+
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(None)
|
436
|
+
elif job_ids:
|
437
|
+
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(job_ids)
|
438
|
+
else:
|
439
|
+
assert name is not None, (job_ids, name, all)
|
440
|
+
code = managed_job_utils.ManagedJobCodeGen.cancel_job_by_name(name)
|
441
|
+
# The stderr is redirected to stdout
|
442
|
+
returncode, stdout, _ = backend.run_on_head(handle,
|
443
|
+
code,
|
444
|
+
require_outputs=True,
|
445
|
+
stream_logs=False)
|
446
|
+
try:
|
447
|
+
subprocess_utils.handle_returncode(returncode, code,
|
448
|
+
'Failed to cancel managed job',
|
449
|
+
stdout)
|
450
|
+
except exceptions.CommandError as e:
|
451
|
+
with ux_utils.print_exception_no_traceback():
|
452
|
+
raise RuntimeError(e.error_msg) from e
|
453
|
+
|
454
|
+
sky_logging.print(stdout)
|
455
|
+
if 'Multiple jobs found with name' in stdout:
|
456
|
+
with ux_utils.print_exception_no_traceback():
|
457
|
+
raise RuntimeError(
|
458
|
+
'Please specify the job ID instead of the job name.')
|
459
|
+
|
460
|
+
|
461
|
+
@usage_lib.entrypoint
|
462
|
+
def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
|
463
|
+
controller: bool, refresh: bool) -> None:
|
464
|
+
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
465
|
+
"""Tail logs of managed jobs.
|
466
|
+
|
467
|
+
Please refer to sky.cli.job_logs for documentation.
|
468
|
+
|
469
|
+
Raises:
|
470
|
+
ValueError: invalid arguments.
|
471
|
+
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
|
472
|
+
"""
|
473
|
+
# TODO(zhwu): Automatically restart the jobs controller
|
474
|
+
if name is not None and job_id is not None:
|
475
|
+
with ux_utils.print_exception_no_traceback():
|
476
|
+
raise ValueError('Cannot specify both name and job_id.')
|
477
|
+
|
478
|
+
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
|
479
|
+
job_name_or_id_str = ''
|
480
|
+
if job_id is not None:
|
481
|
+
job_name_or_id_str = str(job_id)
|
482
|
+
elif name is not None:
|
483
|
+
job_name_or_id_str = f'-n {name}'
|
484
|
+
else:
|
485
|
+
job_name_or_id_str = ''
|
486
|
+
handle = _maybe_restart_controller(
|
487
|
+
refresh,
|
488
|
+
stopped_message=(
|
489
|
+
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
|
490
|
+
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
|
491
|
+
f'-r {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
|
492
|
+
spinner_message='Retrieving job logs')
|
493
|
+
|
494
|
+
backend = backend_utils.get_backend_from_handle(handle)
|
495
|
+
assert isinstance(backend, backends.CloudVmRayBackend), backend
|
496
|
+
|
497
|
+
backend.tail_managed_job_logs(handle,
|
498
|
+
job_id=job_id,
|
499
|
+
job_name=name,
|
500
|
+
follow=follow,
|
501
|
+
controller=controller)
|
502
|
+
|
503
|
+
|
504
|
+
def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
|
505
|
+
"""Opens a dashboard for managed jobs (needs controller to be UP)."""
|
506
|
+
# TODO(SKY-1212): ideally, the controller/dashboard server should expose the
|
507
|
+
# API perhaps via REST. Then here we would (1) not have to use SSH to try to
|
508
|
+
# see if the controller is UP first, which is slow; (2) not have to run SSH
|
509
|
+
# port forwarding first (we'd just launch a local dashboard which would make
|
510
|
+
# REST API calls to the controller dashboard server).
|
511
|
+
logger.info('Starting dashboard')
|
512
|
+
hint = ('Dashboard is not available if jobs controller is not up. Run '
|
513
|
+
'a managed job first or run: sky jobs queue --refresh')
|
514
|
+
handle = _maybe_restart_controller(
|
515
|
+
refresh=refresh,
|
516
|
+
stopped_message=hint,
|
517
|
+
spinner_message='Checking jobs controller')
|
518
|
+
|
519
|
+
# SSH forward a free local port to remote's dashboard port.
|
520
|
+
remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
|
521
|
+
free_port = common_utils.find_free_port(remote_port)
|
522
|
+
runner = handle.get_command_runners()[0]
|
523
|
+
port_forward_command = ' '.join(
|
524
|
+
runner.port_forward_command(port_forward=[(free_port, remote_port)],
|
525
|
+
connect_timeout=1))
|
526
|
+
port_forward_command = (
|
527
|
+
f'{port_forward_command} '
|
528
|
+
f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
|
529
|
+
'2>&1')
|
530
|
+
logger.info(f'Forwarding port: {colorama.Style.DIM}{port_forward_command}'
|
531
|
+
f'{colorama.Style.RESET_ALL}')
|
532
|
+
|
533
|
+
ssh_process = subprocess.Popen(port_forward_command,
|
534
|
+
shell=True,
|
535
|
+
start_new_session=True)
|
536
|
+
time.sleep(3) # Added delay for ssh_command to initialize.
|
537
|
+
logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
|
538
|
+
f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')
|
539
|
+
|
540
|
+
return free_port, ssh_process.pid
|
541
|
+
|
542
|
+
|
543
|
+
def stop_dashboard_forwarding(pid: int) -> None:
|
544
|
+
# Exit the ssh command when the context manager is closed.
|
545
|
+
try:
|
546
|
+
os.killpg(os.getpgid(pid), signal.SIGTERM)
|
547
|
+
except ProcessLookupError:
|
548
|
+
# This happens if jobs controller is auto-stopped.
|
549
|
+
pass
|
550
|
+
logger.info('Forwarding port closed. Exiting.')
|
551
|
+
|
552
|
+
|
553
|
+
@usage_lib.entrypoint
|
554
|
+
def download_logs(
|
555
|
+
name: Optional[str],
|
556
|
+
job_id: Optional[int],
|
557
|
+
refresh: bool,
|
558
|
+
controller: bool,
|
559
|
+
local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]:
|
560
|
+
"""Sync down logs of managed jobs.
|
561
|
+
|
562
|
+
Please refer to sky.cli.job_logs for documentation.
|
563
|
+
|
564
|
+
Returns:
|
565
|
+
A dictionary mapping job ID to the local path.
|
566
|
+
|
567
|
+
Raises:
|
568
|
+
ValueError: invalid arguments.
|
569
|
+
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
|
570
|
+
"""
|
571
|
+
if name is not None and job_id is not None:
|
572
|
+
with ux_utils.print_exception_no_traceback():
|
573
|
+
raise ValueError('Cannot specify both name and job_id.')
|
574
|
+
|
575
|
+
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
|
576
|
+
job_name_or_id_str = ''
|
577
|
+
if job_id is not None:
|
578
|
+
job_name_or_id_str = str(job_id)
|
579
|
+
elif name is not None:
|
580
|
+
job_name_or_id_str = f'-n {name}'
|
581
|
+
else:
|
582
|
+
job_name_or_id_str = ''
|
583
|
+
handle = _maybe_restart_controller(
|
584
|
+
refresh,
|
585
|
+
stopped_message=(
|
586
|
+
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
|
587
|
+
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
|
588
|
+
f'-r --sync-down {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
|
589
|
+
spinner_message='Retrieving job logs')
|
590
|
+
|
591
|
+
backend = backend_utils.get_backend_from_handle(handle)
|
592
|
+
assert isinstance(backend, backends.CloudVmRayBackend), backend
|
593
|
+
|
594
|
+
return backend.sync_down_managed_job_logs(handle,
|
595
|
+
job_id=job_id,
|
596
|
+
job_name=name,
|
597
|
+
controller=controller,
|
598
|
+
local_dir=local_dir)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
"""Persistent dashboard sessions.
|
2
|
+
|
3
|
+
Note: before #4717, this was useful because we needed to tunnel to multiple
|
4
|
+
controllers - one per user. Now, there is only one controller for the whole API
|
5
|
+
server, so this is not very useful. TODO(cooperc): Remove or fix this.
|
6
|
+
"""
|
7
|
+
import pathlib
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
import filelock
|
11
|
+
|
12
|
+
from sky.utils import db_utils
|
13
|
+
|
14
|
+
|
15
|
+
def create_dashboard_table(cursor, conn):
|
16
|
+
cursor.execute("""\
|
17
|
+
CREATE TABLE IF NOT EXISTS dashboard_sessions (
|
18
|
+
user_hash TEXT PRIMARY KEY,
|
19
|
+
port INTEGER,
|
20
|
+
pid INTEGER)""")
|
21
|
+
conn.commit()
|
22
|
+
|
23
|
+
|
24
|
+
def _get_db_path() -> str:
|
25
|
+
path = pathlib.Path('~/.sky/dashboard/sessions.db')
|
26
|
+
path = path.expanduser().absolute()
|
27
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
28
|
+
return str(path)
|
29
|
+
|
30
|
+
|
31
|
+
DB_PATH = _get_db_path()
|
32
|
+
db_utils.SQLiteConn(DB_PATH, create_dashboard_table)
|
33
|
+
LOCK_FILE_PATH = '~/.sky/dashboard/sessions-{user_hash}.lock'
|
34
|
+
|
35
|
+
|
36
|
+
def get_dashboard_session(user_hash: str) -> Tuple[int, int]:
|
37
|
+
"""Get the port and pid of the dashboard session for the user."""
|
38
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
39
|
+
cursor.execute(
|
40
|
+
'SELECT port, pid FROM dashboard_sessions WHERE user_hash=?',
|
41
|
+
(user_hash,))
|
42
|
+
result = cursor.fetchone()
|
43
|
+
if result is None:
|
44
|
+
return 0, 0
|
45
|
+
return result
|
46
|
+
|
47
|
+
|
48
|
+
def add_dashboard_session(user_hash: str, port: int, pid: int) -> None:
|
49
|
+
"""Add a dashboard session for the user."""
|
50
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
51
|
+
cursor.execute(
|
52
|
+
'INSERT OR REPLACE INTO dashboard_sessions (user_hash, port, pid) '
|
53
|
+
'VALUES (?, ?, ?)', (user_hash, port, pid))
|
54
|
+
|
55
|
+
|
56
|
+
def remove_dashboard_session(user_hash: str) -> None:
|
57
|
+
"""Remove the dashboard session for the user."""
|
58
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
59
|
+
cursor.execute('DELETE FROM dashboard_sessions WHERE user_hash=?',
|
60
|
+
(user_hash,))
|
61
|
+
lock_path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
|
62
|
+
lock_path.unlink(missing_ok=True)
|
63
|
+
|
64
|
+
|
65
|
+
def get_dashboard_lock_for_user(user_hash: str) -> filelock.FileLock:
|
66
|
+
path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
|
67
|
+
path = path.expanduser().absolute()
|
68
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
69
|
+
return filelock.FileLock(path)
|