skypilot-nightly 1.0.0.dev20251026__py3-none-any.whl → 1.0.0.dev20251029__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (82) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/coreweave.py +278 -0
  3. sky/backends/backend_utils.py +9 -6
  4. sky/backends/cloud_vm_ray_backend.py +2 -3
  5. sky/check.py +25 -13
  6. sky/client/cli/command.py +34 -15
  7. sky/client/sdk.py +4 -4
  8. sky/cloud_stores.py +73 -0
  9. sky/core.py +7 -5
  10. sky/dashboard/out/404.html +1 -1
  11. sky/dashboard/out/_next/static/{wDQ7aGvICzMNmjIaC37TT → DabuSAKsc_y0wyJxpTIdQ}/_buildManifest.js +1 -1
  12. sky/dashboard/out/_next/static/chunks/{1141-d5204f35a3388bf4.js → 1141-c3c10e2c6ed71a8f.js} +1 -1
  13. sky/dashboard/out/_next/static/chunks/2755.a239c652bf8684dd.js +26 -0
  14. sky/dashboard/out/_next/static/chunks/3294.87a13fba0058865b.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/{3785.538eb23a098fc304.js → 3785.170be320e0060eaf.js} +1 -1
  16. sky/dashboard/out/_next/static/chunks/4282-49b2065b7336e496.js +1 -0
  17. sky/dashboard/out/_next/static/chunks/7615-80aa7b09f45a86d2.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/8969-4ed9236db997b42b.js +1 -0
  19. sky/dashboard/out/_next/static/chunks/9360.10a3aac7aad5e3aa.js +31 -0
  20. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-ac4a217f17b087cb.js +16 -0
  21. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-fbf2907ce2bb67e2.js → [cluster]-1704039ccaf997cf.js} +1 -1
  22. sky/dashboard/out/_next/static/chunks/pages/{jobs-0dc34cf9a8710a9f.js → jobs-7eee823559e5cf9f.js} +1 -1
  23. sky/dashboard/out/_next/static/chunks/pages/{users-96d6b8bb2dec055f.js → users-2b172f13f8538a7a.js} +1 -1
  24. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-fb1b4d3bfb047cad.js → [name]-bbfe5860c93470fd.js} +1 -1
  25. sky/dashboard/out/_next/static/chunks/pages/{workspaces-6fc994fa1ee6c6bf.js → workspaces-1891376c08050940.js} +1 -1
  26. sky/dashboard/out/_next/static/chunks/{webpack-4abaae354da0ba13.js → webpack-485984ca04e021d0.js} +1 -1
  27. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  28. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  29. sky/dashboard/out/clusters/[cluster].html +1 -1
  30. sky/dashboard/out/clusters.html +1 -1
  31. sky/dashboard/out/config.html +1 -1
  32. sky/dashboard/out/index.html +1 -1
  33. sky/dashboard/out/infra/[context].html +1 -1
  34. sky/dashboard/out/infra.html +1 -1
  35. sky/dashboard/out/jobs/[job].html +1 -1
  36. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  37. sky/dashboard/out/jobs.html +1 -1
  38. sky/dashboard/out/users.html +1 -1
  39. sky/dashboard/out/volumes.html +1 -1
  40. sky/dashboard/out/workspace/new.html +1 -1
  41. sky/dashboard/out/workspaces/[name].html +1 -1
  42. sky/dashboard/out/workspaces.html +1 -1
  43. sky/data/data_utils.py +92 -1
  44. sky/data/mounting_utils.py +39 -0
  45. sky/data/storage.py +166 -9
  46. sky/global_user_state.py +59 -83
  47. sky/jobs/server/server.py +2 -2
  48. sky/jobs/utils.py +5 -6
  49. sky/optimizer.py +1 -1
  50. sky/provision/kubernetes/instance.py +88 -19
  51. sky/provision/kubernetes/volume.py +2 -2
  52. sky/schemas/api/responses.py +2 -5
  53. sky/serve/replica_managers.py +2 -2
  54. sky/serve/serve_utils.py +9 -2
  55. sky/server/requests/payloads.py +2 -0
  56. sky/server/requests/requests.py +182 -84
  57. sky/server/requests/serializers/decoders.py +3 -3
  58. sky/server/requests/serializers/encoders.py +33 -6
  59. sky/server/server.py +34 -7
  60. sky/server/stream_utils.py +56 -13
  61. sky/setup_files/dependencies.py +2 -0
  62. sky/task.py +10 -0
  63. sky/templates/nebius-ray.yml.j2 +1 -0
  64. sky/utils/cli_utils/status_utils.py +8 -2
  65. sky/utils/context_utils.py +13 -1
  66. sky/utils/resources_utils.py +53 -29
  67. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/METADATA +50 -34
  68. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/RECORD +74 -73
  69. sky/dashboard/out/_next/static/chunks/2755.227c84f5adf75c6b.js +0 -26
  70. sky/dashboard/out/_next/static/chunks/3015-2dcace420c8939f4.js +0 -1
  71. sky/dashboard/out/_next/static/chunks/3294.6d5054a953a818cb.js +0 -1
  72. sky/dashboard/out/_next/static/chunks/4282-d2f3ef2fbf78e347.js +0 -1
  73. sky/dashboard/out/_next/static/chunks/8969-0389e2cb52412db3.js +0 -1
  74. sky/dashboard/out/_next/static/chunks/9360.07d78b8552bc9d17.js +0 -31
  75. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c815b90e296b8075.js +0 -16
  76. sky/dashboard/out/_next/static/css/4c052b4444e52a58.css +0 -3
  77. /sky/dashboard/out/_next/static/{wDQ7aGvICzMNmjIaC37TT → DabuSAKsc_y0wyJxpTIdQ}/_ssgManifest.js +0 -0
  78. /sky/dashboard/out/_next/static/chunks/pages/{_app-513d332313670f2a.js → _app-bde01e4a2beec258.js} +0 -0
  79. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/WHEEL +0 -0
  80. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/entry_points.txt +0 -0
  81. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/licenses/LICENSE +0 -0
  82. {skypilot_nightly-1.0.0.dev20251026.dist-info → skypilot_nightly-1.0.0.dev20251029.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ from sky.utils.db import db_utils
33
33
  POLL_INTERVAL = 2
34
34
  _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes
35
35
  _MAX_RETRIES = 3
36
+ _MAX_MISSING_PODS_RETRIES = 5
36
37
  _NUM_THREADS = subprocess_utils.get_parallel_threads('kubernetes')
37
38
 
38
39
  # Pattern to extract SSH user from command output, handling MOTD contamination
@@ -489,17 +490,17 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int,
489
490
 
490
491
 
491
492
  @timeline.event
492
- def _wait_for_pods_to_run(namespace, context, new_nodes):
493
+ def _wait_for_pods_to_run(namespace, context, cluster_name, new_pods):
493
494
  """Wait for pods and their containers to be ready.
494
495
 
495
496
  Pods may be pulling images or may be in the process of container
496
497
  creation.
497
498
  """
498
- if not new_nodes:
499
+ if not new_pods:
499
500
  return
500
501
 
501
502
  # Create a set of pod names we're waiting for
502
- expected_pod_names = {node.metadata.name for node in new_nodes}
503
+ expected_pod_names = {pod.metadata.name for pod in new_pods}
503
504
 
504
505
  def _check_init_containers(pod):
505
506
  # Check if any of the init containers failed
@@ -526,28 +527,62 @@ def _wait_for_pods_to_run(namespace, context, new_nodes):
526
527
  'Failed to create init container for pod '
527
528
  f'{pod.metadata.name}. Error details: {msg}.')
528
529
 
530
+ missing_pods_retry = 0
529
531
  while True:
530
532
  # Get all pods in a single API call
531
- cluster_name = new_nodes[0].metadata.labels[
533
+ cluster_name_on_cloud = new_pods[0].metadata.labels[
532
534
  k8s_constants.TAG_SKYPILOT_CLUSTER_NAME]
533
535
  all_pods = kubernetes.core_api(context).list_namespaced_pod(
534
536
  namespace,
535
537
  label_selector=
536
- f'{k8s_constants.TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items
538
+ f'{k8s_constants.TAG_SKYPILOT_CLUSTER_NAME}={cluster_name_on_cloud}'
539
+ ).items
537
540
 
538
541
  # Get the set of found pod names and check if we have all expected pods
539
542
  found_pod_names = {pod.metadata.name for pod in all_pods}
540
- missing_pods = expected_pod_names - found_pod_names
541
- if missing_pods:
543
+ missing_pod_names = expected_pod_names - found_pod_names
544
+ if missing_pod_names:
545
+ # In _wait_for_pods_to_schedule, we already wait for all pods to go
546
+ # from pending to scheduled. So if a pod is missing here, it means
547
+ # something unusual must have happened, and so should be treated as
548
+ # an exception.
549
+ # It is also only in _wait_for_pods_to_schedule that
550
+ # provision_timeout is used.
551
+ # TODO(kevin): Should we take provision_timeout into account here,
552
+ # instead of hardcoding the number of retries?
553
+ if missing_pods_retry >= _MAX_MISSING_PODS_RETRIES:
554
+ for pod_name in missing_pod_names:
555
+ reason = _get_pod_missing_reason(context, namespace,
556
+ cluster_name, pod_name)
557
+ logger.warning(f'Pod {pod_name} missing: {reason}')
558
+ raise config_lib.KubernetesError(
559
+ f'Failed to get all pods after {missing_pods_retry} '
560
+ f'retries. Some pods may have been terminated or failed '
561
+ f'unexpectedly. Run `sky logs --provision {cluster_name}` '
562
+ 'for more details.')
542
563
  logger.info('Retrying running pods check: '
543
- f'Missing pods: {missing_pods}')
564
+ f'Missing pods: {missing_pod_names}')
544
565
  time.sleep(0.5)
566
+ missing_pods_retry += 1
545
567
  continue
546
568
 
547
569
  all_pods_running = True
548
570
  for pod in all_pods:
549
571
  if pod.metadata.name not in expected_pod_names:
550
572
  continue
573
+
574
+ # Check if pod is terminated/preempted/failed.
575
+ if (pod.metadata.deletion_timestamp is not None or
576
+ pod.status.phase == 'Failed'):
577
+ # Get the reason and write to cluster events before
578
+ # the pod gets completely deleted from the API.
579
+ reason = _get_pod_termination_reason(pod, cluster_name)
580
+ logger.warning(f'Pod {pod.metadata.name} terminated: {reason}')
581
+ raise config_lib.KubernetesError(
582
+ f'Pod {pod.metadata.name} has terminated or failed '
583
+ f'unexpectedly. Run `sky logs --provision {cluster_name}` '
584
+ 'for more details.')
585
+
551
586
  # Continue if pod and all the containers within the
552
587
  # pod are successfully created and running.
553
588
  if pod.status.phase == 'Running' and all(
@@ -1169,7 +1204,7 @@ def _create_pods(region: str, cluster_name: str, cluster_name_on_cloud: str,
1169
1204
  # fail early if there is an error
1170
1205
  logger.debug(f'run_instances: waiting for pods to be running (pulling '
1171
1206
  f'images): {[pod.metadata.name for pod in pods]}')
1172
- _wait_for_pods_to_run(namespace, context, pods)
1207
+ _wait_for_pods_to_run(namespace, context, cluster_name, pods)
1173
1208
  logger.debug(f'run_instances: all pods are scheduled and running: '
1174
1209
  f'{[pod.metadata.name for pod in pods]}')
1175
1210
 
@@ -1428,9 +1463,45 @@ def get_cluster_info(
1428
1463
 
1429
1464
 
1430
1465
  def _get_pod_termination_reason(pod: Any, cluster_name: str) -> str:
1431
- """Get pod termination reason and write to cluster events."""
1432
- reasons = []
1466
+ """Get pod termination reason and write to cluster events.
1467
+
1468
+ Checks both pod conditions (for preemption/disruption) and
1469
+ container statuses (for exit codes/errors).
1470
+ """
1433
1471
  latest_timestamp = pod.status.start_time or datetime.datetime.min
1472
+ ready_state = 'Unknown'
1473
+ termination_reason = 'Terminated unexpectedly'
1474
+ container_reasons = []
1475
+
1476
+ # Check pod status conditions for high level overview.
1477
+ # No need to sort, as each condition.type will only appear once.
1478
+ for condition in pod.status.conditions:
1479
+ reason = condition.reason or 'Unknown reason'
1480
+ message = condition.message or ''
1481
+
1482
+ # Get last known readiness state.
1483
+ if condition.type == 'Ready':
1484
+ ready_state = f'{reason} ({message})' if message else reason
1485
+ # Kueue preemption, as defined in:
1486
+ # https://pkg.go.dev/sigs.k8s.io/kueue/pkg/controller/jobs/pod#pkg-constants
1487
+ elif condition.type == 'TerminationTarget':
1488
+ termination_reason = f'Preempted by Kueue: {reason}'
1489
+ if message:
1490
+ termination_reason += f' ({message})'
1491
+ # Generic disruption.
1492
+ elif condition.type == 'DisruptionTarget':
1493
+ termination_reason = f'Disrupted: {reason}'
1494
+ if message:
1495
+ termination_reason += f' ({message})'
1496
+
1497
+ if condition.last_transition_time is not None:
1498
+ latest_timestamp = max(latest_timestamp,
1499
+ condition.last_transition_time)
1500
+
1501
+ pod_reason = (f'{termination_reason}.\n'
1502
+ f'Last known state: {ready_state}.')
1503
+
1504
+ # Check container statuses for exit codes/errors
1434
1505
  if pod.status and pod.status.container_statuses:
1435
1506
  for container_status in pod.status.container_statuses:
1436
1507
  terminated = container_status.state.terminated
@@ -1445,18 +1516,15 @@ def _get_pod_termination_reason(pod: Any, cluster_name: str) -> str:
1445
1516
  if reason is None:
1446
1517
  # just in-case reason is None, have default for debugging
1447
1518
  reason = f'exit({exit_code})'
1448
- reasons.append(reason)
1449
- if terminated.finished_at > latest_timestamp:
1450
- latest_timestamp = terminated.finished_at
1519
+ container_reasons.append(reason)
1520
+ latest_timestamp = max(latest_timestamp, terminated.finished_at)
1451
1521
 
1452
1522
  # TODO (kyuds): later, if needed, query `last_state` too.
1453
1523
 
1454
- if not reasons:
1455
- return ''
1456
-
1457
1524
  # Normally we will have a single container per pod for skypilot
1458
1525
  # but doing this just in-case there are multiple containers.
1459
- pod_reason = ' | '.join(reasons)
1526
+ if container_reasons:
1527
+ pod_reason += f'\nContainer errors: {" | ".join(container_reasons)}'
1460
1528
 
1461
1529
  global_user_state.add_cluster_event(
1462
1530
  cluster_name,
@@ -1658,9 +1726,10 @@ def query_instances(
1658
1726
  Optional[str]]] = {}
1659
1727
  for pod in pods:
1660
1728
  phase = pod.status.phase
1729
+ is_terminating = pod.metadata.deletion_timestamp is not None
1661
1730
  pod_status = status_map[phase]
1662
1731
  reason = None
1663
- if phase in ('Failed', 'Unknown'):
1732
+ if phase in ('Failed', 'Unknown') or is_terminating:
1664
1733
  reason = _get_pod_termination_reason(pod, cluster_name)
1665
1734
  logger.debug(f'Pod Status ({phase}) Reason(s): {reason}')
1666
1735
  if non_terminated_only and pod_status is None:
@@ -75,7 +75,6 @@ def delete_volume(config: models.VolumeConfig) -> models.VolumeConfig:
75
75
  """Deletes a volume."""
76
76
  context, namespace = _get_context_namespace(config)
77
77
  pvc_name = config.name_on_cloud
78
- logger.info(f'Deleting PVC {pvc_name}')
79
78
  kubernetes_utils.delete_k8s_resource_with_retry(
80
79
  delete_func=lambda pvc_name=pvc_name: kubernetes.core_api(
81
80
  context).delete_namespaced_persistent_volume_claim(
@@ -84,6 +83,7 @@ def delete_volume(config: models.VolumeConfig) -> models.VolumeConfig:
84
83
  _request_timeout=config_lib.DELETION_TIMEOUT),
85
84
  resource_type='pvc',
86
85
  resource_name=pvc_name)
86
+ logger.info(f'Deleted PVC {pvc_name} in namespace {namespace}')
87
87
  return config
88
88
 
89
89
 
@@ -242,9 +242,9 @@ def create_persistent_volume_claim(namespace: str, context: Optional[str],
242
242
  except kubernetes.api_exception() as e:
243
243
  if e.status != 404: # Not found
244
244
  raise
245
- logger.info(f'Creating PVC {pvc_name}')
246
245
  kubernetes.core_api(context).create_namespaced_persistent_volume_claim(
247
246
  namespace=namespace, body=pvc_spec)
247
+ logger.info(f'Created PVC {pvc_name} in namespace {namespace}')
248
248
 
249
249
 
250
250
  def _get_pvc_spec(namespace: str,
@@ -90,7 +90,7 @@ class StatusResponse(ResponseBaseModel):
90
90
  # This is an internally facing field anyway, so it's less
91
91
  # of a problem that it's not typed.
92
92
  handle: Optional[Any] = None
93
- last_use: str
93
+ last_use: Optional[str] = None
94
94
  status: status_lib.ClusterStatus
95
95
  autostop: int
96
96
  to_down: bool
@@ -98,11 +98,8 @@ class StatusResponse(ResponseBaseModel):
98
98
  # metadata is a JSON, so we use Any here.
99
99
  metadata: Optional[Dict[str, Any]] = None
100
100
  cluster_hash: str
101
- # pydantic cannot generate the pydantic-core schema for
102
- # storage_mounts_metadata, so we use Any here.
103
- storage_mounts_metadata: Optional[Dict[str, Any]] = None
104
101
  cluster_ever_up: bool
105
- status_updated_at: int
102
+ status_updated_at: Optional[int] = None
106
103
  user_hash: str
107
104
  user_name: str
108
105
  config_hash: Optional[str] = None
@@ -495,8 +495,8 @@ class ReplicaInfo:
495
495
  info_dict['cloud'] = repr(handle.launched_resources.cloud)
496
496
  info_dict['region'] = handle.launched_resources.region
497
497
  info_dict['resources_str'] = (
498
- resources_utils.get_readable_resources_repr(handle,
499
- simplify=True))
498
+ resources_utils.get_readable_resources_repr(
499
+ handle, simplified_only=True)[0])
500
500
  return info_dict
501
501
 
502
502
  def __repr__(self) -> str:
sky/serve/serve_utils.py CHANGED
@@ -1550,8 +1550,15 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], show_all: bool,
1550
1550
  'handle']
1551
1551
  if replica_handle is not None:
1552
1552
  infra = replica_handle.launched_resources.infra.formatted_str()
1553
- resources_str = resources_utils.get_readable_resources_repr(
1554
- replica_handle, simplify=not show_all)
1553
+ simplified = not show_all
1554
+ resources_str_simple, resources_str_full = (
1555
+ resources_utils.get_readable_resources_repr(
1556
+ replica_handle, simplified_only=simplified))
1557
+ if simplified:
1558
+ resources_str = resources_str_simple
1559
+ else:
1560
+ assert resources_str_full is not None
1561
+ resources_str = resources_str_full
1555
1562
 
1556
1563
  replica_values = [
1557
1564
  service_name,
@@ -319,6 +319,8 @@ class StatusBody(RequestBody):
319
319
  # Only return fields that are needed for the
320
320
  # dashboard / CLI summary response
321
321
  summary_response: bool = False
322
+ # Include the cluster handle in the response
323
+ include_handle: bool = True
322
324
 
323
325
 
324
326
  class StartBody(RequestBody):
@@ -5,7 +5,6 @@ import contextlib
5
5
  import dataclasses
6
6
  import enum
7
7
  import functools
8
- import json
9
8
  import os
10
9
  import pathlib
11
10
  import shutil
@@ -21,6 +20,7 @@ import uuid
21
20
  import anyio
22
21
  import colorama
23
22
  import filelock
23
+ import orjson
24
24
 
25
25
  from sky import exceptions
26
26
  from sky import global_user_state
@@ -213,8 +213,8 @@ class Request:
213
213
  entrypoint=self.entrypoint.__name__,
214
214
  request_body=self.request_body.model_dump_json(),
215
215
  status=self.status.value,
216
- return_value=json.dumps(None),
217
- error=json.dumps(None),
216
+ return_value=orjson.dumps(None).decode('utf-8'),
217
+ error=orjson.dumps(None).decode('utf-8'),
218
218
  pid=None,
219
219
  created_at=self.created_at,
220
220
  schedule_type=self.schedule_type.value,
@@ -237,8 +237,8 @@ class Request:
237
237
  entrypoint=encoders.pickle_and_encode(self.entrypoint),
238
238
  request_body=encoders.pickle_and_encode(self.request_body),
239
239
  status=self.status.value,
240
- return_value=json.dumps(self.return_value),
241
- error=json.dumps(self.error),
240
+ return_value=orjson.dumps(self.return_value).decode('utf-8'),
241
+ error=orjson.dumps(self.error).decode('utf-8'),
242
242
  pid=self.pid,
243
243
  created_at=self.created_at,
244
244
  schedule_type=self.schedule_type.value,
@@ -270,8 +270,8 @@ class Request:
270
270
  entrypoint=decoders.decode_and_unpickle(payload.entrypoint),
271
271
  request_body=decoders.decode_and_unpickle(payload.request_body),
272
272
  status=RequestStatus(payload.status),
273
- return_value=json.loads(payload.return_value),
274
- error=json.loads(payload.error),
273
+ return_value=orjson.loads(payload.return_value),
274
+ error=orjson.loads(payload.error),
275
275
  pid=payload.pid,
276
276
  created_at=payload.created_at,
277
277
  schedule_type=ScheduleType(payload.schedule_type),
@@ -328,10 +328,11 @@ def encode_requests(requests: List[Request]) -> List[payloads.RequestPayload]:
328
328
  entrypoint=request.entrypoint.__name__
329
329
  if request.entrypoint is not None else '',
330
330
  request_body=request.request_body.model_dump_json()
331
- if request.request_body is not None else json.dumps(None),
331
+ if request.request_body is not None else
332
+ orjson.dumps(None).decode('utf-8'),
332
333
  status=request.status.value,
333
- return_value=json.dumps(None),
334
- error=json.dumps(None),
334
+ return_value=orjson.dumps(None).decode('utf-8'),
335
+ error=orjson.dumps(None).decode('utf-8'),
335
336
  pid=None,
336
337
  created_at=request.created_at,
337
338
  schedule_type=request.schedule_type.value,
@@ -372,9 +373,9 @@ def _update_request_row_fields(
372
373
  if 'user_id' not in fields:
373
374
  content['user_id'] = ''
374
375
  if 'return_value' not in fields:
375
- content['return_value'] = json.dumps(None)
376
+ content['return_value'] = orjson.dumps(None).decode('utf-8')
376
377
  if 'error' not in fields:
377
- content['error'] = json.dumps(None)
378
+ content['error'] = orjson.dumps(None).decode('utf-8')
378
379
  if 'schedule_type' not in fields:
379
380
  content['schedule_type'] = ScheduleType.SHORT.value
380
381
  # Optional fields in RequestPayload
@@ -393,76 +394,6 @@ def _update_request_row_fields(
393
394
  return tuple(content[col] for col in REQUEST_COLUMNS)
394
395
 
395
396
 
396
- def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
397
- """Kill all pending and running requests for a cluster.
398
-
399
- Args:
400
- cluster_name: the name of the cluster.
401
- exclude_request_names: exclude requests with these names. This is to
402
- prevent killing the caller request.
403
- """
404
- request_ids = [
405
- request_task.request_id
406
- for request_task in get_request_tasks(req_filter=RequestTaskFilter(
407
- status=[RequestStatus.PENDING, RequestStatus.RUNNING],
408
- exclude_request_names=[exclude_request_name],
409
- cluster_names=[cluster_name],
410
- fields=['request_id']))
411
- ]
412
- kill_requests(request_ids)
413
-
414
-
415
- def kill_requests(request_ids: Optional[List[str]] = None,
416
- user_id: Optional[str] = None) -> List[str]:
417
- """Kill a SkyPilot API request and set its status to cancelled.
418
-
419
- Args:
420
- request_ids: The request IDs to kill. If None, all requests for the
421
- user are killed.
422
- user_id: The user ID to kill requests for. If None, all users are
423
- killed.
424
-
425
- Returns:
426
- A list of request IDs that were cancelled.
427
- """
428
- if request_ids is None:
429
- request_ids = [
430
- request_task.request_id
431
- for request_task in get_request_tasks(req_filter=RequestTaskFilter(
432
- status=[RequestStatus.PENDING, RequestStatus.RUNNING],
433
- # Avoid cancelling the cancel request itself.
434
- exclude_request_names=['sky.api_cancel'],
435
- user_id=user_id,
436
- fields=['request_id']))
437
- ]
438
- cancelled_request_ids = []
439
- for request_id in request_ids:
440
- with update_request(request_id) as request_record:
441
- if request_record is None:
442
- logger.debug(f'No request ID {request_id}')
443
- continue
444
- # Skip internal requests. The internal requests are scheduled with
445
- # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
446
- if request_record.request_id in set(
447
- event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
448
- continue
449
- if request_record.status > RequestStatus.RUNNING:
450
- logger.debug(f'Request {request_id} already finished')
451
- continue
452
- if request_record.pid is not None:
453
- logger.debug(f'Killing request process {request_record.pid}')
454
- # Use SIGTERM instead of SIGKILL:
455
- # - The executor can handle SIGTERM gracefully
456
- # - After SIGTERM, the executor can reuse the request process
457
- # for other requests, avoiding the overhead of forking a new
458
- # process for each request.
459
- os.kill(request_record.pid, signal.SIGTERM)
460
- request_record.status = RequestStatus.CANCELLED
461
- request_record.finished_at = time.time()
462
- cancelled_request_ids.append(request_id)
463
- return cancelled_request_ids
464
-
465
-
466
397
  def create_table(cursor, conn):
467
398
  # Enable WAL mode to avoid locking issues.
468
399
  # See: issue #1441 and PR #1509
@@ -606,6 +537,128 @@ def request_lock_path(request_id: str) -> str:
606
537
  return os.path.join(lock_path, f'.{request_id}.lock')
607
538
 
608
539
 
540
+ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
541
+ """Kill all pending and running requests for a cluster.
542
+
543
+ Args:
544
+ cluster_name: the name of the cluster.
545
+ exclude_request_names: exclude requests with these names. This is to
546
+ prevent killing the caller request.
547
+ """
548
+ request_ids = [
549
+ request_task.request_id
550
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
551
+ status=[RequestStatus.PENDING, RequestStatus.RUNNING],
552
+ exclude_request_names=[exclude_request_name],
553
+ cluster_names=[cluster_name],
554
+ fields=['request_id']))
555
+ ]
556
+ _kill_requests(request_ids)
557
+
558
+
559
+ def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
560
+ user_id: Optional[str] = None) -> List[str]:
561
+ """Kill requests with a given request ID prefix."""
562
+ expanded_request_ids: Optional[List[str]] = None
563
+ if request_ids is not None:
564
+ expanded_request_ids = []
565
+ for request_id in request_ids:
566
+ request_tasks = get_requests_with_prefix(request_id,
567
+ fields=['request_id'])
568
+ if request_tasks is None or len(request_tasks) == 0:
569
+ continue
570
+ if len(request_tasks) > 1:
571
+ raise ValueError(f'Multiple requests found for '
572
+ f'request ID prefix: {request_id}')
573
+ expanded_request_ids.append(request_tasks[0].request_id)
574
+ return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
575
+
576
+
577
+ def _should_kill_request(request_id: str,
578
+ request_record: Optional[Request]) -> bool:
579
+ if request_record is None:
580
+ logger.debug(f'No request ID {request_id}')
581
+ return False
582
+ # Skip internal requests. The internal requests are scheduled with
583
+ # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
584
+ if request_record.request_id in set(
585
+ event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
586
+ return False
587
+ if request_record.status > RequestStatus.RUNNING:
588
+ logger.debug(f'Request {request_id} already finished')
589
+ return False
590
+ return True
591
+
592
+
593
+ def _kill_requests(request_ids: Optional[List[str]] = None,
594
+ user_id: Optional[str] = None) -> List[str]:
595
+ """Kill a SkyPilot API request and set its status to cancelled.
596
+
597
+ Args:
598
+ request_ids: The request IDs to kill. If None, all requests for the
599
+ user are killed.
600
+ user_id: The user ID to kill requests for. If None, all users are
601
+ killed.
602
+
603
+ Returns:
604
+ A list of request IDs that were cancelled.
605
+ """
606
+ if request_ids is None:
607
+ request_ids = [
608
+ request_task.request_id
609
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
610
+ status=[RequestStatus.PENDING, RequestStatus.RUNNING],
611
+ # Avoid cancelling the cancel request itself.
612
+ exclude_request_names=['sky.api_cancel'],
613
+ user_id=user_id,
614
+ fields=['request_id']))
615
+ ]
616
+ cancelled_request_ids = []
617
+ for request_id in request_ids:
618
+ with update_request(request_id) as request_record:
619
+ if not _should_kill_request(request_id, request_record):
620
+ continue
621
+ if request_record.pid is not None:
622
+ logger.debug(f'Killing request process {request_record.pid}')
623
+ # Use SIGTERM instead of SIGKILL:
624
+ # - The executor can handle SIGTERM gracefully
625
+ # - After SIGTERM, the executor can reuse the request process
626
+ # for other requests, avoiding the overhead of forking a new
627
+ # process for each request.
628
+ os.kill(request_record.pid, signal.SIGTERM)
629
+ request_record.status = RequestStatus.CANCELLED
630
+ request_record.finished_at = time.time()
631
+ cancelled_request_ids.append(request_id)
632
+ return cancelled_request_ids
633
+
634
+
635
+ @init_db_async
636
+ @asyncio_utils.shield
637
+ async def kill_request_async(request_id: str) -> bool:
638
+ """Kill a SkyPilot API request and set its status to cancelled.
639
+
640
+ Returns:
641
+ True if the request was killed, False otherwise.
642
+ """
643
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
644
+ request = await _get_request_no_lock_async(request_id)
645
+ if not _should_kill_request(request_id, request):
646
+ return False
647
+ assert request is not None
648
+ if request.pid is not None:
649
+ logger.debug(f'Killing request process {request.pid}')
650
+ # Use SIGTERM instead of SIGKILL:
651
+ # - The executor can handle SIGTERM gracefully
652
+ # - After SIGTERM, the executor can reuse the request process
653
+ # for other requests, avoiding the overhead of forking a new
654
+ # process for each request.
655
+ os.kill(request.pid, signal.SIGTERM)
656
+ request.status = RequestStatus.CANCELLED
657
+ request.finished_at = time.time()
658
+ await _add_or_update_request_no_lock_async(request)
659
+ return True
660
+
661
+
609
662
  @contextlib.contextmanager
610
663
  @init_db
611
664
  @metrics_lib.time_me
@@ -620,7 +673,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
620
673
  _add_or_update_request_no_lock(request)
621
674
 
622
675
 
623
- @init_db
676
+ @init_db_async
624
677
  @metrics_lib.time_me
625
678
  @asyncio_utils.shield
626
679
  async def update_status_async(request_id: str, status: RequestStatus) -> None:
@@ -632,7 +685,7 @@ async def update_status_async(request_id: str, status: RequestStatus) -> None:
632
685
  await _add_or_update_request_no_lock_async(request)
633
686
 
634
687
 
635
- @init_db
688
+ @init_db_async
636
689
  @metrics_lib.time_me
637
690
  @asyncio_utils.shield
638
691
  async def update_status_msg_async(request_id: str, status_msg: str) -> None:
@@ -715,6 +768,51 @@ async def get_request_async(
715
768
  return await _get_request_no_lock_async(request_id, fields)
716
769
 
717
770
 
771
+ @init_db
772
+ @metrics_lib.time_me
773
+ def get_requests_with_prefix(
774
+ request_id_prefix: str,
775
+ fields: Optional[List[str]] = None) -> Optional[List[Request]]:
776
+ """Get requests with a given request ID prefix."""
777
+ assert _DB is not None
778
+ if fields:
779
+ columns_str = ', '.join(fields)
780
+ else:
781
+ columns_str = ', '.join(REQUEST_COLUMNS)
782
+ with _DB.conn:
783
+ cursor = _DB.conn.cursor()
784
+ cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
785
+ 'WHERE request_id LIKE ?'), (request_id_prefix + '%',))
786
+ rows = cursor.fetchall()
787
+ if not rows:
788
+ return None
789
+ if fields:
790
+ rows = [_update_request_row_fields(row, fields) for row in rows]
791
+ return [Request.from_row(row) for row in rows]
792
+
793
+
794
+ @init_db_async
795
+ @metrics_lib.time_me_async
796
+ @asyncio_utils.shield
797
+ async def get_requests_async_with_prefix(
798
+ request_id_prefix: str,
799
+ fields: Optional[List[str]] = None) -> Optional[List[Request]]:
800
+ """Async version of get_request_with_prefix."""
801
+ assert _DB is not None
802
+ if fields:
803
+ columns_str = ', '.join(fields)
804
+ else:
805
+ columns_str = ', '.join(REQUEST_COLUMNS)
806
+ async with _DB.execute_fetchall_async(
807
+ (f'SELECT {columns_str} FROM {REQUEST_TABLE} '
808
+ 'WHERE request_id LIKE ?'), (request_id_prefix + '%',)) as rows:
809
+ if not rows:
810
+ return None
811
+ if fields:
812
+ rows = [_update_request_row_fields(row, fields) for row in rows]
813
+ return [Request.from_row(row) for row in rows]
814
+
815
+
718
816
  class StatusWithMsg(NamedTuple):
719
817
  status: RequestStatus
720
818
  status_msg: Optional[str] = None
@@ -56,10 +56,10 @@ def decode_status(
56
56
  clusters = return_value
57
57
  response = []
58
58
  for cluster in clusters:
59
- cluster['handle'] = decode_and_unpickle(cluster['handle'])
59
+ # handle may not always be present in the response.
60
+ if 'handle' in cluster and cluster['handle'] is not None:
61
+ cluster['handle'] = decode_and_unpickle(cluster['handle'])
60
62
  cluster['status'] = status_lib.ClusterStatus(cluster['status'])
61
- cluster['storage_mounts_metadata'] = decode_and_unpickle(
62
- cluster['storage_mounts_metadata'])
63
63
  if 'is_managed' not in cluster:
64
64
  cluster['is_managed'] = False
65
65
  response.append(responses.StatusResponse.model_validate(cluster))