skypilot-nightly 1.0.0.dev20251019__py3-none-any.whl → 1.0.0.dev20251021__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (49) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/kubernetes.py +61 -0
  3. sky/backends/backend_utils.py +11 -11
  4. sky/backends/cloud_vm_ray_backend.py +15 -4
  5. sky/client/cli/command.py +39 -10
  6. sky/client/cli/flags.py +4 -2
  7. sky/client/sdk.py +26 -3
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/_next/static/chunks/{webpack-3c431f6c9086e487.js → webpack-66f23594d38c7f16.js} +1 -1
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/config.html +1 -1
  14. sky/dashboard/out/index.html +1 -1
  15. sky/dashboard/out/infra/[context].html +1 -1
  16. sky/dashboard/out/infra.html +1 -1
  17. sky/dashboard/out/jobs/[job].html +1 -1
  18. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  19. sky/dashboard/out/jobs.html +1 -1
  20. sky/dashboard/out/users.html +1 -1
  21. sky/dashboard/out/volumes.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/data/storage.py +2 -2
  26. sky/global_user_state.py +20 -20
  27. sky/jobs/server/server.py +10 -1
  28. sky/provision/kubernetes/network.py +9 -6
  29. sky/provision/provisioner.py +8 -0
  30. sky/serve/server/server.py +1 -0
  31. sky/server/common.py +9 -2
  32. sky/server/constants.py +1 -1
  33. sky/server/daemons.py +4 -2
  34. sky/server/requests/executor.py +10 -8
  35. sky/server/requests/payloads.py +2 -1
  36. sky/server/requests/preconditions.py +9 -4
  37. sky/server/requests/requests.py +118 -34
  38. sky/server/server.py +57 -24
  39. sky/server/stream_utils.py +127 -38
  40. sky/server/uvicorn.py +18 -17
  41. sky/utils/asyncio_utils.py +63 -3
  42. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/METADATA +35 -36
  43. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/RECORD +49 -49
  44. /sky/dashboard/out/_next/static/{8e35zdobdd0bK_Nkba03m → jDc1PlRsl9Cc5FQUMLBu8}/_buildManifest.js +0 -0
  45. /sky/dashboard/out/_next/static/{8e35zdobdd0bK_Nkba03m → jDc1PlRsl9Cc5FQUMLBu8}/_ssgManifest.js +0 -0
  46. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/WHEEL +0 -0
  47. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/entry_points.txt +0 -0
  48. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/licenses/LICENSE +0 -0
  49. {skypilot_nightly-1.0.0.dev20251019.dist-info → skypilot_nightly-1.0.0.dev20251021.dist-info}/top_level.txt +0 -0
@@ -400,7 +400,8 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
400
400
  for request_task in get_request_tasks(req_filter=RequestTaskFilter(
401
401
  status=[RequestStatus.PENDING, RequestStatus.RUNNING],
402
402
  exclude_request_names=[exclude_request_name],
403
- cluster_names=[cluster_name]))
403
+ cluster_names=[cluster_name],
404
+ fields=['request_id']))
404
405
  ]
405
406
  kill_requests(request_ids)
406
407
 
@@ -425,7 +426,8 @@ def kill_requests(request_ids: Optional[List[str]] = None,
425
426
  status=[RequestStatus.PENDING, RequestStatus.RUNNING],
426
427
  # Avoid cancelling the cancel request itself.
427
428
  exclude_request_names=['sky.api_cancel'],
428
- user_id=user_id))
429
+ user_id=user_id,
430
+ fields=['request_id']))
429
431
  ]
430
432
  cancelled_request_ids = []
431
433
  for request_id in request_ids:
@@ -592,6 +594,18 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
592
594
  _add_or_update_request_no_lock(request)
593
595
 
594
596
 
597
+ @init_db
598
+ @metrics_lib.time_me
599
+ @asyncio_utils.shield
600
+ async def update_status_async(request_id: str, status: RequestStatus) -> None:
601
+ """Update the status of a request"""
602
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
603
+ request = await _get_request_no_lock_async(request_id)
604
+ if request is not None:
605
+ request.status = status
606
+ await _add_or_update_request_no_lock_async(request)
607
+
608
+
595
609
  @init_db
596
610
  @metrics_lib.time_me
597
611
  @asyncio_utils.shield
@@ -604,30 +618,42 @@ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
604
618
  await _add_or_update_request_no_lock_async(request)
605
619
 
606
620
 
607
- _get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
608
- 'WHERE request_id LIKE ?')
609
-
610
-
611
- def _get_request_no_lock(request_id: str) -> Optional[Request]:
621
+ def _get_request_no_lock(
622
+ request_id: str,
623
+ fields: Optional[List[str]] = None) -> Optional[Request]:
612
624
  """Get a SkyPilot API request."""
613
625
  assert _DB is not None
626
+ columns_str = ', '.join(REQUEST_COLUMNS)
627
+ if fields:
628
+ columns_str = ', '.join(fields)
614
629
  with _DB.conn:
615
630
  cursor = _DB.conn.cursor()
616
- cursor.execute(_get_request_sql, (request_id + '%',))
631
+ cursor.execute((f'SELECT {columns_str} FROM {REQUEST_TABLE} '
632
+ 'WHERE request_id LIKE ?'), (request_id + '%',))
617
633
  row = cursor.fetchone()
618
634
  if row is None:
619
635
  return None
636
+ if fields:
637
+ row = _update_request_row_fields(row, fields)
620
638
  return Request.from_row(row)
621
639
 
622
640
 
623
- async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:
641
+ async def _get_request_no_lock_async(
642
+ request_id: str,
643
+ fields: Optional[List[str]] = None) -> Optional[Request]:
624
644
  """Async version of _get_request_no_lock."""
625
645
  assert _DB is not None
626
- async with _DB.execute_fetchall_async(_get_request_sql,
627
- (request_id + '%',)) as rows:
646
+ columns_str = ', '.join(REQUEST_COLUMNS)
647
+ if fields:
648
+ columns_str = ', '.join(fields)
649
+ async with _DB.execute_fetchall_async(
650
+ (f'SELECT {columns_str} FROM {REQUEST_TABLE} '
651
+ 'WHERE request_id LIKE ?'), (request_id + '%',)) as rows:
628
652
  row = rows[0] if rows else None
629
653
  if row is None:
630
654
  return None
655
+ if fields:
656
+ row = _update_request_row_fields(row, fields)
631
657
  return Request.from_row(row)
632
658
 
633
659
 
@@ -646,20 +672,23 @@ def get_latest_request_id() -> Optional[str]:
646
672
 
647
673
  @init_db
648
674
  @metrics_lib.time_me
649
- def get_request(request_id: str) -> Optional[Request]:
675
+ def get_request(request_id: str,
676
+ fields: Optional[List[str]] = None) -> Optional[Request]:
650
677
  """Get a SkyPilot API request."""
651
678
  with filelock.FileLock(request_lock_path(request_id)):
652
- return _get_request_no_lock(request_id)
679
+ return _get_request_no_lock(request_id, fields)
653
680
 
654
681
 
655
682
  @init_db_async
656
683
  @metrics_lib.time_me_async
657
684
  @asyncio_utils.shield
658
- async def get_request_async(request_id: str) -> Optional[Request]:
685
+ async def get_request_async(
686
+ request_id: str,
687
+ fields: Optional[List[str]] = None) -> Optional[Request]:
659
688
  """Async version of get_request."""
660
689
  # TODO(aylei): figure out how to remove FileLock here to avoid the overhead
661
690
  async with filelock.AsyncFileLock(request_lock_path(request_id)):
662
- return await _get_request_no_lock_async(request_id)
691
+ return await _get_request_no_lock_async(request_id, fields)
663
692
 
664
693
 
665
694
  class StatusWithMsg(NamedTuple):
@@ -896,6 +925,23 @@ def set_request_failed(request_id: str, e: BaseException) -> None:
896
925
  request_task.set_error(e)
897
926
 
898
927
 
928
+ @init_db_async
929
+ @metrics_lib.time_me_async
930
+ @asyncio_utils.shield
931
+ async def set_request_failed_async(request_id: str, e: BaseException) -> None:
932
+ """Set a request to failed and populate the error message."""
933
+ with ux_utils.enable_traceback():
934
+ stacktrace = traceback.format_exc()
935
+ setattr(e, 'stacktrace', stacktrace)
936
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
937
+ request_task = await _get_request_no_lock_async(request_id)
938
+ assert request_task is not None, request_id
939
+ request_task.status = RequestStatus.FAILED
940
+ request_task.finished_at = time.time()
941
+ request_task.set_error(e)
942
+ await _add_or_update_request_no_lock_async(request_task)
943
+
944
+
899
945
  def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
900
946
  """Set a request to succeeded and populate the result."""
901
947
  with update_request(request_id) as request_task:
@@ -906,28 +952,50 @@ def set_request_succeeded(request_id: str, result: Optional[Any]) -> None:
906
952
  request_task.set_return_value(result)
907
953
 
908
954
 
909
- def set_request_cancelled(request_id: str) -> None:
955
+ @init_db_async
956
+ @metrics_lib.time_me_async
957
+ @asyncio_utils.shield
958
+ async def set_request_succeeded_async(request_id: str,
959
+ result: Optional[Any]) -> None:
960
+ """Set a request to succeeded and populate the result."""
961
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
962
+ request_task = await _get_request_no_lock_async(request_id)
963
+ assert request_task is not None, request_id
964
+ request_task.status = RequestStatus.SUCCEEDED
965
+ request_task.finished_at = time.time()
966
+ if result is not None:
967
+ request_task.set_return_value(result)
968
+ await _add_or_update_request_no_lock_async(request_task)
969
+
970
+
971
+ @init_db_async
972
+ @metrics_lib.time_me_async
973
+ @asyncio_utils.shield
974
+ async def set_request_cancelled_async(request_id: str) -> None:
910
975
  """Set a pending or running request to cancelled."""
911
- with update_request(request_id) as request_task:
976
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
977
+ request_task = await _get_request_no_lock_async(request_id)
912
978
  assert request_task is not None, request_id
913
979
  # Already finished or cancelled.
914
980
  if request_task.status > RequestStatus.RUNNING:
915
981
  return
916
982
  request_task.finished_at = time.time()
917
983
  request_task.status = RequestStatus.CANCELLED
984
+ await _add_or_update_request_no_lock_async(request_task)
918
985
 
919
986
 
920
987
  @init_db
921
988
  @metrics_lib.time_me
922
- async def _delete_requests(requests: List[Request]):
989
+ async def _delete_requests(request_ids: List[str]):
923
990
  """Clean up requests by their IDs."""
924
- id_list_str = ','.join(repr(req.request_id) for req in requests)
991
+ id_list_str = ','.join(repr(request_id) for request_id in request_ids)
925
992
  assert _DB is not None
926
993
  await _DB.execute_and_commit_async(
927
994
  f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
928
995
 
929
996
 
930
- async def clean_finished_requests_with_retention(retention_seconds: int):
997
+ async def clean_finished_requests_with_retention(retention_seconds: int,
998
+ batch_size: int = 1000):
931
999
  """Clean up finished requests older than the retention period.
932
1000
 
933
1001
  This function removes old finished requests (SUCCEEDED, FAILED, CANCELLED)
@@ -936,24 +1004,40 @@ async def clean_finished_requests_with_retention(retention_seconds: int):
936
1004
  Args:
937
1005
  retention_seconds: Requests older than this many seconds will be
938
1006
  deleted.
1007
+ batch_size: batch delete 'batch_size' requests at a time to
1008
+ avoid using too much memory and once and to let each
1009
+ db query complete in a reasonable time. All stale
1010
+ requests older than the retention period will be deleted
1011
+ regardless of the batch size.
939
1012
  """
940
- reqs = await get_request_tasks_async(
941
- req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
942
- finished_before=time.time() -
943
- retention_seconds))
944
-
945
- futs = []
946
- for req in reqs:
947
- futs.append(
948
- asyncio.create_task(
949
- anyio.Path(req.log_path.absolute()).unlink(missing_ok=True)))
950
- await asyncio.gather(*futs)
951
-
952
- await _delete_requests(reqs)
1013
+ total_deleted = 0
1014
+ while True:
1015
+ reqs = await get_request_tasks_async(
1016
+ req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
1017
+ finished_before=time.time() -
1018
+ retention_seconds,
1019
+ limit=batch_size,
1020
+ fields=['request_id']))
1021
+ if len(reqs) == 0:
1022
+ break
1023
+ futs = []
1024
+ for req in reqs:
1025
+ # req.log_path is derived from request_id,
1026
+ # so it's ok to just grab the request_id in the above query.
1027
+ futs.append(
1028
+ asyncio.create_task(
1029
+ anyio.Path(
1030
+ req.log_path.absolute()).unlink(missing_ok=True)))
1031
+ await asyncio.gather(*futs)
1032
+
1033
+ await _delete_requests([req.request_id for req in reqs])
1034
+ total_deleted += len(reqs)
1035
+ if len(reqs) < batch_size:
1036
+ break
953
1037
 
954
1038
  # To avoid leakage of the log file, logs must be deleted before the
955
1039
  # request task in the database.
956
- logger.info(f'Cleaned up {len(reqs)} finished requests '
1040
+ logger.info(f'Cleaned up {total_deleted} finished requests '
957
1041
  f'older than {retention_seconds} seconds')
958
1042
 
959
1043
 
sky/server/server.py CHANGED
@@ -43,6 +43,7 @@ from sky.data import storage_utils
43
43
  from sky.jobs import utils as managed_job_utils
44
44
  from sky.jobs.server import server as jobs_rest
45
45
  from sky.metrics import utils as metrics_utils
46
+ from sky.provision import metadata_utils
46
47
  from sky.provision.kubernetes import utils as kubernetes_utils
47
48
  from sky.schemas.api import responses
48
49
  from sky.serve.server import server as serve_rest
@@ -1270,6 +1271,7 @@ async def logs(
1270
1271
  request_id=request.state.request_id,
1271
1272
  logs_path=request_task.log_path,
1272
1273
  background_tasks=background_tasks,
1274
+ kill_request_on_disconnect=False,
1273
1275
  )
1274
1276
 
1275
1277
 
@@ -1363,38 +1365,65 @@ async def download(download_body: payloads.DownloadBody,
1363
1365
 
1364
1366
  # TODO(aylei): run it asynchronously after global_user_state support async op
1365
1367
  @app.post('/provision_logs')
1366
- def provision_logs(cluster_body: payloads.ClusterNameBody,
1368
+ def provision_logs(provision_logs_body: payloads.ProvisionLogsBody,
1367
1369
  follow: bool = True,
1368
1370
  tail: int = 0) -> fastapi.responses.StreamingResponse:
1369
1371
  """Streams the provision.log for the latest launch request of a cluster."""
1370
- # Prefer clusters table first, then cluster_history as fallback.
1371
- log_path_str = global_user_state.get_cluster_provision_log_path(
1372
- cluster_body.cluster_name)
1373
- if not log_path_str:
1374
- log_path_str = global_user_state.get_cluster_history_provision_log_path(
1375
- cluster_body.cluster_name)
1376
- if not log_path_str:
1377
- raise fastapi.HTTPException(
1378
- status_code=404,
1379
- detail=('Provision log path is not recorded for this cluster. '
1380
- 'Please relaunch to generate provisioning logs.'))
1372
+ log_path = None
1373
+ cluster_name = provision_logs_body.cluster_name
1374
+ worker = provision_logs_body.worker
1375
+ # stream head node logs
1376
+ if worker is None:
1377
+ # Prefer clusters table first, then cluster_history as fallback.
1378
+ log_path_str = global_user_state.get_cluster_provision_log_path(
1379
+ cluster_name)
1380
+ if not log_path_str:
1381
+ log_path_str = (
1382
+ global_user_state.get_cluster_history_provision_log_path(
1383
+ cluster_name))
1384
+ if not log_path_str:
1385
+ raise fastapi.HTTPException(
1386
+ status_code=404,
1387
+ detail=('Provision log path is not recorded for this cluster. '
1388
+ 'Please relaunch to generate provisioning logs.'))
1389
+ log_path = pathlib.Path(log_path_str).expanduser().resolve()
1390
+ if not log_path.exists():
1391
+ raise fastapi.HTTPException(
1392
+ status_code=404,
1393
+ detail=f'Provision log path does not exist: {str(log_path)}')
1381
1394
 
1382
- log_path = pathlib.Path(log_path_str).expanduser().resolve()
1383
- if not log_path.exists():
1384
- raise fastapi.HTTPException(
1385
- status_code=404,
1386
- detail=f'Provision log path does not exist: {str(log_path)}')
1395
+ # stream worker node logs
1396
+ else:
1397
+ handle = global_user_state.get_handle_from_cluster_name(cluster_name)
1398
+ if handle is None:
1399
+ raise fastapi.HTTPException(
1400
+ status_code=404,
1401
+ detail=('Cluster handle is not recorded for this cluster. '
1402
+ 'Please relaunch to generate provisioning logs.'))
1403
+ # instance_ids includes head node
1404
+ instance_ids = handle.instance_ids
1405
+ if instance_ids is None:
1406
+ raise fastapi.HTTPException(
1407
+ status_code=400,
1408
+ detail='Instance IDs are not recorded for this cluster. '
1409
+ 'Please relaunch to generate provisioning logs.')
1410
+ if worker > len(instance_ids) - 1:
1411
+ raise fastapi.HTTPException(
1412
+ status_code=400,
1413
+ detail=f'Worker {worker} is out of range. '
1414
+ f'The cluster has {len(instance_ids)} nodes.')
1415
+ log_path = metadata_utils.get_instance_log_dir(
1416
+ handle.get_cluster_name_on_cloud(), instance_ids[worker])
1387
1417
 
1388
1418
  # Tail semantics: 0 means print all lines. Convert 0 -> None for streamer.
1389
1419
  effective_tail = None if tail is None or tail <= 0 else tail
1390
1420
 
1391
1421
  return fastapi.responses.StreamingResponse(
1392
- content=stream_utils.log_streamer(
1393
- None,
1394
- log_path,
1395
- tail=effective_tail,
1396
- follow=follow,
1397
- cluster_name=cluster_body.cluster_name),
1422
+ content=stream_utils.log_streamer(None,
1423
+ log_path,
1424
+ tail=effective_tail,
1425
+ follow=follow,
1426
+ cluster_name=cluster_name),
1398
1427
  media_type='text/plain',
1399
1428
  headers={
1400
1429
  'Cache-Control': 'no-cache, no-transform',
@@ -1567,11 +1596,14 @@ async def stream(
1567
1596
  polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
1568
1597
  # Original plain text streaming logic
1569
1598
  if request_id is not None:
1570
- request_task = await requests_lib.get_request_async(request_id)
1599
+ request_task = await requests_lib.get_request_async(
1600
+ request_id, fields=['request_id', 'schedule_type'])
1571
1601
  if request_task is None:
1572
1602
  print(f'No task with request ID {request_id}')
1573
1603
  raise fastapi.HTTPException(
1574
1604
  status_code=404, detail=f'Request {request_id!r} not found')
1605
+ # req.log_path is derived from request_id,
1606
+ # so it's ok to just grab the request_id in the above query.
1575
1607
  log_path_to_stream = request_task.log_path
1576
1608
  if not log_path_to_stream.exists():
1577
1609
  # The log file might be deleted by the request GC daemon but the
@@ -1581,6 +1613,7 @@ async def stream(
1581
1613
  detail=f'Log of request {request_id!r} has been deleted')
1582
1614
  if request_task.schedule_type == requests_lib.ScheduleType.LONG:
1583
1615
  polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
1616
+ del request_task
1584
1617
  else:
1585
1618
  assert log_path is not None, (request_id, log_path)
1586
1619
  if log_path == constants.API_SERVER_LOGS:
@@ -25,6 +25,17 @@ logger = sky_logging.init_logger(__name__)
25
25
  _BUFFER_SIZE = 8 * 1024 # 8KB
26
26
  _BUFFER_TIMEOUT = 0.02 # 20ms
27
27
  _HEARTBEAT_INTERVAL = 30
28
+ # If a SHORT request has been stuck in pending for
29
+ # _SHORT_REQUEST_SPINNER_TIMEOUT seconds, we show the waiting spinner
30
+ _SHORT_REQUEST_SPINNER_TIMEOUT = 2
31
+ # If there is an issue during provisioning that causes the cluster to be stuck
32
+ # in INIT state, we use this timeout to break the loop and stop streaming
33
+ # provision logs.
34
+ _PROVISION_LOG_TIMEOUT = 3
35
+ # Maximum time to wait for new log files to appear when streaming worker node
36
+ # provision logs. Worker logs are created sequentially during the provisioning
37
+ # process, so we need to wait for new files to appear.
38
+ _MAX_WAIT_FOR_NEW_LOG_FILES = 3 # seconds
28
39
 
29
40
  LONG_REQUEST_POLL_INTERVAL = 1
30
41
  DEFAULT_POLL_INTERVAL = 0.1
@@ -45,7 +56,7 @@ async def _yield_log_file_with_payloads_skipped(
45
56
 
46
57
  async def log_streamer(
47
58
  request_id: Optional[str],
48
- log_path: pathlib.Path,
59
+ log_path: Optional[pathlib.Path] = None,
49
60
  plain_logs: bool = False,
50
61
  tail: Optional[int] = None,
51
62
  follow: bool = True,
@@ -57,7 +68,9 @@ async def log_streamer(
57
68
  Args:
58
69
  request_id: The request ID to check whether the log tailing process
59
70
  should be stopped.
60
- log_path: The path to the log file.
71
+ log_path: The path to the log file or directory containing the log
72
+ files. If it is a directory, all *.log files in the directory will be
73
+ streamed.
61
74
  plain_logs: Whether to show plain logs.
62
75
  tail: The number of lines to tail. If None, tail the whole file.
63
76
  follow: Whether to follow the log file.
@@ -66,17 +79,26 @@ async def log_streamer(
66
79
  """
67
80
 
68
81
  if request_id is not None:
82
+ start_time = asyncio.get_event_loop().time()
69
83
  status_msg = rich_utils.EncodedStatusMessage(
70
84
  f'[dim]Checking request: {request_id}[/dim]')
71
- request_task = await requests_lib.get_request_async(request_id)
85
+ request_task = await requests_lib.get_request_async(request_id,
86
+ fields=[
87
+ 'request_id',
88
+ 'name',
89
+ 'schedule_type',
90
+ 'status',
91
+ 'status_msg'
92
+ ])
72
93
 
73
94
  if request_task is None:
74
95
  raise fastapi.HTTPException(
75
96
  status_code=404, detail=f'Request {request_id} not found')
76
97
  request_id = request_task.request_id
77
98
 
78
- # Do not show the waiting spinner if the request is a fast, non-blocking
79
- # request.
99
+ # By default, do not show the waiting spinner for SHORT requests.
100
+ # If the request has been stuck in pending for
101
+ # _SHORT_REQUEST_SPINNER_TIMEOUT seconds, we show the waiting spinner
80
102
  show_request_waiting_spinner = (not plain_logs and
81
103
  request_task.schedule_type
82
104
  == requests_lib.ScheduleType.LONG)
@@ -89,14 +111,23 @@ async def log_streamer(
89
111
  f'scheduled: {request_id}')
90
112
  req_status = request_task.status
91
113
  req_msg = request_task.status_msg
114
+ del request_task
92
115
  # Slowly back off the database polling up to every 1 second, to avoid
93
116
  # overloading the CPU and DB.
94
117
  backoff = common_utils.Backoff(initial_backoff=polling_interval,
95
118
  max_backoff_factor=10,
96
119
  multiplier=1.2)
97
120
  while req_status < requests_lib.RequestStatus.RUNNING:
121
+ current_time = asyncio.get_event_loop().time()
122
+ # Show the waiting spinner for a SHORT request if it has been stuck
123
+ # in pending for _SHORT_REQUEST_SPINNER_TIMEOUT seconds
124
+ if not show_request_waiting_spinner and (
125
+ current_time - start_time > _SHORT_REQUEST_SPINNER_TIMEOUT):
126
+ show_request_waiting_spinner = True
127
+ yield status_msg.init()
128
+ yield status_msg.start()
98
129
  if req_msg is not None:
99
- waiting_msg = request_task.status_msg
130
+ waiting_msg = req_msg
100
131
  if show_request_waiting_spinner:
101
132
  yield status_msg.update(f'[dim]{waiting_msg}[/dim]')
102
133
  elif plain_logs and waiting_msg != last_waiting_msg:
@@ -119,11 +150,57 @@ async def log_streamer(
119
150
  if show_request_waiting_spinner:
120
151
  yield status_msg.stop()
121
152
 
122
- async with aiofiles.open(log_path, 'rb') as f:
123
- async for chunk in _tail_log_file(f, request_id, plain_logs, tail,
124
- follow, cluster_name,
125
- polling_interval):
126
- yield chunk
153
+ if log_path is not None and log_path.is_dir():
154
+ # Track which log files we've already streamed
155
+ streamed_files = set()
156
+ no_new_files_count = 0
157
+
158
+ while True:
159
+ # Get all *.log files in the log_path
160
+ log_files = sorted(log_path.glob('*.log'))
161
+
162
+ # Filter out already streamed files
163
+ new_files = [f for f in log_files if f not in streamed_files]
164
+
165
+ if len(new_files) == 0:
166
+ if not follow:
167
+ break
168
+ # Wait a bit to see if new files appear
169
+ await asyncio.sleep(0.5)
170
+ no_new_files_count += 1
171
+ # Check if we've waited too long for new files
172
+ if no_new_files_count > _MAX_WAIT_FOR_NEW_LOG_FILES * 2:
173
+ break
174
+ continue
175
+
176
+ # Reset the no-new-files counter when we find new files
177
+ no_new_files_count = 0
178
+
179
+ for log_file_path in new_files:
180
+ # Add header before each file (similar to tail -f behavior)
181
+ header = f'\n==> {log_file_path} <==\n\n'
182
+ yield header
183
+
184
+ async with aiofiles.open(log_file_path, 'rb') as f:
185
+ async for chunk in _tail_log_file(f, request_id, plain_logs,
186
+ tail, follow,
187
+ cluster_name,
188
+ polling_interval):
189
+ yield chunk
190
+
191
+ # Mark this file as streamed
192
+ streamed_files.add(log_file_path)
193
+
194
+ # If not following, break after streaming all current files
195
+ if not follow:
196
+ break
197
+ else:
198
+ assert log_path is not None, (request_id, log_path)
199
+ async with aiofiles.open(log_path, 'rb') as f:
200
+ async for chunk in _tail_log_file(f, request_id, plain_logs, tail,
201
+ follow, cluster_name,
202
+ polling_interval):
203
+ yield chunk
127
204
 
128
205
 
129
206
  async def _tail_log_file(
@@ -197,7 +274,7 @@ async def _tail_log_file(
197
274
  if (req_status.status ==
198
275
  requests_lib.RequestStatus.CANCELLED):
199
276
  request_task = await requests_lib.get_request_async(
200
- request_id)
277
+ request_id, fields=['name', 'should_retry'])
201
278
  if request_task.should_retry:
202
279
  buffer.append(
203
280
  message_utils.encode_payload(
@@ -206,6 +283,7 @@ async def _tail_log_file(
206
283
  buffer.append(
207
284
  f'{request_task.name!r} request {request_id}'
208
285
  ' cancelled\n')
286
+ del request_task
209
287
  break
210
288
  if not follow:
211
289
  # The below checks (cluster status, heartbeat) are not needed
@@ -213,21 +291,24 @@ async def _tail_log_file(
213
291
  break
214
292
  # Provision logs pass in cluster_name, check cluster status
215
293
  # periodically to see if provisioning is done.
216
- if cluster_name is not None and should_check_status:
217
- last_status_check_time = current_time
218
- cluster_status = await (
219
- global_user_state.get_status_from_cluster_name_async(
220
- cluster_name))
221
- if cluster_status is None:
222
- logger.debug(
223
- 'Stop tailing provision logs for cluster'
224
- f' status for cluster {cluster_name} not found')
225
- break
226
- if cluster_status != status_lib.ClusterStatus.INIT:
227
- logger.debug(f'Stop tailing provision logs for cluster'
228
- f' {cluster_name} has status {cluster_status} '
229
- '(not in INIT state)')
294
+ if cluster_name is not None:
295
+ if current_time - last_flush_time > _PROVISION_LOG_TIMEOUT:
230
296
  break
297
+ if should_check_status:
298
+ last_status_check_time = current_time
299
+ cluster_status = await (
300
+ global_user_state.get_status_from_cluster_name_async(
301
+ cluster_name))
302
+ if cluster_status is None:
303
+ logger.debug(
304
+ 'Stop tailing provision logs for cluster'
305
+ f' status for cluster {cluster_name} not found')
306
+ break
307
+ if cluster_status != status_lib.ClusterStatus.INIT:
308
+ logger.debug(
309
+ f'Stop tailing provision logs for cluster'
310
+ f' {cluster_name} has status {cluster_status} '
311
+ '(not in INIT state)')
231
312
  if current_time - last_heartbeat_time >= _HEARTBEAT_INTERVAL:
232
313
  # Currently just used to keep the connection busy, refer to
233
314
  # https://github.com/skypilot-org/skypilot/issues/5750 for
@@ -267,28 +348,36 @@ def stream_response_for_long_request(
267
348
  request_id: str,
268
349
  logs_path: pathlib.Path,
269
350
  background_tasks: fastapi.BackgroundTasks,
351
+ kill_request_on_disconnect: bool = True,
270
352
  ) -> fastapi.responses.StreamingResponse:
271
- return stream_response(request_id,
272
- logs_path,
273
- background_tasks,
274
- polling_interval=LONG_REQUEST_POLL_INTERVAL)
353
+ """Stream the logs of a long request."""
354
+ return stream_response(
355
+ request_id,
356
+ logs_path,
357
+ background_tasks,
358
+ polling_interval=LONG_REQUEST_POLL_INTERVAL,
359
+ kill_request_on_disconnect=kill_request_on_disconnect,
360
+ )
275
361
 
276
362
 
277
363
  def stream_response(
278
364
  request_id: str,
279
365
  logs_path: pathlib.Path,
280
366
  background_tasks: fastapi.BackgroundTasks,
281
- polling_interval: float = DEFAULT_POLL_INTERVAL
367
+ polling_interval: float = DEFAULT_POLL_INTERVAL,
368
+ kill_request_on_disconnect: bool = True,
282
369
  ) -> fastapi.responses.StreamingResponse:
283
370
 
284
- async def on_disconnect():
285
- logger.info(f'User terminated the connection for request '
286
- f'{request_id}')
287
- requests_lib.kill_requests([request_id])
371
+ if kill_request_on_disconnect:
372
+
373
+ async def on_disconnect():
374
+ logger.info(f'User terminated the connection for request '
375
+ f'{request_id}')
376
+ requests_lib.kill_requests([request_id])
288
377
 
289
- # The background task will be run after returning a response.
290
- # https://fastapi.tiangolo.com/tutorial/background-tasks/
291
- background_tasks.add_task(on_disconnect)
378
+ # The background task will be run after returning a response.
379
+ # https://fastapi.tiangolo.com/tutorial/background-tasks/
380
+ background_tasks.add_task(on_disconnect)
292
381
 
293
382
  return fastapi.responses.StreamingResponse(
294
383
  log_streamer(request_id, logs_path, polling_interval=polling_interval),