skypilot-nightly 1.0.0.dev20250901__py3-none-any.whl → 1.0.0.dev20250903__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (63) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/runpod.py +68 -0
  3. sky/backends/backend_utils.py +5 -3
  4. sky/client/cli/command.py +20 -5
  5. sky/clouds/kubernetes.py +1 -1
  6. sky/clouds/runpod.py +17 -0
  7. sky/dashboard/out/404.html +1 -1
  8. sky/dashboard/out/_next/static/chunks/1121-ec35954c8cbea535.js +1 -0
  9. sky/dashboard/out/_next/static/chunks/3015-8089ed1e0b7e37fd.js +1 -0
  10. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-b77360a343d48902.js +16 -0
  11. sky/dashboard/out/_next/static/chunks/webpack-60556df644cd5d71.js +1 -0
  12. sky/dashboard/out/_next/static/{EqPZ0ygxa__3XPBVJ9dpy → yLz6EPhW_XXmnNs1I6dmS}/_buildManifest.js +1 -1
  13. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  14. sky/dashboard/out/clusters/[cluster].html +1 -1
  15. sky/dashboard/out/clusters.html +1 -1
  16. sky/dashboard/out/config.html +1 -1
  17. sky/dashboard/out/index.html +1 -1
  18. sky/dashboard/out/infra/[context].html +1 -1
  19. sky/dashboard/out/infra.html +1 -1
  20. sky/dashboard/out/jobs/[job].html +1 -1
  21. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  22. sky/dashboard/out/jobs.html +1 -1
  23. sky/dashboard/out/users.html +1 -1
  24. sky/dashboard/out/volumes.html +1 -1
  25. sky/dashboard/out/workspace/new.html +1 -1
  26. sky/dashboard/out/workspaces/[name].html +1 -1
  27. sky/dashboard/out/workspaces.html +1 -1
  28. sky/global_user_state.py +5 -2
  29. sky/models.py +1 -0
  30. sky/provision/runpod/__init__.py +3 -0
  31. sky/provision/runpod/instance.py +17 -0
  32. sky/provision/runpod/utils.py +23 -5
  33. sky/provision/runpod/volume.py +158 -0
  34. sky/server/auth/oauth2_proxy.py +6 -0
  35. sky/server/requests/payloads.py +7 -1
  36. sky/server/requests/preconditions.py +8 -7
  37. sky/server/requests/requests.py +123 -57
  38. sky/server/server.py +32 -25
  39. sky/server/stream_utils.py +14 -6
  40. sky/server/uvicorn.py +2 -1
  41. sky/templates/kubernetes-ray.yml.j2 +5 -5
  42. sky/templates/runpod-ray.yml.j2 +8 -0
  43. sky/utils/benchmark_utils.py +60 -0
  44. sky/utils/command_runner.py +4 -0
  45. sky/utils/db/migration_utils.py +20 -4
  46. sky/utils/resource_checker.py +6 -5
  47. sky/utils/schemas.py +1 -1
  48. sky/utils/volume.py +3 -0
  49. sky/volumes/client/sdk.py +28 -0
  50. sky/volumes/server/server.py +11 -1
  51. sky/volumes/utils.py +117 -68
  52. sky/volumes/volume.py +98 -39
  53. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/METADATA +33 -33
  54. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/RECORD +59 -57
  55. sky/dashboard/out/_next/static/chunks/1121-8afcf719ea87debc.js +0 -1
  56. sky/dashboard/out/_next/static/chunks/3015-6c9c09593b1e67b6.js +0 -1
  57. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-06afb50d25f7c61f.js +0 -16
  58. sky/dashboard/out/_next/static/chunks/webpack-6e76f636a048e145.js +0 -1
  59. /sky/dashboard/out/_next/static/{EqPZ0ygxa__3XPBVJ9dpy → yLz6EPhW_XXmnNs1I6dmS}/_ssgManifest.js +0 -0
  60. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/WHEEL +0 -0
  61. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/entry_points.txt +0 -0
  62. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/licenses/LICENSE +0 -0
  63. {skypilot_nightly-1.0.0.dev20250901.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@ import threading
14
14
  import time
15
15
  import traceback
16
16
  from typing import (Any, AsyncContextManager, Callable, Dict, Generator, List,
17
- Optional, Tuple)
17
+ NamedTuple, Optional, Tuple)
18
18
 
19
19
  import colorama
20
20
  import filelock
@@ -300,10 +300,11 @@ def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
300
300
  prevent killing the caller request.
301
301
  """
302
302
  request_ids = [
303
- request_task.request_id for request_task in get_request_tasks(
303
+ request_task.request_id
304
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
304
305
  cluster_names=[cluster_name],
305
306
  status=[RequestStatus.PENDING, RequestStatus.RUNNING],
306
- exclude_request_names=[exclude_request_name])
307
+ exclude_request_names=[exclude_request_name]))
307
308
  ]
308
309
  kill_requests(request_ids)
309
310
 
@@ -323,11 +324,12 @@ def kill_requests(request_ids: Optional[List[str]] = None,
323
324
  """
324
325
  if request_ids is None:
325
326
  request_ids = [
326
- request_task.request_id for request_task in get_request_tasks(
327
+ request_task.request_id
328
+ for request_task in get_request_tasks(req_filter=RequestTaskFilter(
327
329
  user_id=user_id,
328
330
  status=[RequestStatus.RUNNING, RequestStatus.PENDING],
329
331
  # Avoid cancelling the cancel request itself.
330
- exclude_request_names=['sky.api_cancel'])
332
+ exclude_request_names=['sky.api_cancel']))
331
333
  ]
332
334
  cancelled_request_ids = []
333
335
  for request_id in request_ids:
@@ -548,6 +550,40 @@ async def get_request_async(request_id: str) -> Optional[Request]:
548
550
  return await _get_request_no_lock_async(request_id)
549
551
 
550
552
 
553
+ class StatusWithMsg(NamedTuple):
554
+ status: RequestStatus
555
+ status_msg: Optional[str] = None
556
+
557
+
558
+ @init_db_async
559
+ @metrics_lib.time_me_async
560
+ async def get_request_status_async(
561
+ request_id: str,
562
+ include_msg: bool = False,
563
+ ) -> Optional[StatusWithMsg]:
564
+ """Get the status of a request.
565
+
566
+ Args:
567
+ request_id: The ID of the request.
568
+ include_msg: Whether to include the status message.
569
+
570
+ Returns:
571
+ The status of the request. If the request is not found, returns
572
+ None.
573
+ """
574
+ assert _DB is not None
575
+ columns = 'status'
576
+ if include_msg:
577
+ columns += ', status_msg'
578
+ sql = f'SELECT {columns} FROM {REQUEST_TABLE} WHERE request_id LIKE ?'
579
+ async with _DB.execute_fetchall_async(sql, (request_id + '%',)) as rows:
580
+ if rows is None or len(rows) == 0:
581
+ return None
582
+ status = RequestStatus(rows[0][0])
583
+ status_msg = rows[0][1] if include_msg else None
584
+ return StatusWithMsg(status, status_msg)
585
+
586
+
551
587
  @init_db
552
588
  @metrics_lib.time_me
553
589
  def create_if_not_exists(request: Request) -> bool:
@@ -570,17 +606,9 @@ async def create_if_not_exists_async(request: Request) -> bool:
570
606
  return True
571
607
 
572
608
 
573
- @init_db
574
- @metrics_lib.time_me
575
- def get_request_tasks(
576
- status: Optional[List[RequestStatus]] = None,
577
- cluster_names: Optional[List[str]] = None,
578
- user_id: Optional[str] = None,
579
- exclude_request_names: Optional[List[str]] = None,
580
- include_request_names: Optional[List[str]] = None,
581
- finished_before: Optional[float] = None,
582
- ) -> List[Request]:
583
- """Get a list of requests that match the given filters.
609
+ @dataclasses.dataclass
610
+ class RequestTaskFilter:
611
+ """Filter for requests.
584
612
 
585
613
  Args:
586
614
  status: a list of statuses of the requests to filter on.
@@ -598,51 +626,87 @@ def get_request_tasks(
598
626
  ValueError: If both exclude_request_names and include_request_names are
599
627
  provided.
600
628
  """
601
- if exclude_request_names is not None and include_request_names is not None:
602
- raise ValueError(
603
- 'Only one of exclude_request_names or include_request_names can be '
604
- 'provided, not both.')
605
-
606
- filters = []
607
- filter_params: List[Any] = []
608
- if status is not None:
609
- status_list_str = ','.join(repr(status.value) for status in status)
610
- filters.append(f'status IN ({status_list_str})')
611
- if exclude_request_names is not None:
612
- exclude_request_names_str = ','.join(
613
- repr(name) for name in exclude_request_names)
614
- filters.append(f'name NOT IN ({exclude_request_names_str})')
615
- if cluster_names is not None:
616
- cluster_names_str = ','.join(repr(name) for name in cluster_names)
617
- filters.append(f'{COL_CLUSTER_NAME} IN ({cluster_names_str})')
618
- if user_id is not None:
619
- filters.append(f'{COL_USER_ID} = ?')
620
- filter_params.append(user_id)
621
- if include_request_names is not None:
622
- request_names_str = ','.join(
623
- repr(name) for name in include_request_names)
624
- filters.append(f'name IN ({request_names_str})')
625
- if finished_before is not None:
626
- filters.append('finished_at < ?')
627
- filter_params.append(finished_before)
628
- assert _DB is not None
629
- with _DB.conn:
630
- cursor = _DB.conn.cursor()
629
+ status: Optional[List[RequestStatus]] = None
630
+ cluster_names: Optional[List[str]] = None
631
+ user_id: Optional[str] = None
632
+ exclude_request_names: Optional[List[str]] = None
633
+ include_request_names: Optional[List[str]] = None
634
+ finished_before: Optional[float] = None
635
+
636
+ def __post_init__(self):
637
+ if (self.exclude_request_names is not None and
638
+ self.include_request_names is not None):
639
+ raise ValueError(
640
+ 'Only one of exclude_request_names or include_request_names '
641
+ 'can be provided, not both.')
642
+
643
+ def build_query(self) -> Tuple[str, List[Any]]:
644
+ """Build the SQL query and filter parameters.
645
+
646
+ Returns:
647
+ A tuple of (SQL, SQL parameters).
648
+ """
649
+ filters = []
650
+ filter_params: List[Any] = []
651
+ if self.status is not None:
652
+ status_list_str = ','.join(
653
+ repr(status.value) for status in self.status)
654
+ filters.append(f'status IN ({status_list_str})')
655
+ if self.exclude_request_names is not None:
656
+ exclude_request_names_str = ','.join(
657
+ repr(name) for name in self.exclude_request_names)
658
+ filters.append(f'name NOT IN ({exclude_request_names_str})')
659
+ if self.cluster_names is not None:
660
+ cluster_names_str = ','.join(
661
+ repr(name) for name in self.cluster_names)
662
+ filters.append(f'{COL_CLUSTER_NAME} IN ({cluster_names_str})')
663
+ if self.user_id is not None:
664
+ filters.append(f'{COL_USER_ID} = ?')
665
+ filter_params.append(self.user_id)
666
+ if self.include_request_names is not None:
667
+ request_names_str = ','.join(
668
+ repr(name) for name in self.include_request_names)
669
+ filters.append(f'name IN ({request_names_str})')
670
+ if self.finished_before is not None:
671
+ filters.append('finished_at < ?')
672
+ filter_params.append(self.finished_before)
631
673
  filter_str = ' AND '.join(filters)
632
674
  if filter_str:
633
675
  filter_str = f' WHERE {filter_str}'
634
676
  columns_str = ', '.join(REQUEST_COLUMNS)
635
- cursor.execute(
636
- f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
637
- 'ORDER BY created_at DESC', filter_params)
677
+ return (f'SELECT {columns_str} FROM {REQUEST_TABLE}{filter_str} '
678
+ 'ORDER BY created_at DESC'), filter_params
679
+
680
+
681
+ @init_db
682
+ @metrics_lib.time_me
683
+ def get_request_tasks(req_filter: RequestTaskFilter) -> List[Request]:
684
+ """Get a list of requests that match the given filters.
685
+
686
+ Args:
687
+ req_filter: the filter to apply to the requests. Refer to
688
+ RequestTaskFilter for the details.
689
+ """
690
+ assert _DB is not None
691
+ with _DB.conn:
692
+ cursor = _DB.conn.cursor()
693
+ cursor.execute(*req_filter.build_query())
638
694
  rows = cursor.fetchall()
639
695
  if rows is None:
640
696
  return []
641
- requests = []
642
- for row in rows:
643
- request = Request.from_row(row)
644
- requests.append(request)
645
- return requests
697
+ return [Request.from_row(row) for row in rows]
698
+
699
+
700
+ @init_db_async
701
+ @metrics_lib.time_me_async
702
+ async def get_request_tasks_async(
703
+ req_filter: RequestTaskFilter) -> List[Request]:
704
+ """Async version of get_request_tasks."""
705
+ assert _DB is not None
706
+ async with _DB.execute_fetchall_async(*req_filter.build_query()) as rows:
707
+ if not rows:
708
+ return []
709
+ return [Request.from_row(row) for row in rows]
646
710
 
647
711
 
648
712
  @init_db_async
@@ -739,8 +803,10 @@ def clean_finished_requests_with_retention(retention_seconds: int):
739
803
  retention_seconds: Requests older than this many seconds will be
740
804
  deleted.
741
805
  """
742
- reqs = get_request_tasks(status=RequestStatus.finished_status(),
743
- finished_before=time.time() - retention_seconds)
806
+ reqs = get_request_tasks(
807
+ req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
808
+ finished_before=time.time() -
809
+ retention_seconds))
744
810
 
745
811
  subprocess_utils.run_in_parallel(
746
812
  func=lambda req: req.log_path.unlink(missing_ok=True),
@@ -767,7 +833,7 @@ async def requests_gc_daemon():
767
833
  try:
768
834
  # Negative value disables the requests GC
769
835
  if retention_seconds >= 0:
770
- clean_finished_requests_with_retention(retention_seconds)
836
+ await clean_finished_requests_with_retention(retention_seconds)
771
837
  except asyncio.CancelledError:
772
838
  logger.info('Requests GC daemon cancelled')
773
839
  break
sky/server/server.py CHANGED
@@ -24,6 +24,7 @@ import aiofiles
24
24
  import anyio
25
25
  import fastapi
26
26
  from fastapi.middleware import cors
27
+ from sqlalchemy import pool
27
28
  import starlette.middleware.base
28
29
  import uvloop
29
30
 
@@ -1327,10 +1328,12 @@ async def provision_logs(cluster_body: payloads.ClusterNameBody,
1327
1328
  tail: int = 0) -> fastapi.responses.StreamingResponse:
1328
1329
  """Streams the provision.log for the latest launch request of a cluster."""
1329
1330
  # Prefer clusters table first, then cluster_history as fallback.
1330
- log_path_str = global_user_state.get_cluster_provision_log_path(
1331
+ log_path_str = await context_utils.to_thread(
1332
+ global_user_state.get_cluster_provision_log_path,
1331
1333
  cluster_body.cluster_name)
1332
1334
  if not log_path_str:
1333
- log_path_str = global_user_state.get_cluster_history_provision_log_path(
1335
+ log_path_str = await context_utils.to_thread(
1336
+ global_user_state.get_cluster_history_provision_log_path,
1334
1337
  cluster_body.cluster_name)
1335
1338
  if not log_path_str:
1336
1339
  raise fastapi.HTTPException(
@@ -1429,27 +1432,29 @@ async def local_down(request: fastapi.Request) -> None:
1429
1432
  async def api_get(request_id: str) -> payloads.RequestPayload:
1430
1433
  """Gets a request with a given request ID prefix."""
1431
1434
  while True:
1432
- request_task = await requests_lib.get_request_async(request_id)
1433
- if request_task is None:
1435
+ req_status = await requests_lib.get_request_status_async(request_id)
1436
+ if req_status is None:
1434
1437
  print(f'No task with request ID {request_id}', flush=True)
1435
1438
  raise fastapi.HTTPException(
1436
1439
  status_code=404, detail=f'Request {request_id!r} not found')
1437
- if request_task.status > requests_lib.RequestStatus.RUNNING:
1438
- if request_task.should_retry:
1439
- raise fastapi.HTTPException(
1440
- status_code=503,
1441
- detail=f'Request {request_id!r} should be retried')
1442
- request_error = request_task.get_error()
1443
- if request_error is not None:
1444
- raise fastapi.HTTPException(
1445
- status_code=500, detail=request_task.encode().model_dump())
1446
- return request_task.encode()
1447
- elif (request_task.status == requests_lib.RequestStatus.RUNNING and
1448
- daemons.is_daemon_request_id(request_id)):
1449
- return request_task.encode()
1440
+ if (req_status.status == requests_lib.RequestStatus.RUNNING and
1441
+ daemons.is_daemon_request_id(request_id)):
1442
+ # Daemon requests run forever, break without waiting for complete.
1443
+ break
1444
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
1445
+ break
1450
1446
  # yield control to allow other coroutines to run, sleep shortly
1451
1447
  # to avoid storming the DB and CPU in the meantime
1452
1448
  await asyncio.sleep(0.1)
1449
+ request_task = await requests_lib.get_request_async(request_id)
1450
+ if request_task.should_retry:
1451
+ raise fastapi.HTTPException(
1452
+ status_code=503, detail=f'Request {request_id!r} should be retried')
1453
+ request_error = request_task.get_error()
1454
+ if request_error is not None:
1455
+ raise fastapi.HTTPException(status_code=500,
1456
+ detail=request_task.encode().model_dump())
1457
+ return request_task.encode()
1453
1458
 
1454
1459
 
1455
1460
  @app.get('/api/stream')
@@ -1606,10 +1611,9 @@ async def api_status(
1606
1611
  requests_lib.RequestStatus.PENDING,
1607
1612
  requests_lib.RequestStatus.RUNNING,
1608
1613
  ]
1609
- return [
1610
- request_task.readable_encode()
1611
- for request_task in requests_lib.get_request_tasks(status=statuses)
1612
- ]
1614
+ request_tasks = await requests_lib.get_request_tasks_async(
1615
+ req_filter=requests_lib.RequestTaskFilter(status=statuses))
1616
+ return [r.readable_encode() for r in request_tasks]
1613
1617
  else:
1614
1618
  encoded_request_tasks = []
1615
1619
  for request_id in request_ids:
@@ -1808,17 +1812,20 @@ async def gpu_metrics() -> fastapi.Response:
1808
1812
  # === Internal APIs ===
1809
1813
  @app.get('/api/completion/cluster_name')
1810
1814
  async def complete_cluster_name(incomplete: str,) -> List[str]:
1811
- return global_user_state.get_cluster_names_start_with(incomplete)
1815
+ return await context_utils.to_thread(
1816
+ global_user_state.get_cluster_names_start_with, incomplete)
1812
1817
 
1813
1818
 
1814
1819
  @app.get('/api/completion/storage_name')
1815
1820
  async def complete_storage_name(incomplete: str,) -> List[str]:
1816
- return global_user_state.get_storage_names_start_with(incomplete)
1821
+ return await context_utils.to_thread(
1822
+ global_user_state.get_storage_names_start_with, incomplete)
1817
1823
 
1818
1824
 
1819
1825
  @app.get('/api/completion/volume_name')
1820
1826
  async def complete_volume_name(incomplete: str,) -> List[str]:
1821
- return global_user_state.get_volume_names_start_with(incomplete)
1827
+ return await context_utils.to_thread(
1828
+ global_user_state.get_volume_names_start_with, incomplete)
1822
1829
 
1823
1830
 
1824
1831
  @app.get('/api/completion/api_request')
@@ -1902,7 +1909,7 @@ if __name__ == '__main__':
1902
1909
  skyuvicorn.add_timestamp_prefix_for_server_logs()
1903
1910
 
1904
1911
  # Initialize global user state db
1905
- global_user_state.initialize_and_get_db()
1912
+ global_user_state.initialize_and_get_db(pool.QueuePool)
1906
1913
  # Initialize request db
1907
1914
  requests_lib.reset_db_and_logs()
1908
1915
  # Restore the server user hash
@@ -75,8 +75,10 @@ async def log_streamer(request_id: Optional[str],
75
75
  last_waiting_msg = ''
76
76
  waiting_msg = (f'Waiting for {request_task.name!r} request to be '
77
77
  f'scheduled: {request_id}')
78
- while request_task.status < requests_lib.RequestStatus.RUNNING:
79
- if request_task.status_msg is not None:
78
+ req_status = request_task.status
79
+ req_msg = request_task.status_msg
80
+ while req_status < requests_lib.RequestStatus.RUNNING:
81
+ if req_msg is not None:
80
82
  waiting_msg = request_task.status_msg
81
83
  if show_request_waiting_spinner:
82
84
  yield status_msg.update(f'[dim]{waiting_msg}[/dim]')
@@ -91,7 +93,10 @@ async def log_streamer(request_id: Optional[str],
91
93
  # polling the DB, which can be a bottleneck for high-concurrency
92
94
  # requests.
93
95
  await asyncio.sleep(0.1)
94
- request_task = await requests_lib.get_request_async(request_id)
96
+ status_with_msg = await requests_lib.get_request_status_async(
97
+ request_id, include_msg=True)
98
+ req_status = status_with_msg.status
99
+ req_msg = status_with_msg.status_msg
95
100
  if not follow:
96
101
  break
97
102
  if show_request_waiting_spinner:
@@ -153,10 +158,13 @@ async def _tail_log_file(f: aiofiles.threadpool.binary.AsyncBufferedReader,
153
158
  line: Optional[bytes] = await f.readline()
154
159
  if not line:
155
160
  if request_id is not None:
156
- request_task = await requests_lib.get_request_async(request_id)
157
- if request_task.status > requests_lib.RequestStatus.RUNNING:
158
- if (request_task.status ==
161
+ req_status = await requests_lib.get_request_status_async(
162
+ request_id)
163
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
164
+ if (req_status.status ==
159
165
  requests_lib.RequestStatus.CANCELLED):
166
+ request_task = await requests_lib.get_request_async(
167
+ request_id)
160
168
  if request_task.should_retry:
161
169
  buffer.append(
162
170
  message_utils.encode_payload(
sky/server/uvicorn.py CHANGED
@@ -146,7 +146,8 @@ class Server(uvicorn.Server):
146
146
  requests_lib.RequestStatus.PENDING,
147
147
  requests_lib.RequestStatus.RUNNING,
148
148
  ]
149
- reqs = requests_lib.get_request_tasks(status=statuses)
149
+ reqs = requests_lib.get_request_tasks(
150
+ req_filter=requests_lib.RequestTaskFilter(status=statuses))
150
151
  if not reqs:
151
152
  break
152
153
  logger.info(f'{len(reqs)} on-going requests '
@@ -302,7 +302,7 @@ available_node_types:
302
302
  provreq.kueue.x-k8s.io/maxRunDurationSeconds: "{{k8s_max_run_duration_seconds|string}}"
303
303
  {% endif %}
304
304
  {% endif %}
305
- # https://cloud.google.com/kubernetes-engine/docs/how-to/gpu-bandwidth-gpudirect-tcpx
305
+ # https://cloud.google.com/kubernetes-engine/docs/how-to/gpu-bandwidth-gpudirect-tcpx
306
306
  # Values from google cloud guide
307
307
  {% if k8s_enable_gpudirect_tcpx %}
308
308
  devices.gke.io/container.tcpx-daemon: |+
@@ -784,8 +784,8 @@ available_node_types:
784
784
  echo "Waiting for patch package to be installed..."
785
785
  done
786
786
  # Apply Ray patches for progress bar fix
787
- ~/.local/bin/uv pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && {
788
- VIRTUAL_ENV=~/skypilot-runtime python -c "from sky.skylet.ray_patches import patch; patch()" || exit 1;
787
+ ~/.local/bin/uv pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && {
788
+ VIRTUAL_ENV=~/skypilot-runtime python -c "from sky.skylet.ray_patches import patch; patch()" || exit 1;
789
789
  }
790
790
  touch /tmp/ray_skypilot_installation_complete
791
791
  echo "=== Ray and skypilot installation completed ==="
@@ -1202,7 +1202,7 @@ setup_commands:
1202
1202
  {%- endfor %}
1203
1203
  STEPS=("apt-ssh-setup" "runtime-setup" "env-setup")
1204
1204
  start_epoch=$(date +%s);
1205
-
1205
+
1206
1206
  # Wait for SSH setup to complete before proceeding
1207
1207
  if [ -f /tmp/apt_ssh_setup_started ]; then
1208
1208
  echo "=== Logs for asynchronous SSH setup ===";
@@ -1210,7 +1210,7 @@ setup_commands:
1210
1210
  { tail -f -n +1 /tmp/${STEPS[0]}.log & TAIL_PID=$!; echo "Tail PID: $TAIL_PID"; until [ -f /tmp/apt_ssh_setup_complete ]; do sleep 0.5; done; kill $TAIL_PID || true; };
1211
1211
  [ -f /tmp/${STEPS[0]}.failed ] && { echo "Error: ${STEPS[0]} failed. Exiting."; exit 1; } || true;
1212
1212
  fi
1213
-
1213
+
1214
1214
  echo "=== Logs for asynchronous ray and skypilot installation ===";
1215
1215
  if [ -f /tmp/skypilot_is_nimbus ]; then
1216
1216
  echo "=== Logs for asynchronous ray and skypilot installation ===";
@@ -40,6 +40,14 @@ available_node_types:
40
40
  skypilot:ssh_public_key_content
41
41
  Preemptible: {{use_spot}}
42
42
  BidPerGPU: {{bid_per_gpu}}
43
+ {%- if volume_mounts and volume_mounts|length > 0 %}
44
+ VolumeMounts:
45
+ {%- for vm in volume_mounts %}
46
+ - VolumeNameOnCloud: {{ vm.volume_name_on_cloud }}
47
+ VolumeIdOnCloud: {{ vm.volume_id_on_cloud }}
48
+ MountPath: {{ vm.path }}
49
+ {%- endfor %}
50
+ {%- endif %}
43
51
 
44
52
  head_node_type: ray_head_default
45
53
 
@@ -0,0 +1,60 @@
1
+ """Utility functions for benchmarking."""
2
+
3
+ import functools
4
+ import logging
5
+ import time
6
+ from typing import Callable, Optional
7
+
8
+ from sky import sky_logging
9
+
10
+ logger = sky_logging.init_logger(__name__)
11
+
12
+
13
+ def log_execution_time(func: Optional[Callable] = None,
14
+ *,
15
+ name: Optional[str] = None,
16
+ level: int = logging.DEBUG,
17
+ precision: int = 4) -> Callable:
18
+ """Mark a function and log its execution time.
19
+
20
+ Args:
21
+ func: Function to decorate.
22
+ name: Name of the function.
23
+ level: Logging level.
24
+ precision: Number of decimal places (default: 4).
25
+
26
+ Usage:
27
+ from sky.utils import benchmark_utils
28
+
29
+ @benchmark_utils.log_execution_time
30
+ def my_function():
31
+ pass
32
+
33
+ @benchmark_utils.log_execution_time(name='my_module.my_function2')
34
+ def my_function2():
35
+ pass
36
+ """
37
+
38
+ def decorator(f: Callable) -> Callable:
39
+
40
+ @functools.wraps(f)
41
+ def wrapper(*args, **kwargs):
42
+ nonlocal name
43
+ name = name or f.__name__
44
+ start_time = time.perf_counter()
45
+ try:
46
+ result = f(*args, **kwargs)
47
+ return result
48
+ finally:
49
+ end_time = time.perf_counter()
50
+ execution_time = end_time - start_time
51
+ log = (f'Method {name} executed in '
52
+ f'{execution_time:.{precision}f}')
53
+ logger.log(level, log)
54
+
55
+ return wrapper
56
+
57
+ if func is None:
58
+ return decorator
59
+ else:
60
+ return decorator(func)
@@ -41,6 +41,8 @@ RSYNC_FILTER_GITIGNORE = f'--filter=\'dir-merge,- {constants.GIT_IGNORE_FILE}\''
41
41
  # The git exclude file to support.
42
42
  GIT_EXCLUDE = '.git/info/exclude'
43
43
  RSYNC_EXCLUDE_OPTION = '--exclude-from={}'
44
+ # Owner and group metadata is not needed for downloads.
45
+ RSYNC_NO_OWNER_NO_GROUP_OPTION = '--no-owner --no-group'
44
46
 
45
47
  _HASH_MAX_LENGTH = 10
46
48
  _DEFAULT_CONNECT_TIMEOUT = 30
@@ -286,6 +288,8 @@ class CommandRunner:
286
288
  if prefix_command is not None:
287
289
  rsync_command.append(prefix_command)
288
290
  rsync_command += ['rsync', RSYNC_DISPLAY_OPTION]
291
+ if not up:
292
+ rsync_command.append(RSYNC_NO_OWNER_NO_GROUP_OPTION)
289
293
 
290
294
  # --filter
291
295
  # The source is a local path, so we need to resolve it.
@@ -4,6 +4,8 @@ import contextlib
4
4
  import logging
5
5
  import os
6
6
  import pathlib
7
+ import threading
8
+ from typing import Dict, Optional
7
9
 
8
10
  from alembic import command as alembic_command
9
11
  from alembic.config import Config
@@ -30,18 +32,32 @@ SERVE_DB_NAME = 'serve_db'
30
32
  SERVE_VERSION = '001'
31
33
  SERVE_LOCK_PATH = '~/.sky/locks/.serve_db.lock'
32
34
 
35
+ _postgres_engine_cache: Dict[str, sqlalchemy.engine.Engine] = {}
36
+ _sqlite_engine_cache: Dict[str, sqlalchemy.engine.Engine] = {}
33
37
 
34
- def get_engine(db_name: str):
38
+ _db_creation_lock = threading.Lock()
39
+
40
+
41
+ def get_engine(db_name: str,
42
+ pg_pool_class: Optional[sqlalchemy.pool.Pool] = None):
35
43
  conn_string = None
36
44
  if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
37
45
  conn_string = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
38
46
  if conn_string:
39
- engine = sqlalchemy.create_engine(conn_string,
40
- poolclass=sqlalchemy.NullPool)
47
+ if pg_pool_class is None:
48
+ pg_pool_class = sqlalchemy.NullPool
49
+ with _db_creation_lock:
50
+ if conn_string not in _postgres_engine_cache:
51
+ _postgres_engine_cache[conn_string] = sqlalchemy.create_engine(
52
+ conn_string, poolclass=pg_pool_class)
53
+ engine = _postgres_engine_cache[conn_string]
41
54
  else:
42
55
  db_path = os.path.expanduser(f'~/.sky/{db_name}.db')
43
56
  pathlib.Path(db_path).parents[0].mkdir(parents=True, exist_ok=True)
44
- engine = sqlalchemy.create_engine('sqlite:///' + db_path)
57
+ if db_path not in _sqlite_engine_cache:
58
+ _sqlite_engine_cache[db_path] = sqlalchemy.create_engine(
59
+ 'sqlite:///' + db_path)
60
+ engine = _sqlite_engine_cache[db_path]
45
61
  return engine
46
62
 
47
63
 
@@ -269,16 +269,17 @@ def _get_active_resources(
269
269
  all_managed_jobs: List[Dict[str, Any]]
270
270
  """
271
271
 
272
- def get_all_clusters():
272
+ def get_all_clusters() -> List[Dict[str, Any]]:
273
273
  return global_user_state.get_clusters()
274
274
 
275
- def get_all_managed_jobs():
275
+ def get_all_managed_jobs() -> List[Dict[str, Any]]:
276
276
  # pylint: disable=import-outside-toplevel
277
277
  from sky.jobs.server import core as managed_jobs_core
278
278
  try:
279
- return managed_jobs_core.queue(refresh=False,
280
- skip_finished=True,
281
- all_users=True)
279
+ filtered_jobs, _, _, _ = managed_jobs_core.queue(refresh=False,
280
+ skip_finished=True,
281
+ all_users=True)
282
+ return filtered_jobs
282
283
  except exceptions.ClusterNotUpError:
283
284
  logger.warning('All jobs should be finished.')
284
285
  return []
sky/utils/schemas.py CHANGED
@@ -432,7 +432,7 @@ def get_volume_schema():
432
432
  return {
433
433
  '$schema': 'https://json-schema.org/draft/2020-12/schema',
434
434
  'type': 'object',
435
- 'required': ['name', 'type', 'infra'],
435
+ 'required': ['name', 'type'],
436
436
  'additionalProperties': False,
437
437
  'properties': {
438
438
  'name': {
sky/utils/volume.py CHANGED
@@ -10,6 +10,8 @@ from sky.utils import common_utils
10
10
  from sky.utils import schemas
11
11
  from sky.utils import status_lib
12
12
 
13
+ MIN_RUNPOD_NETWORK_VOLUME_SIZE_GB = 10
14
+
13
15
 
14
16
  class VolumeAccessMode(enum.Enum):
15
17
  """Volume access mode."""
@@ -22,6 +24,7 @@ class VolumeAccessMode(enum.Enum):
22
24
  class VolumeType(enum.Enum):
23
25
  """Volume type."""
24
26
  PVC = 'k8s-pvc'
27
+ RUNPOD_NETWORK_VOLUME = 'runpod-network-volume'
25
28
 
26
29
 
27
30
  class VolumeMount: