skypilot-nightly 1.0.0.dev20250215__py3-none-any.whl → 1.0.0.dev20250217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +48 -22
- sky/adaptors/aws.py +2 -1
- sky/adaptors/azure.py +4 -4
- sky/adaptors/cloudflare.py +4 -4
- sky/adaptors/kubernetes.py +8 -8
- sky/authentication.py +42 -45
- sky/backends/backend.py +2 -2
- sky/backends/backend_utils.py +108 -221
- sky/backends/cloud_vm_ray_backend.py +283 -282
- sky/benchmark/benchmark_utils.py +6 -2
- sky/check.py +40 -28
- sky/cli.py +1213 -1116
- sky/client/__init__.py +1 -0
- sky/client/cli.py +5644 -0
- sky/client/common.py +345 -0
- sky/client/sdk.py +1757 -0
- sky/cloud_stores.py +12 -6
- sky/clouds/__init__.py +0 -2
- sky/clouds/aws.py +20 -13
- sky/clouds/azure.py +5 -3
- sky/clouds/cloud.py +1 -1
- sky/clouds/cudo.py +2 -1
- sky/clouds/do.py +2 -1
- sky/clouds/fluidstack.py +3 -2
- sky/clouds/gcp.py +10 -8
- sky/clouds/ibm.py +8 -7
- sky/clouds/kubernetes.py +7 -6
- sky/clouds/lambda_cloud.py +8 -7
- sky/clouds/oci.py +4 -3
- sky/clouds/paperspace.py +2 -1
- sky/clouds/runpod.py +2 -1
- sky/clouds/scp.py +8 -7
- sky/clouds/service_catalog/__init__.py +3 -3
- sky/clouds/service_catalog/aws_catalog.py +7 -1
- sky/clouds/service_catalog/common.py +4 -2
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +2 -2
- sky/clouds/utils/oci_utils.py +1 -1
- sky/clouds/vast.py +2 -1
- sky/clouds/vsphere.py +2 -1
- sky/core.py +263 -99
- sky/dag.py +4 -0
- sky/data/mounting_utils.py +2 -1
- sky/data/storage.py +97 -35
- sky/data/storage_utils.py +69 -9
- sky/exceptions.py +138 -5
- sky/execution.py +47 -50
- sky/global_user_state.py +105 -22
- sky/jobs/__init__.py +12 -14
- sky/jobs/client/__init__.py +0 -0
- sky/jobs/client/sdk.py +296 -0
- sky/jobs/constants.py +30 -1
- sky/jobs/controller.py +12 -6
- sky/jobs/dashboard/dashboard.py +2 -6
- sky/jobs/recovery_strategy.py +22 -29
- sky/jobs/server/__init__.py +1 -0
- sky/jobs/{core.py → server/core.py} +101 -34
- sky/jobs/server/dashboard_utils.py +64 -0
- sky/jobs/server/server.py +182 -0
- sky/jobs/utils.py +32 -23
- sky/models.py +27 -0
- sky/optimizer.py +9 -11
- sky/provision/__init__.py +6 -3
- sky/provision/aws/config.py +2 -2
- sky/provision/aws/instance.py +1 -1
- sky/provision/azure/instance.py +1 -1
- sky/provision/cudo/instance.py +1 -1
- sky/provision/do/instance.py +1 -1
- sky/provision/do/utils.py +0 -5
- sky/provision/fluidstack/fluidstack_utils.py +4 -3
- sky/provision/fluidstack/instance.py +4 -2
- sky/provision/gcp/instance.py +1 -1
- sky/provision/instance_setup.py +2 -2
- sky/provision/kubernetes/constants.py +8 -0
- sky/provision/kubernetes/instance.py +1 -1
- sky/provision/kubernetes/utils.py +67 -76
- sky/provision/lambda_cloud/instance.py +3 -15
- sky/provision/logging.py +1 -1
- sky/provision/oci/instance.py +7 -4
- sky/provision/paperspace/instance.py +1 -1
- sky/provision/provisioner.py +3 -2
- sky/provision/runpod/instance.py +1 -1
- sky/provision/vast/instance.py +1 -1
- sky/provision/vast/utils.py +2 -1
- sky/provision/vsphere/instance.py +2 -11
- sky/resources.py +55 -40
- sky/serve/__init__.py +6 -10
- sky/serve/client/__init__.py +0 -0
- sky/serve/client/sdk.py +366 -0
- sky/serve/constants.py +3 -0
- sky/serve/replica_managers.py +10 -10
- sky/serve/serve_utils.py +56 -36
- sky/serve/server/__init__.py +0 -0
- sky/serve/{core.py → server/core.py} +37 -17
- sky/serve/server/server.py +117 -0
- sky/serve/service.py +8 -1
- sky/server/__init__.py +1 -0
- sky/server/common.py +441 -0
- sky/server/constants.py +21 -0
- sky/server/html/log.html +174 -0
- sky/server/requests/__init__.py +0 -0
- sky/server/requests/executor.py +462 -0
- sky/server/requests/payloads.py +481 -0
- sky/server/requests/queues/__init__.py +0 -0
- sky/server/requests/queues/mp_queue.py +76 -0
- sky/server/requests/requests.py +567 -0
- sky/server/requests/serializers/__init__.py +0 -0
- sky/server/requests/serializers/decoders.py +192 -0
- sky/server/requests/serializers/encoders.py +166 -0
- sky/server/server.py +1095 -0
- sky/server/stream_utils.py +144 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +12 -4
- sky/setup_files/setup.py +1 -1
- sky/sky_logging.py +9 -13
- sky/skylet/autostop_lib.py +2 -2
- sky/skylet/constants.py +46 -12
- sky/skylet/events.py +5 -6
- sky/skylet/job_lib.py +78 -66
- sky/skylet/log_lib.py +17 -11
- sky/skypilot_config.py +79 -94
- sky/task.py +119 -73
- sky/templates/aws-ray.yml.j2 +4 -4
- sky/templates/azure-ray.yml.j2 +3 -2
- sky/templates/cudo-ray.yml.j2 +3 -2
- sky/templates/fluidstack-ray.yml.j2 +3 -2
- sky/templates/gcp-ray.yml.j2 +3 -2
- sky/templates/ibm-ray.yml.j2 +3 -2
- sky/templates/jobs-controller.yaml.j2 +1 -12
- sky/templates/kubernetes-ray.yml.j2 +3 -2
- sky/templates/lambda-ray.yml.j2 +3 -2
- sky/templates/oci-ray.yml.j2 +3 -2
- sky/templates/paperspace-ray.yml.j2 +3 -2
- sky/templates/runpod-ray.yml.j2 +3 -2
- sky/templates/scp-ray.yml.j2 +3 -2
- sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
- sky/templates/vsphere-ray.yml.j2 +4 -2
- sky/templates/websocket_proxy.py +64 -0
- sky/usage/constants.py +8 -0
- sky/usage/usage_lib.py +45 -11
- sky/utils/accelerator_registry.py +33 -53
- sky/utils/admin_policy_utils.py +2 -1
- sky/utils/annotations.py +51 -0
- sky/utils/cli_utils/status_utils.py +33 -3
- sky/utils/cluster_utils.py +356 -0
- sky/utils/command_runner.py +69 -14
- sky/utils/common.py +74 -0
- sky/utils/common_utils.py +133 -93
- sky/utils/config_utils.py +204 -0
- sky/utils/control_master_utils.py +2 -3
- sky/utils/controller_utils.py +133 -147
- sky/utils/dag_utils.py +72 -24
- sky/utils/kubernetes/deploy_remote_cluster.sh +2 -2
- sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
- sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
- sky/utils/log_utils.py +83 -23
- sky/utils/message_utils.py +81 -0
- sky/utils/registry.py +127 -0
- sky/utils/resources_utils.py +2 -2
- sky/utils/rich_utils.py +213 -34
- sky/utils/schemas.py +19 -2
- sky/{status_lib.py → utils/status_lib.py} +12 -7
- sky/utils/subprocess_utils.py +51 -35
- sky/utils/timeline.py +7 -2
- sky/utils/ux_utils.py +95 -25
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/METADATA +8 -3
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/RECORD +170 -132
- sky/clouds/cloud_registry.py +0 -76
- sky/utils/cluster_yaml_utils.py +0 -24
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
|
|
1
1
|
"""SDK functions for managed jobs."""
|
2
2
|
import os
|
3
|
+
import signal
|
4
|
+
import subprocess
|
3
5
|
import tempfile
|
6
|
+
import time
|
4
7
|
import typing
|
5
8
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
6
9
|
import uuid
|
7
10
|
|
8
11
|
import colorama
|
9
12
|
|
10
|
-
import sky
|
11
13
|
from sky import backends
|
14
|
+
from sky import core
|
12
15
|
from sky import exceptions
|
16
|
+
from sky import execution
|
13
17
|
from sky import provision as provision_lib
|
14
18
|
from sky import sky_logging
|
15
|
-
from sky import status_lib
|
16
19
|
from sky import task as task_lib
|
17
20
|
from sky.backends import backend_utils
|
18
21
|
from sky.clouds.service_catalog import common as service_catalog_common
|
@@ -26,13 +29,17 @@ from sky.utils import common_utils
|
|
26
29
|
from sky.utils import controller_utils
|
27
30
|
from sky.utils import dag_utils
|
28
31
|
from sky.utils import rich_utils
|
32
|
+
from sky.utils import status_lib
|
29
33
|
from sky.utils import subprocess_utils
|
30
34
|
from sky.utils import timeline
|
31
35
|
from sky.utils import ux_utils
|
32
36
|
|
33
37
|
if typing.TYPE_CHECKING:
|
38
|
+
import sky
|
34
39
|
from sky.backends import cloud_vm_ray_backend
|
35
40
|
|
41
|
+
logger = sky_logging.init_logger(__name__)
|
42
|
+
|
36
43
|
|
37
44
|
@timeline.event
|
38
45
|
@usage_lib.entrypoint
|
@@ -40,10 +47,9 @@ def launch(
|
|
40
47
|
task: Union['sky.Task', 'sky.Dag'],
|
41
48
|
name: Optional[str] = None,
|
42
49
|
stream_logs: bool = True,
|
43
|
-
detach_run: bool = False,
|
44
50
|
) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
|
45
51
|
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
46
|
-
"""
|
52
|
+
"""Launches a managed job.
|
47
53
|
|
48
54
|
Please refer to sky.cli.job_launch for documentation.
|
49
55
|
|
@@ -51,7 +57,6 @@ def launch(
|
|
51
57
|
task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
|
52
58
|
managed job.
|
53
59
|
name: Name of the managed job.
|
54
|
-
detach_run: Whether to detach the run.
|
55
60
|
|
56
61
|
Raises:
|
57
62
|
ValueError: cluster does not exist. Or, the entrypoint is not a valid
|
@@ -77,7 +82,7 @@ def launch(
|
|
77
82
|
with ux_utils.print_exception_no_traceback():
|
78
83
|
raise ValueError('Only single-task or chain DAG is '
|
79
84
|
f'allowed for job_launch. Dag: {dag}')
|
80
|
-
|
85
|
+
dag.validate()
|
81
86
|
dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)
|
82
87
|
|
83
88
|
task_names = set()
|
@@ -120,6 +125,7 @@ def launch(
|
|
120
125
|
'remote_user_config_path': remote_user_config_path,
|
121
126
|
'modified_catalogs':
|
122
127
|
service_catalog_common.get_modified_catalog_file_mounts(),
|
128
|
+
'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
|
123
129
|
**controller_utils.shared_controller_vars_to_fill(
|
124
130
|
controller_utils.Controllers.JOBS_CONTROLLER,
|
125
131
|
remote_user_config_path=remote_user_config_path,
|
@@ -143,15 +149,14 @@ def launch(
|
|
143
149
|
f'{colorama.Fore.YELLOW}'
|
144
150
|
f'Launching managed job {dag.name!r} from jobs controller...'
|
145
151
|
f'{colorama.Style.RESET_ALL}')
|
146
|
-
return
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
_disable_controller_check=True)
|
152
|
+
return execution.launch(task=controller_task,
|
153
|
+
cluster_name=controller_name,
|
154
|
+
stream_logs=stream_logs,
|
155
|
+
idle_minutes_to_autostop=skylet_constants.
|
156
|
+
CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
|
157
|
+
retry_until_up=True,
|
158
|
+
fast=True,
|
159
|
+
_disable_controller_check=True)
|
155
160
|
|
156
161
|
|
157
162
|
def queue_from_kubernetes_pod(
|
@@ -256,7 +261,22 @@ def _maybe_restart_controller(
|
|
256
261
|
rich_utils.force_update_status(
|
257
262
|
ux_utils.spinner_message(f'{spinner_message} - restarting '
|
258
263
|
'controller'))
|
259
|
-
handle =
|
264
|
+
handle = core.start(cluster_name=jobs_controller_type.value.cluster_name)
|
265
|
+
# Make sure the dashboard is running when the controller is restarted.
|
266
|
+
# We should not directly use execution.launch() and have the dashboard cmd
|
267
|
+
# in the task setup because since we are using detached_setup, it will
|
268
|
+
# become a job on controller which messes up the job IDs (we assume the
|
269
|
+
# job ID in controller's job queue is consistent with managed job IDs).
|
270
|
+
with rich_utils.safe_status(
|
271
|
+
ux_utils.spinner_message('Starting dashboard...')):
|
272
|
+
runner = handle.get_command_runners()[0]
|
273
|
+
user_hash = common_utils.get_user_hash()
|
274
|
+
runner.run(
|
275
|
+
f'export '
|
276
|
+
f'{skylet_constants.USER_ID_ENV_VAR}={user_hash!r}; '
|
277
|
+
f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
|
278
|
+
stream_logs=True,
|
279
|
+
)
|
260
280
|
controller_status = status_lib.ClusterStatus.UP
|
261
281
|
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))
|
262
282
|
|
@@ -267,7 +287,7 @@ def _maybe_restart_controller(
|
|
267
287
|
@usage_lib.entrypoint
|
268
288
|
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
|
269
289
|
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
270
|
-
"""
|
290
|
+
"""Gets statuses of managed jobs.
|
271
291
|
|
272
292
|
Please refer to sky.cli.job_queue for documentation.
|
273
293
|
|
@@ -307,14 +327,10 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
|
|
307
327
|
stream_logs=False,
|
308
328
|
separate_stderr=True)
|
309
329
|
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
job_table_payload + stderr,
|
315
|
-
stream_logs=False)
|
316
|
-
except exceptions.CommandError as e:
|
317
|
-
raise RuntimeError(str(e)) from e
|
330
|
+
if returncode != 0:
|
331
|
+
logger.error(job_table_payload + stderr)
|
332
|
+
raise RuntimeError('Failed to fetch managed jobs with returncode: '
|
333
|
+
f'{returncode}')
|
318
334
|
|
319
335
|
jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
|
320
336
|
if skip_finished:
|
@@ -334,7 +350,7 @@ def cancel(name: Optional[str] = None,
|
|
334
350
|
job_ids: Optional[List[int]] = None,
|
335
351
|
all: bool = False) -> None:
|
336
352
|
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
|
337
|
-
"""
|
353
|
+
"""Cancels managed jobs.
|
338
354
|
|
339
355
|
Please refer to sky.cli.job_cancel for documentation.
|
340
356
|
|
@@ -428,22 +444,73 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
|
|
428
444
|
controller=controller)
|
429
445
|
|
430
446
|
|
447
|
+
def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
|
448
|
+
"""Opens a dashboard for managed jobs (needs controller to be UP)."""
|
449
|
+
# TODO(SKY-1212): ideally, the controller/dashboard server should expose the
|
450
|
+
# API perhaps via REST. Then here we would (1) not have to use SSH to try to
|
451
|
+
# see if the controller is UP first, which is slow; (2) not have to run SSH
|
452
|
+
# port forwarding first (we'd just launch a local dashboard which would make
|
453
|
+
# REST API calls to the controller dashboard server).
|
454
|
+
logger.info('Starting dashboard')
|
455
|
+
hint = ('Dashboard is not available if jobs controller is not up. Run '
|
456
|
+
'a managed job first or run: sky jobs queue --refresh')
|
457
|
+
handle = _maybe_restart_controller(
|
458
|
+
refresh=refresh,
|
459
|
+
stopped_message=hint,
|
460
|
+
spinner_message='Checking jobs controller')
|
461
|
+
|
462
|
+
# SSH forward a free local port to remote's dashboard port.
|
463
|
+
remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
|
464
|
+
free_port = common_utils.find_free_port(remote_port)
|
465
|
+
runner = handle.get_command_runners()[0]
|
466
|
+
port_forward_command = ' '.join(
|
467
|
+
runner.port_forward_command(port_forward=[(free_port, remote_port)],
|
468
|
+
connect_timeout=1))
|
469
|
+
port_forward_command = (
|
470
|
+
f'{port_forward_command} '
|
471
|
+
f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
|
472
|
+
'2>&1')
|
473
|
+
logger.info(f'Forwarding port: {colorama.Style.DIM}{port_forward_command}'
|
474
|
+
f'{colorama.Style.RESET_ALL}')
|
475
|
+
|
476
|
+
ssh_process = subprocess.Popen(port_forward_command,
|
477
|
+
shell=True,
|
478
|
+
start_new_session=True)
|
479
|
+
time.sleep(3) # Added delay for ssh_command to initialize.
|
480
|
+
logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
|
481
|
+
f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')
|
482
|
+
|
483
|
+
return free_port, ssh_process.pid
|
484
|
+
|
485
|
+
|
486
|
+
def stop_dashboard_forwarding(pid: int) -> None:
|
487
|
+
# Exit the ssh command when the context manager is closed.
|
488
|
+
try:
|
489
|
+
os.killpg(os.getpgid(pid), signal.SIGTERM)
|
490
|
+
except ProcessLookupError:
|
491
|
+
# This happens if jobs controller is auto-stopped.
|
492
|
+
pass
|
493
|
+
logger.info('Forwarding port closed. Exiting.')
|
494
|
+
|
495
|
+
|
431
496
|
@usage_lib.entrypoint
|
432
|
-
def
|
497
|
+
def download_logs(
|
433
498
|
name: Optional[str],
|
434
499
|
job_id: Optional[int],
|
435
500
|
refresh: bool,
|
436
501
|
controller: bool,
|
437
|
-
local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) ->
|
502
|
+
local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]:
|
438
503
|
"""Sync down logs of managed jobs.
|
439
504
|
|
440
505
|
Please refer to sky.cli.job_logs for documentation.
|
441
506
|
|
507
|
+
Returns:
|
508
|
+
A dictionary mapping job ID to the local path.
|
509
|
+
|
442
510
|
Raises:
|
443
511
|
ValueError: invalid arguments.
|
444
512
|
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
|
445
513
|
"""
|
446
|
-
# TODO(zhwu): Automatically restart the jobs controller
|
447
514
|
if name is not None and job_id is not None:
|
448
515
|
with ux_utils.print_exception_no_traceback():
|
449
516
|
raise ValueError('Cannot specify both name and job_id.')
|
@@ -467,8 +534,8 @@ def sync_down_logs(
|
|
467
534
|
backend = backend_utils.get_backend_from_handle(handle)
|
468
535
|
assert isinstance(backend, backends.CloudVmRayBackend), backend
|
469
536
|
|
470
|
-
backend.sync_down_managed_job_logs(handle,
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
537
|
+
return backend.sync_down_managed_job_logs(handle,
|
538
|
+
job_id=job_id,
|
539
|
+
job_name=name,
|
540
|
+
controller=controller,
|
541
|
+
local_dir=local_dir)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
"""Persistent dashboard sessions."""
|
2
|
+
import pathlib
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import filelock
|
6
|
+
|
7
|
+
from sky.utils import db_utils
|
8
|
+
|
9
|
+
|
10
|
+
def create_dashboard_table(cursor, conn):
|
11
|
+
cursor.execute("""\
|
12
|
+
CREATE TABLE IF NOT EXISTS dashboard_sessions (
|
13
|
+
user_hash TEXT PRIMARY KEY,
|
14
|
+
port INTEGER,
|
15
|
+
pid INTEGER)""")
|
16
|
+
conn.commit()
|
17
|
+
|
18
|
+
|
19
|
+
def _get_db_path() -> str:
|
20
|
+
path = pathlib.Path('~/.sky/dashboard/sessions.db')
|
21
|
+
path = path.expanduser().absolute()
|
22
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
23
|
+
return str(path)
|
24
|
+
|
25
|
+
|
26
|
+
DB_PATH = _get_db_path()
|
27
|
+
db_utils.SQLiteConn(DB_PATH, create_dashboard_table)
|
28
|
+
LOCK_FILE_PATH = '~/.sky/dashboard/sessions-{user_hash}.lock'
|
29
|
+
|
30
|
+
|
31
|
+
def get_dashboard_session(user_hash: str) -> Tuple[int, int]:
|
32
|
+
"""Get the port and pid of the dashboard session for the user."""
|
33
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
34
|
+
cursor.execute(
|
35
|
+
'SELECT port, pid FROM dashboard_sessions WHERE user_hash=?',
|
36
|
+
(user_hash,))
|
37
|
+
result = cursor.fetchone()
|
38
|
+
if result is None:
|
39
|
+
return 0, 0
|
40
|
+
return result
|
41
|
+
|
42
|
+
|
43
|
+
def add_dashboard_session(user_hash: str, port: int, pid: int) -> None:
|
44
|
+
"""Add a dashboard session for the user."""
|
45
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
46
|
+
cursor.execute(
|
47
|
+
'INSERT OR REPLACE INTO dashboard_sessions (user_hash, port, pid) '
|
48
|
+
'VALUES (?, ?, ?)', (user_hash, port, pid))
|
49
|
+
|
50
|
+
|
51
|
+
def remove_dashboard_session(user_hash: str) -> None:
|
52
|
+
"""Remove the dashboard session for the user."""
|
53
|
+
with db_utils.safe_cursor(DB_PATH) as cursor:
|
54
|
+
cursor.execute('DELETE FROM dashboard_sessions WHERE user_hash=?',
|
55
|
+
(user_hash,))
|
56
|
+
lock_path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
|
57
|
+
lock_path.unlink(missing_ok=True)
|
58
|
+
|
59
|
+
|
60
|
+
def get_dashboard_lock_for_user(user_hash: str) -> filelock.FileLock:
|
61
|
+
path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
|
62
|
+
path = path.expanduser().absolute()
|
63
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
64
|
+
return filelock.FileLock(path)
|
@@ -0,0 +1,182 @@
|
|
1
|
+
"""REST API for managed jobs."""
|
2
|
+
import os
|
3
|
+
|
4
|
+
import fastapi
|
5
|
+
import httpx
|
6
|
+
|
7
|
+
from sky import sky_logging
|
8
|
+
from sky.jobs.server import core
|
9
|
+
from sky.jobs.server import dashboard_utils
|
10
|
+
from sky.server import common as server_common
|
11
|
+
from sky.server import stream_utils
|
12
|
+
from sky.server.requests import executor
|
13
|
+
from sky.server.requests import payloads
|
14
|
+
from sky.server.requests import requests as api_requests
|
15
|
+
from sky.skylet import constants
|
16
|
+
from sky.utils import common
|
17
|
+
from sky.utils import common_utils
|
18
|
+
|
19
|
+
logger = sky_logging.init_logger(__name__)
|
20
|
+
|
21
|
+
router = fastapi.APIRouter()
|
22
|
+
|
23
|
+
|
24
|
+
def _get_controller_name(request_body: payloads.RequestBody) -> str:
|
25
|
+
user_hash = request_body.user_hash
|
26
|
+
return common.get_controller_name(common.ControllerType.JOBS, user_hash)
|
27
|
+
|
28
|
+
|
29
|
+
@router.post('/launch')
|
30
|
+
async def launch(request: fastapi.Request,
|
31
|
+
jobs_launch_body: payloads.JobsLaunchBody) -> None:
|
32
|
+
executor.schedule_request(
|
33
|
+
request_id=request.state.request_id,
|
34
|
+
request_name='jobs.launch',
|
35
|
+
request_body=jobs_launch_body,
|
36
|
+
func=core.launch,
|
37
|
+
schedule_type=api_requests.ScheduleType.LONG,
|
38
|
+
request_cluster_name=_get_controller_name(jobs_launch_body),
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@router.post('/queue')
|
43
|
+
async def queue(request: fastapi.Request,
|
44
|
+
jobs_queue_body: payloads.JobsQueueBody) -> None:
|
45
|
+
executor.schedule_request(
|
46
|
+
request_id=request.state.request_id,
|
47
|
+
request_name='jobs.queue',
|
48
|
+
request_body=jobs_queue_body,
|
49
|
+
func=core.queue,
|
50
|
+
schedule_type=(api_requests.ScheduleType.LONG if jobs_queue_body.refresh
|
51
|
+
else api_requests.ScheduleType.SHORT),
|
52
|
+
request_cluster_name=_get_controller_name(jobs_queue_body),
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
@router.post('/cancel')
|
57
|
+
async def cancel(request: fastapi.Request,
|
58
|
+
jobs_cancel_body: payloads.JobsCancelBody) -> None:
|
59
|
+
executor.schedule_request(
|
60
|
+
request_id=request.state.request_id,
|
61
|
+
request_name='jobs.cancel',
|
62
|
+
request_body=jobs_cancel_body,
|
63
|
+
func=core.cancel,
|
64
|
+
schedule_type=api_requests.ScheduleType.SHORT,
|
65
|
+
request_cluster_name=_get_controller_name(jobs_cancel_body),
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
@router.post('/logs')
|
70
|
+
async def logs(
|
71
|
+
request: fastapi.Request, jobs_logs_body: payloads.JobsLogsBody,
|
72
|
+
background_tasks: fastapi.BackgroundTasks
|
73
|
+
) -> fastapi.responses.StreamingResponse:
|
74
|
+
executor.schedule_request(
|
75
|
+
request_id=request.state.request_id,
|
76
|
+
request_name='jobs.logs',
|
77
|
+
request_body=jobs_logs_body,
|
78
|
+
func=core.tail_logs,
|
79
|
+
schedule_type=api_requests.ScheduleType.SHORT
|
80
|
+
if jobs_logs_body.refresh else api_requests.ScheduleType.LONG,
|
81
|
+
request_cluster_name=_get_controller_name(jobs_logs_body),
|
82
|
+
)
|
83
|
+
request_task = api_requests.get_request(request.state.request_id)
|
84
|
+
|
85
|
+
return stream_utils.stream_response(
|
86
|
+
request_id=request_task.request_id,
|
87
|
+
logs_path=request_task.log_path,
|
88
|
+
background_tasks=background_tasks,
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
@router.post('/download_logs')
|
93
|
+
async def download_logs(
|
94
|
+
request: fastapi.Request,
|
95
|
+
jobs_download_logs_body: payloads.JobsDownloadLogsBody) -> None:
|
96
|
+
user_hash = jobs_download_logs_body.env_vars[constants.USER_ID_ENV_VAR]
|
97
|
+
logs_dir_on_api_server = server_common.api_server_user_logs_dir_prefix(
|
98
|
+
user_hash)
|
99
|
+
logs_dir_on_api_server.expanduser().mkdir(parents=True, exist_ok=True)
|
100
|
+
# We should reuse the original request body, so that the env vars, such as
|
101
|
+
# user hash, are kept the same.
|
102
|
+
jobs_download_logs_body.local_dir = str(logs_dir_on_api_server)
|
103
|
+
executor.schedule_request(
|
104
|
+
request_id=request.state.request_id,
|
105
|
+
request_name='jobs.download_logs',
|
106
|
+
request_body=jobs_download_logs_body,
|
107
|
+
func=core.download_logs,
|
108
|
+
schedule_type=api_requests.ScheduleType.LONG
|
109
|
+
if jobs_download_logs_body.refresh else api_requests.ScheduleType.SHORT,
|
110
|
+
request_cluster_name=_get_controller_name(jobs_download_logs_body),
|
111
|
+
)
|
112
|
+
|
113
|
+
|
114
|
+
@router.get('/dashboard')
|
115
|
+
async def dashboard(request: fastapi.Request,
|
116
|
+
user_hash: str) -> fastapi.Response:
|
117
|
+
# Find the port for the dashboard of the user
|
118
|
+
os.environ[constants.USER_ID_ENV_VAR] = user_hash
|
119
|
+
server_common.reload_for_new_request(client_entrypoint=None,
|
120
|
+
client_command=None)
|
121
|
+
logger.info(f'Starting dashboard for user hash: {user_hash}')
|
122
|
+
|
123
|
+
with dashboard_utils.get_dashboard_lock_for_user(user_hash):
|
124
|
+
max_retries = 3
|
125
|
+
for attempt in range(max_retries):
|
126
|
+
port, pid = dashboard_utils.get_dashboard_session(user_hash)
|
127
|
+
if port == 0 or attempt > 0:
|
128
|
+
# Let the client know that we are waiting for starting the
|
129
|
+
# dashboard.
|
130
|
+
try:
|
131
|
+
port, pid = core.start_dashboard_forwarding()
|
132
|
+
except Exception as e: # pylint: disable=broad-except
|
133
|
+
# We catch all exceptions to gracefully handle unknown
|
134
|
+
# errors and raise an HTTPException to the client.
|
135
|
+
msg = (
|
136
|
+
'Dashboard failed to start: '
|
137
|
+
f'{common_utils.format_exception(e, use_bracket=True)}')
|
138
|
+
logger.error(msg)
|
139
|
+
raise fastapi.HTTPException(status_code=503, detail=msg)
|
140
|
+
dashboard_utils.add_dashboard_session(user_hash, port, pid)
|
141
|
+
|
142
|
+
# Assuming the dashboard is forwarded to localhost on the API server
|
143
|
+
dashboard_url = f'http://localhost:{port}'
|
144
|
+
try:
|
145
|
+
# Ping the dashboard to check if it's still running
|
146
|
+
async with httpx.AsyncClient() as client:
|
147
|
+
response = await client.request('GET',
|
148
|
+
dashboard_url,
|
149
|
+
timeout=1)
|
150
|
+
break # Connection successful, proceed with the request
|
151
|
+
except Exception as e: # pylint: disable=broad-except
|
152
|
+
# We catch all exceptions to gracefully handle unknown
|
153
|
+
# errors and retry or raise an HTTPException to the client.
|
154
|
+
msg = (
|
155
|
+
f'Dashboard connection attempt {attempt + 1} failed with '
|
156
|
+
f'{common_utils.format_exception(e, use_bracket=True)}')
|
157
|
+
logger.info(msg)
|
158
|
+
if attempt == max_retries - 1:
|
159
|
+
raise fastapi.HTTPException(status_code=503, detail=msg)
|
160
|
+
|
161
|
+
# Create a client session to forward the request
|
162
|
+
try:
|
163
|
+
async with httpx.AsyncClient() as client:
|
164
|
+
# Make the request and get the response
|
165
|
+
response = await client.request(
|
166
|
+
method='GET',
|
167
|
+
url=f'{dashboard_url}',
|
168
|
+
headers=request.headers.raw,
|
169
|
+
)
|
170
|
+
|
171
|
+
# Create a new response with the content already read
|
172
|
+
content = await response.aread()
|
173
|
+
return fastapi.Response(
|
174
|
+
content=content,
|
175
|
+
status_code=response.status_code,
|
176
|
+
headers=dict(response.headers),
|
177
|
+
media_type=response.headers.get('content-type'))
|
178
|
+
except Exception as e:
|
179
|
+
msg = (f'Failed to forward request to dashboard: '
|
180
|
+
f'{common_utils.format_exception(e, use_bracket=True)}')
|
181
|
+
logger.error(msg)
|
182
|
+
raise fastapi.HTTPException(status_code=502, detail=msg)
|
sky/jobs/utils.py
CHANGED
@@ -20,7 +20,6 @@ import filelock
|
|
20
20
|
import psutil
|
21
21
|
from typing_extensions import Literal
|
22
22
|
|
23
|
-
import sky
|
24
23
|
from sky import backends
|
25
24
|
from sky import exceptions
|
26
25
|
from sky import global_user_state
|
@@ -35,21 +34,17 @@ from sky.skylet import log_lib
|
|
35
34
|
from sky.usage import usage_lib
|
36
35
|
from sky.utils import common_utils
|
37
36
|
from sky.utils import log_utils
|
37
|
+
from sky.utils import message_utils
|
38
38
|
from sky.utils import rich_utils
|
39
39
|
from sky.utils import subprocess_utils
|
40
40
|
from sky.utils import ux_utils
|
41
41
|
|
42
42
|
if typing.TYPE_CHECKING:
|
43
|
+
import sky
|
43
44
|
from sky import dag as dag_lib
|
44
45
|
|
45
46
|
logger = sky_logging.init_logger(__name__)
|
46
47
|
|
47
|
-
# Add user hash so that two users don't have the same controller VM on
|
48
|
-
# shared-account clouds such as GCP.
|
49
|
-
JOB_CONTROLLER_NAME: str = (
|
50
|
-
f'sky-jobs-controller-{common_utils.get_user_hash()}')
|
51
|
-
LEGACY_JOB_CONTROLLER_NAME: str = (
|
52
|
-
f'sky-spot-controller-{common_utils.get_user_hash()}')
|
53
48
|
SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
|
54
49
|
# Controller checks its job's status every this many seconds.
|
55
50
|
JOB_STATUS_CHECK_GAP_SECONDS = 20
|
@@ -86,6 +81,7 @@ class UserSignal(enum.Enum):
|
|
86
81
|
# ====== internal functions ======
|
87
82
|
def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
|
88
83
|
"""Terminate the cluster."""
|
84
|
+
from sky import core # pylint: disable=import-outside-toplevel
|
89
85
|
retry_cnt = 0
|
90
86
|
# In some cases, e.g. botocore.exceptions.NoCredentialsError due to AWS
|
91
87
|
# metadata service throttling, the failed sky.down attempt can take 10-11
|
@@ -102,7 +98,7 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
|
|
102
98
|
while True:
|
103
99
|
try:
|
104
100
|
usage_lib.messages.usage.set_internal()
|
105
|
-
|
101
|
+
core.down(cluster_name)
|
106
102
|
return
|
107
103
|
except exceptions.ClusterDoesNotExist:
|
108
104
|
# The cluster is already down.
|
@@ -274,15 +270,23 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
|
|
274
270
|
failure_reason = ('Inconsistent internal job state. This is a bug.')
|
275
271
|
elif pid is None:
|
276
272
|
# Non-legacy job and controller process has not yet started.
|
277
|
-
|
273
|
+
controller_status = job_lib.get_status(job_id)
|
274
|
+
if controller_status == job_lib.JobStatus.FAILED_SETUP:
|
275
|
+
# We should fail the case where the controller status is
|
276
|
+
# FAILED_SETUP, as it is due to the failure of dependency setup
|
277
|
+
# on the controller.
|
278
|
+
# TODO(cooperc): We should also handle the case where controller
|
279
|
+
# status is FAILED_DRIVER or FAILED.
|
280
|
+
logger.error('Failed to setup the cloud dependencies for '
|
281
|
+
'the managed job.')
|
282
|
+
elif (schedule_state in [
|
278
283
|
managed_job_state.ManagedJobScheduleState.INACTIVE,
|
279
|
-
managed_job_state.ManagedJobScheduleState.WAITING
|
280
|
-
|
281
|
-
#
|
284
|
+
managed_job_state.ManagedJobScheduleState.WAITING,
|
285
|
+
]):
|
286
|
+
# It is expected that the controller hasn't been started yet.
|
282
287
|
continue
|
283
|
-
|
284
|
-
|
285
|
-
managed_job_state.ManagedJobScheduleState.LAUNCHING):
|
288
|
+
elif (schedule_state ==
|
289
|
+
managed_job_state.ManagedJobScheduleState.LAUNCHING):
|
286
290
|
# This is unlikely but technically possible. There's a brief
|
287
291
|
# period between marking job as scheduled (LAUNCHING) and
|
288
292
|
# actually launching the controller process and writing the pid
|
@@ -368,7 +372,7 @@ def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
|
|
368
372
|
subprocess_utils.handle_returncode(returncode, code,
|
369
373
|
'Failed to get job time.',
|
370
374
|
stdout + stderr)
|
371
|
-
stdout =
|
375
|
+
stdout = message_utils.decode_payload(stdout)
|
372
376
|
return float(stdout)
|
373
377
|
|
374
378
|
|
@@ -531,7 +535,8 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
|
|
531
535
|
f'{managed_job_state.get_failure_reason(job_id)}')
|
532
536
|
log_file = managed_job_state.get_local_log_file(job_id, None)
|
533
537
|
if log_file is not None:
|
534
|
-
with open(log_file, 'r',
|
538
|
+
with open(os.path.expanduser(log_file), 'r',
|
539
|
+
encoding='utf-8') as f:
|
535
540
|
# Stream the logs to the console without reading the whole
|
536
541
|
# file into memory.
|
537
542
|
start_streaming = False
|
@@ -878,12 +883,12 @@ def dump_managed_job_queue() -> str:
|
|
878
883
|
job['cluster_resources'] = '-'
|
879
884
|
job['region'] = '-'
|
880
885
|
|
881
|
-
return
|
886
|
+
return message_utils.encode_payload(jobs)
|
882
887
|
|
883
888
|
|
884
889
|
def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
|
885
890
|
"""Load job queue from json string."""
|
886
|
-
jobs =
|
891
|
+
jobs = message_utils.decode_payload(payload)
|
887
892
|
for job in jobs:
|
888
893
|
job['status'] = managed_job_state.ManagedJobStatus(job['status'])
|
889
894
|
return jobs
|
@@ -1159,9 +1164,9 @@ class ManagedJobCodeGen:
|
|
1159
1164
|
@classmethod
|
1160
1165
|
def get_all_job_ids_by_name(cls, job_name: Optional[str]) -> str:
|
1161
1166
|
code = textwrap.dedent(f"""\
|
1162
|
-
from sky.utils import
|
1167
|
+
from sky.utils import message_utils
|
1163
1168
|
job_id = managed_job_state.get_all_job_ids_by_name({job_name!r})
|
1164
|
-
print(
|
1169
|
+
print(message_utils.encode_payload(job_id), end="", flush=True)
|
1165
1170
|
""")
|
1166
1171
|
return cls._build(code)
|
1167
1172
|
|
@@ -1197,5 +1202,9 @@ class ManagedJobCodeGen:
|
|
1197
1202
|
@classmethod
|
1198
1203
|
def _build(cls, code: str) -> str:
|
1199
1204
|
generated_code = cls._PREFIX + '\n' + code
|
1200
|
-
|
1201
|
-
|
1205
|
+
# Use the local user id to make sure the operation goes to the correct
|
1206
|
+
# user.
|
1207
|
+
return (
|
1208
|
+
f'export {constants.USER_ID_ENV_VAR}='
|
1209
|
+
f'"{common_utils.get_user_hash()}"; '
|
1210
|
+
f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}')
|
sky/models.py
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
"""Data Models for SkyPilot."""
|
2
|
+
|
3
|
+
import collections
|
4
|
+
import dataclasses
|
5
|
+
from typing import Dict, Optional
|
6
|
+
|
7
|
+
|
8
|
+
@dataclasses.dataclass
|
9
|
+
class User:
|
10
|
+
# User hash
|
11
|
+
id: str
|
12
|
+
# Display name of the user
|
13
|
+
name: Optional[str] = None
|
14
|
+
|
15
|
+
|
16
|
+
RealtimeGpuAvailability = collections.namedtuple(
|
17
|
+
'RealtimeGpuAvailability', ['gpu', 'counts', 'capacity', 'available'])
|
18
|
+
|
19
|
+
|
20
|
+
@dataclasses.dataclass
|
21
|
+
class KubernetesNodeInfo:
|
22
|
+
"""Dataclass to store Kubernetes node information."""
|
23
|
+
name: str
|
24
|
+
accelerator_type: Optional[str]
|
25
|
+
# Resources available on the node. E.g., {'nvidia.com/gpu': '2'}
|
26
|
+
total: Dict[str, int]
|
27
|
+
free: Dict[str, int]
|