skypilot-nightly 1.0.0.dev20250215__py3-none-any.whl → 1.0.0.dev20250217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sky/__init__.py +48 -22
  2. sky/adaptors/aws.py +2 -1
  3. sky/adaptors/azure.py +4 -4
  4. sky/adaptors/cloudflare.py +4 -4
  5. sky/adaptors/kubernetes.py +8 -8
  6. sky/authentication.py +42 -45
  7. sky/backends/backend.py +2 -2
  8. sky/backends/backend_utils.py +108 -221
  9. sky/backends/cloud_vm_ray_backend.py +283 -282
  10. sky/benchmark/benchmark_utils.py +6 -2
  11. sky/check.py +40 -28
  12. sky/cli.py +1213 -1116
  13. sky/client/__init__.py +1 -0
  14. sky/client/cli.py +5644 -0
  15. sky/client/common.py +345 -0
  16. sky/client/sdk.py +1757 -0
  17. sky/cloud_stores.py +12 -6
  18. sky/clouds/__init__.py +0 -2
  19. sky/clouds/aws.py +20 -13
  20. sky/clouds/azure.py +5 -3
  21. sky/clouds/cloud.py +1 -1
  22. sky/clouds/cudo.py +2 -1
  23. sky/clouds/do.py +2 -1
  24. sky/clouds/fluidstack.py +3 -2
  25. sky/clouds/gcp.py +10 -8
  26. sky/clouds/ibm.py +8 -7
  27. sky/clouds/kubernetes.py +7 -6
  28. sky/clouds/lambda_cloud.py +8 -7
  29. sky/clouds/oci.py +4 -3
  30. sky/clouds/paperspace.py +2 -1
  31. sky/clouds/runpod.py +2 -1
  32. sky/clouds/scp.py +8 -7
  33. sky/clouds/service_catalog/__init__.py +3 -3
  34. sky/clouds/service_catalog/aws_catalog.py +7 -1
  35. sky/clouds/service_catalog/common.py +4 -2
  36. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +2 -2
  37. sky/clouds/utils/oci_utils.py +1 -1
  38. sky/clouds/vast.py +2 -1
  39. sky/clouds/vsphere.py +2 -1
  40. sky/core.py +263 -99
  41. sky/dag.py +4 -0
  42. sky/data/mounting_utils.py +2 -1
  43. sky/data/storage.py +97 -35
  44. sky/data/storage_utils.py +69 -9
  45. sky/exceptions.py +138 -5
  46. sky/execution.py +47 -50
  47. sky/global_user_state.py +105 -22
  48. sky/jobs/__init__.py +12 -14
  49. sky/jobs/client/__init__.py +0 -0
  50. sky/jobs/client/sdk.py +296 -0
  51. sky/jobs/constants.py +30 -1
  52. sky/jobs/controller.py +12 -6
  53. sky/jobs/dashboard/dashboard.py +2 -6
  54. sky/jobs/recovery_strategy.py +22 -29
  55. sky/jobs/server/__init__.py +1 -0
  56. sky/jobs/{core.py → server/core.py} +101 -34
  57. sky/jobs/server/dashboard_utils.py +64 -0
  58. sky/jobs/server/server.py +182 -0
  59. sky/jobs/utils.py +32 -23
  60. sky/models.py +27 -0
  61. sky/optimizer.py +9 -11
  62. sky/provision/__init__.py +6 -3
  63. sky/provision/aws/config.py +2 -2
  64. sky/provision/aws/instance.py +1 -1
  65. sky/provision/azure/instance.py +1 -1
  66. sky/provision/cudo/instance.py +1 -1
  67. sky/provision/do/instance.py +1 -1
  68. sky/provision/do/utils.py +0 -5
  69. sky/provision/fluidstack/fluidstack_utils.py +4 -3
  70. sky/provision/fluidstack/instance.py +4 -2
  71. sky/provision/gcp/instance.py +1 -1
  72. sky/provision/instance_setup.py +2 -2
  73. sky/provision/kubernetes/constants.py +8 -0
  74. sky/provision/kubernetes/instance.py +1 -1
  75. sky/provision/kubernetes/utils.py +67 -76
  76. sky/provision/lambda_cloud/instance.py +3 -15
  77. sky/provision/logging.py +1 -1
  78. sky/provision/oci/instance.py +7 -4
  79. sky/provision/paperspace/instance.py +1 -1
  80. sky/provision/provisioner.py +3 -2
  81. sky/provision/runpod/instance.py +1 -1
  82. sky/provision/vast/instance.py +1 -1
  83. sky/provision/vast/utils.py +2 -1
  84. sky/provision/vsphere/instance.py +2 -11
  85. sky/resources.py +55 -40
  86. sky/serve/__init__.py +6 -10
  87. sky/serve/client/__init__.py +0 -0
  88. sky/serve/client/sdk.py +366 -0
  89. sky/serve/constants.py +3 -0
  90. sky/serve/replica_managers.py +10 -10
  91. sky/serve/serve_utils.py +56 -36
  92. sky/serve/server/__init__.py +0 -0
  93. sky/serve/{core.py → server/core.py} +37 -17
  94. sky/serve/server/server.py +117 -0
  95. sky/serve/service.py +8 -1
  96. sky/server/__init__.py +1 -0
  97. sky/server/common.py +441 -0
  98. sky/server/constants.py +21 -0
  99. sky/server/html/log.html +174 -0
  100. sky/server/requests/__init__.py +0 -0
  101. sky/server/requests/executor.py +462 -0
  102. sky/server/requests/payloads.py +481 -0
  103. sky/server/requests/queues/__init__.py +0 -0
  104. sky/server/requests/queues/mp_queue.py +76 -0
  105. sky/server/requests/requests.py +567 -0
  106. sky/server/requests/serializers/__init__.py +0 -0
  107. sky/server/requests/serializers/decoders.py +192 -0
  108. sky/server/requests/serializers/encoders.py +166 -0
  109. sky/server/server.py +1095 -0
  110. sky/server/stream_utils.py +144 -0
  111. sky/setup_files/MANIFEST.in +1 -0
  112. sky/setup_files/dependencies.py +12 -4
  113. sky/setup_files/setup.py +1 -1
  114. sky/sky_logging.py +9 -13
  115. sky/skylet/autostop_lib.py +2 -2
  116. sky/skylet/constants.py +46 -12
  117. sky/skylet/events.py +5 -6
  118. sky/skylet/job_lib.py +78 -66
  119. sky/skylet/log_lib.py +17 -11
  120. sky/skypilot_config.py +79 -94
  121. sky/task.py +119 -73
  122. sky/templates/aws-ray.yml.j2 +4 -4
  123. sky/templates/azure-ray.yml.j2 +3 -2
  124. sky/templates/cudo-ray.yml.j2 +3 -2
  125. sky/templates/fluidstack-ray.yml.j2 +3 -2
  126. sky/templates/gcp-ray.yml.j2 +3 -2
  127. sky/templates/ibm-ray.yml.j2 +3 -2
  128. sky/templates/jobs-controller.yaml.j2 +1 -12
  129. sky/templates/kubernetes-ray.yml.j2 +3 -2
  130. sky/templates/lambda-ray.yml.j2 +3 -2
  131. sky/templates/oci-ray.yml.j2 +3 -2
  132. sky/templates/paperspace-ray.yml.j2 +3 -2
  133. sky/templates/runpod-ray.yml.j2 +3 -2
  134. sky/templates/scp-ray.yml.j2 +3 -2
  135. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  136. sky/templates/vsphere-ray.yml.j2 +4 -2
  137. sky/templates/websocket_proxy.py +64 -0
  138. sky/usage/constants.py +8 -0
  139. sky/usage/usage_lib.py +45 -11
  140. sky/utils/accelerator_registry.py +33 -53
  141. sky/utils/admin_policy_utils.py +2 -1
  142. sky/utils/annotations.py +51 -0
  143. sky/utils/cli_utils/status_utils.py +33 -3
  144. sky/utils/cluster_utils.py +356 -0
  145. sky/utils/command_runner.py +69 -14
  146. sky/utils/common.py +74 -0
  147. sky/utils/common_utils.py +133 -93
  148. sky/utils/config_utils.py +204 -0
  149. sky/utils/control_master_utils.py +2 -3
  150. sky/utils/controller_utils.py +133 -147
  151. sky/utils/dag_utils.py +72 -24
  152. sky/utils/kubernetes/deploy_remote_cluster.sh +2 -2
  153. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  154. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  155. sky/utils/log_utils.py +83 -23
  156. sky/utils/message_utils.py +81 -0
  157. sky/utils/registry.py +127 -0
  158. sky/utils/resources_utils.py +2 -2
  159. sky/utils/rich_utils.py +213 -34
  160. sky/utils/schemas.py +19 -2
  161. sky/{status_lib.py → utils/status_lib.py} +12 -7
  162. sky/utils/subprocess_utils.py +51 -35
  163. sky/utils/timeline.py +7 -2
  164. sky/utils/ux_utils.py +95 -25
  165. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/METADATA +8 -3
  166. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/RECORD +170 -132
  167. sky/clouds/cloud_registry.py +0 -76
  168. sky/utils/cluster_yaml_utils.py +0 -24
  169. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/LICENSE +0 -0
  170. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/WHEEL +0 -0
  171. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/entry_points.txt +0 -0
  172. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  """SDK functions for managed jobs."""
2
2
  import os
3
+ import signal
4
+ import subprocess
3
5
  import tempfile
6
+ import time
4
7
  import typing
5
8
  from typing import Any, Dict, List, Optional, Tuple, Union
6
9
  import uuid
7
10
 
8
11
  import colorama
9
12
 
10
- import sky
11
13
  from sky import backends
14
+ from sky import core
12
15
  from sky import exceptions
16
+ from sky import execution
13
17
  from sky import provision as provision_lib
14
18
  from sky import sky_logging
15
- from sky import status_lib
16
19
  from sky import task as task_lib
17
20
  from sky.backends import backend_utils
18
21
  from sky.clouds.service_catalog import common as service_catalog_common
@@ -26,13 +29,17 @@ from sky.utils import common_utils
26
29
  from sky.utils import controller_utils
27
30
  from sky.utils import dag_utils
28
31
  from sky.utils import rich_utils
32
+ from sky.utils import status_lib
29
33
  from sky.utils import subprocess_utils
30
34
  from sky.utils import timeline
31
35
  from sky.utils import ux_utils
32
36
 
33
37
  if typing.TYPE_CHECKING:
38
+ import sky
34
39
  from sky.backends import cloud_vm_ray_backend
35
40
 
41
+ logger = sky_logging.init_logger(__name__)
42
+
36
43
 
37
44
  @timeline.event
38
45
  @usage_lib.entrypoint
@@ -40,10 +47,9 @@ def launch(
40
47
  task: Union['sky.Task', 'sky.Dag'],
41
48
  name: Optional[str] = None,
42
49
  stream_logs: bool = True,
43
- detach_run: bool = False,
44
50
  ) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
45
51
  # NOTE(dev): Keep the docstring consistent between the Python API and CLI.
46
- """Launch a managed job.
52
+ """Launches a managed job.
47
53
 
48
54
  Please refer to sky.cli.job_launch for documentation.
49
55
 
@@ -51,7 +57,6 @@ def launch(
51
57
  task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
52
58
  managed job.
53
59
  name: Name of the managed job.
54
- detach_run: Whether to detach the run.
55
60
 
56
61
  Raises:
57
62
  ValueError: cluster does not exist. Or, the entrypoint is not a valid
@@ -77,7 +82,7 @@ def launch(
77
82
  with ux_utils.print_exception_no_traceback():
78
83
  raise ValueError('Only single-task or chain DAG is '
79
84
  f'allowed for job_launch. Dag: {dag}')
80
-
85
+ dag.validate()
81
86
  dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)
82
87
 
83
88
  task_names = set()
@@ -120,6 +125,7 @@ def launch(
120
125
  'remote_user_config_path': remote_user_config_path,
121
126
  'modified_catalogs':
122
127
  service_catalog_common.get_modified_catalog_file_mounts(),
128
+ 'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
123
129
  **controller_utils.shared_controller_vars_to_fill(
124
130
  controller_utils.Controllers.JOBS_CONTROLLER,
125
131
  remote_user_config_path=remote_user_config_path,
@@ -143,15 +149,14 @@ def launch(
143
149
  f'{colorama.Fore.YELLOW}'
144
150
  f'Launching managed job {dag.name!r} from jobs controller...'
145
151
  f'{colorama.Style.RESET_ALL}')
146
- return sky.launch(task=controller_task,
147
- stream_logs=stream_logs,
148
- cluster_name=controller_name,
149
- detach_run=detach_run,
150
- idle_minutes_to_autostop=skylet_constants.
151
- CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
152
- retry_until_up=True,
153
- fast=True,
154
- _disable_controller_check=True)
152
+ return execution.launch(task=controller_task,
153
+ cluster_name=controller_name,
154
+ stream_logs=stream_logs,
155
+ idle_minutes_to_autostop=skylet_constants.
156
+ CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
157
+ retry_until_up=True,
158
+ fast=True,
159
+ _disable_controller_check=True)
155
160
 
156
161
 
157
162
  def queue_from_kubernetes_pod(
@@ -256,7 +261,22 @@ def _maybe_restart_controller(
256
261
  rich_utils.force_update_status(
257
262
  ux_utils.spinner_message(f'{spinner_message} - restarting '
258
263
  'controller'))
259
- handle = sky.start(jobs_controller_type.value.cluster_name)
264
+ handle = core.start(cluster_name=jobs_controller_type.value.cluster_name)
265
+ # Make sure the dashboard is running when the controller is restarted.
266
+ # We should not directly use execution.launch() and have the dashboard cmd
267
+ # in the task setup because since we are using detached_setup, it will
268
+ # become a job on controller which messes up the job IDs (we assume the
269
+ # job ID in controller's job queue is consistent with managed job IDs).
270
+ with rich_utils.safe_status(
271
+ ux_utils.spinner_message('Starting dashboard...')):
272
+ runner = handle.get_command_runners()[0]
273
+ user_hash = common_utils.get_user_hash()
274
+ runner.run(
275
+ f'export '
276
+ f'{skylet_constants.USER_ID_ENV_VAR}={user_hash!r}; '
277
+ f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
278
+ stream_logs=True,
279
+ )
260
280
  controller_status = status_lib.ClusterStatus.UP
261
281
  rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))
262
282
 
@@ -267,7 +287,7 @@ def _maybe_restart_controller(
267
287
  @usage_lib.entrypoint
268
288
  def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
269
289
  # NOTE(dev): Keep the docstring consistent between the Python API and CLI.
270
- """Get statuses of managed jobs.
290
+ """Gets statuses of managed jobs.
271
291
 
272
292
  Please refer to sky.cli.job_queue for documentation.
273
293
 
@@ -307,14 +327,10 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
307
327
  stream_logs=False,
308
328
  separate_stderr=True)
309
329
 
310
- try:
311
- subprocess_utils.handle_returncode(returncode,
312
- code,
313
- 'Failed to fetch managed jobs',
314
- job_table_payload + stderr,
315
- stream_logs=False)
316
- except exceptions.CommandError as e:
317
- raise RuntimeError(str(e)) from e
330
+ if returncode != 0:
331
+ logger.error(job_table_payload + stderr)
332
+ raise RuntimeError('Failed to fetch managed jobs with returncode: '
333
+ f'{returncode}')
318
334
 
319
335
  jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
320
336
  if skip_finished:
@@ -334,7 +350,7 @@ def cancel(name: Optional[str] = None,
334
350
  job_ids: Optional[List[int]] = None,
335
351
  all: bool = False) -> None:
336
352
  # NOTE(dev): Keep the docstring consistent between the Python API and CLI.
337
- """Cancel managed jobs.
353
+ """Cancels managed jobs.
338
354
 
339
355
  Please refer to sky.cli.job_cancel for documentation.
340
356
 
@@ -428,22 +444,73 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
428
444
  controller=controller)
429
445
 
430
446
 
447
+ def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
448
+ """Opens a dashboard for managed jobs (needs controller to be UP)."""
449
+ # TODO(SKY-1212): ideally, the controller/dashboard server should expose the
450
+ # API perhaps via REST. Then here we would (1) not have to use SSH to try to
451
+ # see if the controller is UP first, which is slow; (2) not have to run SSH
452
+ # port forwarding first (we'd just launch a local dashboard which would make
453
+ # REST API calls to the controller dashboard server).
454
+ logger.info('Starting dashboard')
455
+ hint = ('Dashboard is not available if jobs controller is not up. Run '
456
+ 'a managed job first or run: sky jobs queue --refresh')
457
+ handle = _maybe_restart_controller(
458
+ refresh=refresh,
459
+ stopped_message=hint,
460
+ spinner_message='Checking jobs controller')
461
+
462
+ # SSH forward a free local port to remote's dashboard port.
463
+ remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
464
+ free_port = common_utils.find_free_port(remote_port)
465
+ runner = handle.get_command_runners()[0]
466
+ port_forward_command = ' '.join(
467
+ runner.port_forward_command(port_forward=[(free_port, remote_port)],
468
+ connect_timeout=1))
469
+ port_forward_command = (
470
+ f'{port_forward_command} '
471
+ f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
472
+ '2>&1')
473
+ logger.info(f'Forwarding port: {colorama.Style.DIM}{port_forward_command}'
474
+ f'{colorama.Style.RESET_ALL}')
475
+
476
+ ssh_process = subprocess.Popen(port_forward_command,
477
+ shell=True,
478
+ start_new_session=True)
479
+ time.sleep(3) # Added delay for ssh_command to initialize.
480
+ logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
481
+ f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')
482
+
483
+ return free_port, ssh_process.pid
484
+
485
+
486
+ def stop_dashboard_forwarding(pid: int) -> None:
487
+ # Exit the ssh command when the context manager is closed.
488
+ try:
489
+ os.killpg(os.getpgid(pid), signal.SIGTERM)
490
+ except ProcessLookupError:
491
+ # This happens if jobs controller is auto-stopped.
492
+ pass
493
+ logger.info('Forwarding port closed. Exiting.')
494
+
495
+
431
496
  @usage_lib.entrypoint
432
- def sync_down_logs(
497
+ def download_logs(
433
498
  name: Optional[str],
434
499
  job_id: Optional[int],
435
500
  refresh: bool,
436
501
  controller: bool,
437
- local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> None:
502
+ local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]:
438
503
  """Sync down logs of managed jobs.
439
504
 
440
505
  Please refer to sky.cli.job_logs for documentation.
441
506
 
507
+ Returns:
508
+ A dictionary mapping job ID to the local path.
509
+
442
510
  Raises:
443
511
  ValueError: invalid arguments.
444
512
  sky.exceptions.ClusterNotUpError: the jobs controller is not up.
445
513
  """
446
- # TODO(zhwu): Automatically restart the jobs controller
447
514
  if name is not None and job_id is not None:
448
515
  with ux_utils.print_exception_no_traceback():
449
516
  raise ValueError('Cannot specify both name and job_id.')
@@ -467,8 +534,8 @@ def sync_down_logs(
467
534
  backend = backend_utils.get_backend_from_handle(handle)
468
535
  assert isinstance(backend, backends.CloudVmRayBackend), backend
469
536
 
470
- backend.sync_down_managed_job_logs(handle,
471
- job_id=job_id,
472
- job_name=name,
473
- controller=controller,
474
- local_dir=local_dir)
537
+ return backend.sync_down_managed_job_logs(handle,
538
+ job_id=job_id,
539
+ job_name=name,
540
+ controller=controller,
541
+ local_dir=local_dir)
@@ -0,0 +1,64 @@
1
+ """Persistent dashboard sessions."""
2
+ import pathlib
3
+ from typing import Tuple
4
+
5
+ import filelock
6
+
7
+ from sky.utils import db_utils
8
+
9
+
10
+ def create_dashboard_table(cursor, conn):
11
+ cursor.execute("""\
12
+ CREATE TABLE IF NOT EXISTS dashboard_sessions (
13
+ user_hash TEXT PRIMARY KEY,
14
+ port INTEGER,
15
+ pid INTEGER)""")
16
+ conn.commit()
17
+
18
+
19
+ def _get_db_path() -> str:
20
+ path = pathlib.Path('~/.sky/dashboard/sessions.db')
21
+ path = path.expanduser().absolute()
22
+ path.parent.mkdir(parents=True, exist_ok=True)
23
+ return str(path)
24
+
25
+
26
+ DB_PATH = _get_db_path()
27
+ db_utils.SQLiteConn(DB_PATH, create_dashboard_table)
28
+ LOCK_FILE_PATH = '~/.sky/dashboard/sessions-{user_hash}.lock'
29
+
30
+
31
+ def get_dashboard_session(user_hash: str) -> Tuple[int, int]:
32
+ """Get the port and pid of the dashboard session for the user."""
33
+ with db_utils.safe_cursor(DB_PATH) as cursor:
34
+ cursor.execute(
35
+ 'SELECT port, pid FROM dashboard_sessions WHERE user_hash=?',
36
+ (user_hash,))
37
+ result = cursor.fetchone()
38
+ if result is None:
39
+ return 0, 0
40
+ return result
41
+
42
+
43
+ def add_dashboard_session(user_hash: str, port: int, pid: int) -> None:
44
+ """Add a dashboard session for the user."""
45
+ with db_utils.safe_cursor(DB_PATH) as cursor:
46
+ cursor.execute(
47
+ 'INSERT OR REPLACE INTO dashboard_sessions (user_hash, port, pid) '
48
+ 'VALUES (?, ?, ?)', (user_hash, port, pid))
49
+
50
+
51
+ def remove_dashboard_session(user_hash: str) -> None:
52
+ """Remove the dashboard session for the user."""
53
+ with db_utils.safe_cursor(DB_PATH) as cursor:
54
+ cursor.execute('DELETE FROM dashboard_sessions WHERE user_hash=?',
55
+ (user_hash,))
56
+ lock_path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
57
+ lock_path.unlink(missing_ok=True)
58
+
59
+
60
+ def get_dashboard_lock_for_user(user_hash: str) -> filelock.FileLock:
61
+ path = pathlib.Path(LOCK_FILE_PATH.format(user_hash=user_hash))
62
+ path = path.expanduser().absolute()
63
+ path.parent.mkdir(parents=True, exist_ok=True)
64
+ return filelock.FileLock(path)
@@ -0,0 +1,182 @@
1
+ """REST API for managed jobs."""
2
+ import os
3
+
4
+ import fastapi
5
+ import httpx
6
+
7
+ from sky import sky_logging
8
+ from sky.jobs.server import core
9
+ from sky.jobs.server import dashboard_utils
10
+ from sky.server import common as server_common
11
+ from sky.server import stream_utils
12
+ from sky.server.requests import executor
13
+ from sky.server.requests import payloads
14
+ from sky.server.requests import requests as api_requests
15
+ from sky.skylet import constants
16
+ from sky.utils import common
17
+ from sky.utils import common_utils
18
+
19
+ logger = sky_logging.init_logger(__name__)
20
+
21
+ router = fastapi.APIRouter()
22
+
23
+
24
+ def _get_controller_name(request_body: payloads.RequestBody) -> str:
25
+ user_hash = request_body.user_hash
26
+ return common.get_controller_name(common.ControllerType.JOBS, user_hash)
27
+
28
+
29
+ @router.post('/launch')
30
+ async def launch(request: fastapi.Request,
31
+ jobs_launch_body: payloads.JobsLaunchBody) -> None:
32
+ executor.schedule_request(
33
+ request_id=request.state.request_id,
34
+ request_name='jobs.launch',
35
+ request_body=jobs_launch_body,
36
+ func=core.launch,
37
+ schedule_type=api_requests.ScheduleType.LONG,
38
+ request_cluster_name=_get_controller_name(jobs_launch_body),
39
+ )
40
+
41
+
42
+ @router.post('/queue')
43
+ async def queue(request: fastapi.Request,
44
+ jobs_queue_body: payloads.JobsQueueBody) -> None:
45
+ executor.schedule_request(
46
+ request_id=request.state.request_id,
47
+ request_name='jobs.queue',
48
+ request_body=jobs_queue_body,
49
+ func=core.queue,
50
+ schedule_type=(api_requests.ScheduleType.LONG if jobs_queue_body.refresh
51
+ else api_requests.ScheduleType.SHORT),
52
+ request_cluster_name=_get_controller_name(jobs_queue_body),
53
+ )
54
+
55
+
56
+ @router.post('/cancel')
57
+ async def cancel(request: fastapi.Request,
58
+ jobs_cancel_body: payloads.JobsCancelBody) -> None:
59
+ executor.schedule_request(
60
+ request_id=request.state.request_id,
61
+ request_name='jobs.cancel',
62
+ request_body=jobs_cancel_body,
63
+ func=core.cancel,
64
+ schedule_type=api_requests.ScheduleType.SHORT,
65
+ request_cluster_name=_get_controller_name(jobs_cancel_body),
66
+ )
67
+
68
+
69
+ @router.post('/logs')
70
+ async def logs(
71
+ request: fastapi.Request, jobs_logs_body: payloads.JobsLogsBody,
72
+ background_tasks: fastapi.BackgroundTasks
73
+ ) -> fastapi.responses.StreamingResponse:
74
+ executor.schedule_request(
75
+ request_id=request.state.request_id,
76
+ request_name='jobs.logs',
77
+ request_body=jobs_logs_body,
78
+ func=core.tail_logs,
79
+ schedule_type=api_requests.ScheduleType.SHORT
80
+ if jobs_logs_body.refresh else api_requests.ScheduleType.LONG,
81
+ request_cluster_name=_get_controller_name(jobs_logs_body),
82
+ )
83
+ request_task = api_requests.get_request(request.state.request_id)
84
+
85
+ return stream_utils.stream_response(
86
+ request_id=request_task.request_id,
87
+ logs_path=request_task.log_path,
88
+ background_tasks=background_tasks,
89
+ )
90
+
91
+
92
+ @router.post('/download_logs')
93
+ async def download_logs(
94
+ request: fastapi.Request,
95
+ jobs_download_logs_body: payloads.JobsDownloadLogsBody) -> None:
96
+ user_hash = jobs_download_logs_body.env_vars[constants.USER_ID_ENV_VAR]
97
+ logs_dir_on_api_server = server_common.api_server_user_logs_dir_prefix(
98
+ user_hash)
99
+ logs_dir_on_api_server.expanduser().mkdir(parents=True, exist_ok=True)
100
+ # We should reuse the original request body, so that the env vars, such as
101
+ # user hash, are kept the same.
102
+ jobs_download_logs_body.local_dir = str(logs_dir_on_api_server)
103
+ executor.schedule_request(
104
+ request_id=request.state.request_id,
105
+ request_name='jobs.download_logs',
106
+ request_body=jobs_download_logs_body,
107
+ func=core.download_logs,
108
+ schedule_type=api_requests.ScheduleType.LONG
109
+ if jobs_download_logs_body.refresh else api_requests.ScheduleType.SHORT,
110
+ request_cluster_name=_get_controller_name(jobs_download_logs_body),
111
+ )
112
+
113
+
114
+ @router.get('/dashboard')
115
+ async def dashboard(request: fastapi.Request,
116
+ user_hash: str) -> fastapi.Response:
117
+ # Find the port for the dashboard of the user
118
+ os.environ[constants.USER_ID_ENV_VAR] = user_hash
119
+ server_common.reload_for_new_request(client_entrypoint=None,
120
+ client_command=None)
121
+ logger.info(f'Starting dashboard for user hash: {user_hash}')
122
+
123
+ with dashboard_utils.get_dashboard_lock_for_user(user_hash):
124
+ max_retries = 3
125
+ for attempt in range(max_retries):
126
+ port, pid = dashboard_utils.get_dashboard_session(user_hash)
127
+ if port == 0 or attempt > 0:
128
+ # Let the client know that we are waiting for starting the
129
+ # dashboard.
130
+ try:
131
+ port, pid = core.start_dashboard_forwarding()
132
+ except Exception as e: # pylint: disable=broad-except
133
+ # We catch all exceptions to gracefully handle unknown
134
+ # errors and raise an HTTPException to the client.
135
+ msg = (
136
+ 'Dashboard failed to start: '
137
+ f'{common_utils.format_exception(e, use_bracket=True)}')
138
+ logger.error(msg)
139
+ raise fastapi.HTTPException(status_code=503, detail=msg)
140
+ dashboard_utils.add_dashboard_session(user_hash, port, pid)
141
+
142
+ # Assuming the dashboard is forwarded to localhost on the API server
143
+ dashboard_url = f'http://localhost:{port}'
144
+ try:
145
+ # Ping the dashboard to check if it's still running
146
+ async with httpx.AsyncClient() as client:
147
+ response = await client.request('GET',
148
+ dashboard_url,
149
+ timeout=1)
150
+ break # Connection successful, proceed with the request
151
+ except Exception as e: # pylint: disable=broad-except
152
+ # We catch all exceptions to gracefully handle unknown
153
+ # errors and retry or raise an HTTPException to the client.
154
+ msg = (
155
+ f'Dashboard connection attempt {attempt + 1} failed with '
156
+ f'{common_utils.format_exception(e, use_bracket=True)}')
157
+ logger.info(msg)
158
+ if attempt == max_retries - 1:
159
+ raise fastapi.HTTPException(status_code=503, detail=msg)
160
+
161
+ # Create a client session to forward the request
162
+ try:
163
+ async with httpx.AsyncClient() as client:
164
+ # Make the request and get the response
165
+ response = await client.request(
166
+ method='GET',
167
+ url=f'{dashboard_url}',
168
+ headers=request.headers.raw,
169
+ )
170
+
171
+ # Create a new response with the content already read
172
+ content = await response.aread()
173
+ return fastapi.Response(
174
+ content=content,
175
+ status_code=response.status_code,
176
+ headers=dict(response.headers),
177
+ media_type=response.headers.get('content-type'))
178
+ except Exception as e:
179
+ msg = (f'Failed to forward request to dashboard: '
180
+ f'{common_utils.format_exception(e, use_bracket=True)}')
181
+ logger.error(msg)
182
+ raise fastapi.HTTPException(status_code=502, detail=msg)
sky/jobs/utils.py CHANGED
@@ -20,7 +20,6 @@ import filelock
20
20
  import psutil
21
21
  from typing_extensions import Literal
22
22
 
23
- import sky
24
23
  from sky import backends
25
24
  from sky import exceptions
26
25
  from sky import global_user_state
@@ -35,21 +34,17 @@ from sky.skylet import log_lib
35
34
  from sky.usage import usage_lib
36
35
  from sky.utils import common_utils
37
36
  from sky.utils import log_utils
37
+ from sky.utils import message_utils
38
38
  from sky.utils import rich_utils
39
39
  from sky.utils import subprocess_utils
40
40
  from sky.utils import ux_utils
41
41
 
42
42
  if typing.TYPE_CHECKING:
43
+ import sky
43
44
  from sky import dag as dag_lib
44
45
 
45
46
  logger = sky_logging.init_logger(__name__)
46
47
 
47
- # Add user hash so that two users don't have the same controller VM on
48
- # shared-account clouds such as GCP.
49
- JOB_CONTROLLER_NAME: str = (
50
- f'sky-jobs-controller-{common_utils.get_user_hash()}')
51
- LEGACY_JOB_CONTROLLER_NAME: str = (
52
- f'sky-spot-controller-{common_utils.get_user_hash()}')
53
48
  SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
54
49
  # Controller checks its job's status every this many seconds.
55
50
  JOB_STATUS_CHECK_GAP_SECONDS = 20
@@ -86,6 +81,7 @@ class UserSignal(enum.Enum):
86
81
  # ====== internal functions ======
87
82
  def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
88
83
  """Terminate the cluster."""
84
+ from sky import core # pylint: disable=import-outside-toplevel
89
85
  retry_cnt = 0
90
86
  # In some cases, e.g. botocore.exceptions.NoCredentialsError due to AWS
91
87
  # metadata service throttling, the failed sky.down attempt can take 10-11
@@ -102,7 +98,7 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
102
98
  while True:
103
99
  try:
104
100
  usage_lib.messages.usage.set_internal()
105
- sky.down(cluster_name)
101
+ core.down(cluster_name)
106
102
  return
107
103
  except exceptions.ClusterDoesNotExist:
108
104
  # The cluster is already down.
@@ -274,15 +270,23 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
274
270
  failure_reason = ('Inconsistent internal job state. This is a bug.')
275
271
  elif pid is None:
276
272
  # Non-legacy job and controller process has not yet started.
277
- if schedule_state in (
273
+ controller_status = job_lib.get_status(job_id)
274
+ if controller_status == job_lib.JobStatus.FAILED_SETUP:
275
+ # We should fail the case where the controller status is
276
+ # FAILED_SETUP, as it is due to the failure of dependency setup
277
+ # on the controller.
278
+ # TODO(cooperc): We should also handle the case where controller
279
+ # status is FAILED_DRIVER or FAILED.
280
+ logger.error('Failed to setup the cloud dependencies for '
281
+ 'the managed job.')
282
+ elif (schedule_state in [
278
283
  managed_job_state.ManagedJobScheduleState.INACTIVE,
279
- managed_job_state.ManagedJobScheduleState.WAITING):
280
- # For these states, the controller hasn't been started yet.
281
- # This is expected.
284
+ managed_job_state.ManagedJobScheduleState.WAITING,
285
+ ]):
286
+ # It is expected that the controller hasn't been started yet.
282
287
  continue
283
-
284
- if (schedule_state ==
285
- managed_job_state.ManagedJobScheduleState.LAUNCHING):
288
+ elif (schedule_state ==
289
+ managed_job_state.ManagedJobScheduleState.LAUNCHING):
286
290
  # This is unlikely but technically possible. There's a brief
287
291
  # period between marking job as scheduled (LAUNCHING) and
288
292
  # actually launching the controller process and writing the pid
@@ -368,7 +372,7 @@ def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
368
372
  subprocess_utils.handle_returncode(returncode, code,
369
373
  'Failed to get job time.',
370
374
  stdout + stderr)
371
- stdout = common_utils.decode_payload(stdout)
375
+ stdout = message_utils.decode_payload(stdout)
372
376
  return float(stdout)
373
377
 
374
378
 
@@ -531,7 +535,8 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
531
535
  f'{managed_job_state.get_failure_reason(job_id)}')
532
536
  log_file = managed_job_state.get_local_log_file(job_id, None)
533
537
  if log_file is not None:
534
- with open(log_file, 'r', encoding='utf-8') as f:
538
+ with open(os.path.expanduser(log_file), 'r',
539
+ encoding='utf-8') as f:
535
540
  # Stream the logs to the console without reading the whole
536
541
  # file into memory.
537
542
  start_streaming = False
@@ -878,12 +883,12 @@ def dump_managed_job_queue() -> str:
878
883
  job['cluster_resources'] = '-'
879
884
  job['region'] = '-'
880
885
 
881
- return common_utils.encode_payload(jobs)
886
+ return message_utils.encode_payload(jobs)
882
887
 
883
888
 
884
889
  def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
885
890
  """Load job queue from json string."""
886
- jobs = common_utils.decode_payload(payload)
891
+ jobs = message_utils.decode_payload(payload)
887
892
  for job in jobs:
888
893
  job['status'] = managed_job_state.ManagedJobStatus(job['status'])
889
894
  return jobs
@@ -1159,9 +1164,9 @@ class ManagedJobCodeGen:
1159
1164
  @classmethod
1160
1165
  def get_all_job_ids_by_name(cls, job_name: Optional[str]) -> str:
1161
1166
  code = textwrap.dedent(f"""\
1162
- from sky.utils import common_utils
1167
+ from sky.utils import message_utils
1163
1168
  job_id = managed_job_state.get_all_job_ids_by_name({job_name!r})
1164
- print(common_utils.encode_payload(job_id), end="", flush=True)
1169
+ print(message_utils.encode_payload(job_id), end="", flush=True)
1165
1170
  """)
1166
1171
  return cls._build(code)
1167
1172
 
@@ -1197,5 +1202,9 @@ class ManagedJobCodeGen:
1197
1202
  @classmethod
1198
1203
  def _build(cls, code: str) -> str:
1199
1204
  generated_code = cls._PREFIX + '\n' + code
1200
-
1201
- return f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}'
1205
+ # Use the local user id to make sure the operation goes to the correct
1206
+ # user.
1207
+ return (
1208
+ f'export {constants.USER_ID_ENV_VAR}='
1209
+ f'"{common_utils.get_user_hash()}"; '
1210
+ f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}')
sky/models.py ADDED
@@ -0,0 +1,27 @@
1
+ """Data Models for SkyPilot."""
2
+
3
+ import collections
4
+ import dataclasses
5
+ from typing import Dict, Optional
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class User:
10
+ # User hash
11
+ id: str
12
+ # Display name of the user
13
+ name: Optional[str] = None
14
+
15
+
16
+ RealtimeGpuAvailability = collections.namedtuple(
17
+ 'RealtimeGpuAvailability', ['gpu', 'counts', 'capacity', 'available'])
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class KubernetesNodeInfo:
22
+ """Dataclass to store Kubernetes node information."""
23
+ name: str
24
+ accelerator_type: Optional[str]
25
+ # Resources available on the node. E.g., {'nvidia.com/gpu': '2'}
26
+ total: Dict[str, int]
27
+ free: Dict[str, int]