skypilot-nightly 1.0.0.dev20250902__py3-none-any.whl → 1.0.0.dev20250903__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (59) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/runpod.py +68 -0
  3. sky/backends/backend_utils.py +5 -3
  4. sky/client/cli/command.py +20 -5
  5. sky/clouds/kubernetes.py +1 -1
  6. sky/clouds/runpod.py +17 -0
  7. sky/dashboard/out/404.html +1 -1
  8. sky/dashboard/out/_next/static/chunks/1121-ec35954c8cbea535.js +1 -0
  9. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-b77360a343d48902.js +16 -0
  10. sky/dashboard/out/_next/static/chunks/{webpack-0eaa6f7e63f51311.js → webpack-60556df644cd5d71.js} +1 -1
  11. sky/dashboard/out/_next/static/{tio0QibqY2C0F2-rPy00p → yLz6EPhW_XXmnNs1I6dmS}/_buildManifest.js +1 -1
  12. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  13. sky/dashboard/out/clusters/[cluster].html +1 -1
  14. sky/dashboard/out/clusters.html +1 -1
  15. sky/dashboard/out/config.html +1 -1
  16. sky/dashboard/out/index.html +1 -1
  17. sky/dashboard/out/infra/[context].html +1 -1
  18. sky/dashboard/out/infra.html +1 -1
  19. sky/dashboard/out/jobs/[job].html +1 -1
  20. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  21. sky/dashboard/out/jobs.html +1 -1
  22. sky/dashboard/out/users.html +1 -1
  23. sky/dashboard/out/volumes.html +1 -1
  24. sky/dashboard/out/workspace/new.html +1 -1
  25. sky/dashboard/out/workspaces/[name].html +1 -1
  26. sky/dashboard/out/workspaces.html +1 -1
  27. sky/global_user_state.py +5 -2
  28. sky/models.py +1 -0
  29. sky/provision/runpod/__init__.py +3 -0
  30. sky/provision/runpod/instance.py +17 -0
  31. sky/provision/runpod/utils.py +23 -5
  32. sky/provision/runpod/volume.py +158 -0
  33. sky/server/requests/payloads.py +7 -1
  34. sky/server/requests/preconditions.py +8 -7
  35. sky/server/requests/requests.py +123 -57
  36. sky/server/server.py +32 -25
  37. sky/server/stream_utils.py +14 -6
  38. sky/server/uvicorn.py +2 -1
  39. sky/templates/kubernetes-ray.yml.j2 +5 -5
  40. sky/templates/runpod-ray.yml.j2 +8 -0
  41. sky/utils/benchmark_utils.py +60 -0
  42. sky/utils/command_runner.py +4 -0
  43. sky/utils/db/migration_utils.py +20 -4
  44. sky/utils/resource_checker.py +6 -5
  45. sky/utils/schemas.py +1 -1
  46. sky/utils/volume.py +3 -0
  47. sky/volumes/client/sdk.py +28 -0
  48. sky/volumes/server/server.py +11 -1
  49. sky/volumes/utils.py +117 -68
  50. sky/volumes/volume.py +98 -39
  51. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/METADATA +34 -34
  52. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/RECORD +57 -55
  53. sky/dashboard/out/_next/static/chunks/1121-8afcf719ea87debc.js +0 -1
  54. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-06afb50d25f7c61f.js +0 -16
  55. /sky/dashboard/out/_next/static/{tio0QibqY2C0F2-rPy00p → yLz6EPhW_XXmnNs1I6dmS}/_ssgManifest.js +0 -0
  56. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/WHEEL +0 -0
  57. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/entry_points.txt +0 -0
  58. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/licenses/LICENSE +0 -0
  59. {skypilot_nightly-1.0.0.dev20250902.dist-info → skypilot_nightly-1.0.0.dev20250903.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -24,6 +24,7 @@ import aiofiles
24
24
  import anyio
25
25
  import fastapi
26
26
  from fastapi.middleware import cors
27
+ from sqlalchemy import pool
27
28
  import starlette.middleware.base
28
29
  import uvloop
29
30
 
@@ -1327,10 +1328,12 @@ async def provision_logs(cluster_body: payloads.ClusterNameBody,
1327
1328
  tail: int = 0) -> fastapi.responses.StreamingResponse:
1328
1329
  """Streams the provision.log for the latest launch request of a cluster."""
1329
1330
  # Prefer clusters table first, then cluster_history as fallback.
1330
- log_path_str = global_user_state.get_cluster_provision_log_path(
1331
+ log_path_str = await context_utils.to_thread(
1332
+ global_user_state.get_cluster_provision_log_path,
1331
1333
  cluster_body.cluster_name)
1332
1334
  if not log_path_str:
1333
- log_path_str = global_user_state.get_cluster_history_provision_log_path(
1335
+ log_path_str = await context_utils.to_thread(
1336
+ global_user_state.get_cluster_history_provision_log_path,
1334
1337
  cluster_body.cluster_name)
1335
1338
  if not log_path_str:
1336
1339
  raise fastapi.HTTPException(
@@ -1429,27 +1432,29 @@ async def local_down(request: fastapi.Request) -> None:
1429
1432
  async def api_get(request_id: str) -> payloads.RequestPayload:
1430
1433
  """Gets a request with a given request ID prefix."""
1431
1434
  while True:
1432
- request_task = await requests_lib.get_request_async(request_id)
1433
- if request_task is None:
1435
+ req_status = await requests_lib.get_request_status_async(request_id)
1436
+ if req_status is None:
1434
1437
  print(f'No task with request ID {request_id}', flush=True)
1435
1438
  raise fastapi.HTTPException(
1436
1439
  status_code=404, detail=f'Request {request_id!r} not found')
1437
- if request_task.status > requests_lib.RequestStatus.RUNNING:
1438
- if request_task.should_retry:
1439
- raise fastapi.HTTPException(
1440
- status_code=503,
1441
- detail=f'Request {request_id!r} should be retried')
1442
- request_error = request_task.get_error()
1443
- if request_error is not None:
1444
- raise fastapi.HTTPException(
1445
- status_code=500, detail=request_task.encode().model_dump())
1446
- return request_task.encode()
1447
- elif (request_task.status == requests_lib.RequestStatus.RUNNING and
1448
- daemons.is_daemon_request_id(request_id)):
1449
- return request_task.encode()
1440
+ if (req_status.status == requests_lib.RequestStatus.RUNNING and
1441
+ daemons.is_daemon_request_id(request_id)):
1442
+ # Daemon requests run forever, break without waiting for complete.
1443
+ break
1444
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
1445
+ break
1450
1446
  # yield control to allow other coroutines to run, sleep shortly
1451
1447
  # to avoid storming the DB and CPU in the meantime
1452
1448
  await asyncio.sleep(0.1)
1449
+ request_task = await requests_lib.get_request_async(request_id)
1450
+ if request_task.should_retry:
1451
+ raise fastapi.HTTPException(
1452
+ status_code=503, detail=f'Request {request_id!r} should be retried')
1453
+ request_error = request_task.get_error()
1454
+ if request_error is not None:
1455
+ raise fastapi.HTTPException(status_code=500,
1456
+ detail=request_task.encode().model_dump())
1457
+ return request_task.encode()
1453
1458
 
1454
1459
 
1455
1460
  @app.get('/api/stream')
@@ -1606,10 +1611,9 @@ async def api_status(
1606
1611
  requests_lib.RequestStatus.PENDING,
1607
1612
  requests_lib.RequestStatus.RUNNING,
1608
1613
  ]
1609
- return [
1610
- request_task.readable_encode()
1611
- for request_task in requests_lib.get_request_tasks(status=statuses)
1612
- ]
1614
+ request_tasks = await requests_lib.get_request_tasks_async(
1615
+ req_filter=requests_lib.RequestTaskFilter(status=statuses))
1616
+ return [r.readable_encode() for r in request_tasks]
1613
1617
  else:
1614
1618
  encoded_request_tasks = []
1615
1619
  for request_id in request_ids:
@@ -1808,17 +1812,20 @@ async def gpu_metrics() -> fastapi.Response:
1808
1812
  # === Internal APIs ===
1809
1813
  @app.get('/api/completion/cluster_name')
1810
1814
  async def complete_cluster_name(incomplete: str,) -> List[str]:
1811
- return global_user_state.get_cluster_names_start_with(incomplete)
1815
+ return await context_utils.to_thread(
1816
+ global_user_state.get_cluster_names_start_with, incomplete)
1812
1817
 
1813
1818
 
1814
1819
  @app.get('/api/completion/storage_name')
1815
1820
  async def complete_storage_name(incomplete: str,) -> List[str]:
1816
- return global_user_state.get_storage_names_start_with(incomplete)
1821
+ return await context_utils.to_thread(
1822
+ global_user_state.get_storage_names_start_with, incomplete)
1817
1823
 
1818
1824
 
1819
1825
  @app.get('/api/completion/volume_name')
1820
1826
  async def complete_volume_name(incomplete: str,) -> List[str]:
1821
- return global_user_state.get_volume_names_start_with(incomplete)
1827
+ return await context_utils.to_thread(
1828
+ global_user_state.get_volume_names_start_with, incomplete)
1822
1829
 
1823
1830
 
1824
1831
  @app.get('/api/completion/api_request')
@@ -1902,7 +1909,7 @@ if __name__ == '__main__':
1902
1909
  skyuvicorn.add_timestamp_prefix_for_server_logs()
1903
1910
 
1904
1911
  # Initialize global user state db
1905
- global_user_state.initialize_and_get_db()
1912
+ global_user_state.initialize_and_get_db(pool.QueuePool)
1906
1913
  # Initialize request db
1907
1914
  requests_lib.reset_db_and_logs()
1908
1915
  # Restore the server user hash
@@ -75,8 +75,10 @@ async def log_streamer(request_id: Optional[str],
75
75
  last_waiting_msg = ''
76
76
  waiting_msg = (f'Waiting for {request_task.name!r} request to be '
77
77
  f'scheduled: {request_id}')
78
- while request_task.status < requests_lib.RequestStatus.RUNNING:
79
- if request_task.status_msg is not None:
78
+ req_status = request_task.status
79
+ req_msg = request_task.status_msg
80
+ while req_status < requests_lib.RequestStatus.RUNNING:
81
+ if req_msg is not None:
80
82
  waiting_msg = request_task.status_msg
81
83
  if show_request_waiting_spinner:
82
84
  yield status_msg.update(f'[dim]{waiting_msg}[/dim]')
@@ -91,7 +93,10 @@ async def log_streamer(request_id: Optional[str],
91
93
  # polling the DB, which can be a bottleneck for high-concurrency
92
94
  # requests.
93
95
  await asyncio.sleep(0.1)
94
- request_task = await requests_lib.get_request_async(request_id)
96
+ status_with_msg = await requests_lib.get_request_status_async(
97
+ request_id, include_msg=True)
98
+ req_status = status_with_msg.status
99
+ req_msg = status_with_msg.status_msg
95
100
  if not follow:
96
101
  break
97
102
  if show_request_waiting_spinner:
@@ -153,10 +158,13 @@ async def _tail_log_file(f: aiofiles.threadpool.binary.AsyncBufferedReader,
153
158
  line: Optional[bytes] = await f.readline()
154
159
  if not line:
155
160
  if request_id is not None:
156
- request_task = await requests_lib.get_request_async(request_id)
157
- if request_task.status > requests_lib.RequestStatus.RUNNING:
158
- if (request_task.status ==
161
+ req_status = await requests_lib.get_request_status_async(
162
+ request_id)
163
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
164
+ if (req_status.status ==
159
165
  requests_lib.RequestStatus.CANCELLED):
166
+ request_task = await requests_lib.get_request_async(
167
+ request_id)
160
168
  if request_task.should_retry:
161
169
  buffer.append(
162
170
  message_utils.encode_payload(
sky/server/uvicorn.py CHANGED
@@ -146,7 +146,8 @@ class Server(uvicorn.Server):
146
146
  requests_lib.RequestStatus.PENDING,
147
147
  requests_lib.RequestStatus.RUNNING,
148
148
  ]
149
- reqs = requests_lib.get_request_tasks(status=statuses)
149
+ reqs = requests_lib.get_request_tasks(
150
+ req_filter=requests_lib.RequestTaskFilter(status=statuses))
150
151
  if not reqs:
151
152
  break
152
153
  logger.info(f'{len(reqs)} on-going requests '
@@ -302,7 +302,7 @@ available_node_types:
302
302
  provreq.kueue.x-k8s.io/maxRunDurationSeconds: "{{k8s_max_run_duration_seconds|string}}"
303
303
  {% endif %}
304
304
  {% endif %}
305
- # https://cloud.google.com/kubernetes-engine/docs/how-to/gpu-bandwidth-gpudirect-tcpx
305
+ # https://cloud.google.com/kubernetes-engine/docs/how-to/gpu-bandwidth-gpudirect-tcpx
306
306
  # Values from google cloud guide
307
307
  {% if k8s_enable_gpudirect_tcpx %}
308
308
  devices.gke.io/container.tcpx-daemon: |+
@@ -784,8 +784,8 @@ available_node_types:
784
784
  echo "Waiting for patch package to be installed..."
785
785
  done
786
786
  # Apply Ray patches for progress bar fix
787
- ~/.local/bin/uv pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && {
788
- VIRTUAL_ENV=~/skypilot-runtime python -c "from sky.skylet.ray_patches import patch; patch()" || exit 1;
787
+ ~/.local/bin/uv pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && {
788
+ VIRTUAL_ENV=~/skypilot-runtime python -c "from sky.skylet.ray_patches import patch; patch()" || exit 1;
789
789
  }
790
790
  touch /tmp/ray_skypilot_installation_complete
791
791
  echo "=== Ray and skypilot installation completed ==="
@@ -1202,7 +1202,7 @@ setup_commands:
1202
1202
  {%- endfor %}
1203
1203
  STEPS=("apt-ssh-setup" "runtime-setup" "env-setup")
1204
1204
  start_epoch=$(date +%s);
1205
-
1205
+
1206
1206
  # Wait for SSH setup to complete before proceeding
1207
1207
  if [ -f /tmp/apt_ssh_setup_started ]; then
1208
1208
  echo "=== Logs for asynchronous SSH setup ===";
@@ -1210,7 +1210,7 @@ setup_commands:
1210
1210
  { tail -f -n +1 /tmp/${STEPS[0]}.log & TAIL_PID=$!; echo "Tail PID: $TAIL_PID"; until [ -f /tmp/apt_ssh_setup_complete ]; do sleep 0.5; done; kill $TAIL_PID || true; };
1211
1211
  [ -f /tmp/${STEPS[0]}.failed ] && { echo "Error: ${STEPS[0]} failed. Exiting."; exit 1; } || true;
1212
1212
  fi
1213
-
1213
+
1214
1214
  echo "=== Logs for asynchronous ray and skypilot installation ===";
1215
1215
  if [ -f /tmp/skypilot_is_nimbus ]; then
1216
1216
  echo "=== Logs for asynchronous ray and skypilot installation ===";
@@ -40,6 +40,14 @@ available_node_types:
40
40
  skypilot:ssh_public_key_content
41
41
  Preemptible: {{use_spot}}
42
42
  BidPerGPU: {{bid_per_gpu}}
43
+ {%- if volume_mounts and volume_mounts|length > 0 %}
44
+ VolumeMounts:
45
+ {%- for vm in volume_mounts %}
46
+ - VolumeNameOnCloud: {{ vm.volume_name_on_cloud }}
47
+ VolumeIdOnCloud: {{ vm.volume_id_on_cloud }}
48
+ MountPath: {{ vm.path }}
49
+ {%- endfor %}
50
+ {%- endif %}
43
51
 
44
52
  head_node_type: ray_head_default
45
53
 
@@ -0,0 +1,60 @@
1
+ """Utility functions for benchmarking."""
2
+
3
+ import functools
4
+ import logging
5
+ import time
6
+ from typing import Callable, Optional
7
+
8
+ from sky import sky_logging
9
+
10
+ logger = sky_logging.init_logger(__name__)
11
+
12
+
13
+ def log_execution_time(func: Optional[Callable] = None,
14
+ *,
15
+ name: Optional[str] = None,
16
+ level: int = logging.DEBUG,
17
+ precision: int = 4) -> Callable:
18
+ """Mark a function and log its execution time.
19
+
20
+ Args:
21
+ func: Function to decorate.
22
+ name: Name of the function.
23
+ level: Logging level.
24
+ precision: Number of decimal places (default: 4).
25
+
26
+ Usage:
27
+ from sky.utils import benchmark_utils
28
+
29
+ @benchmark_utils.log_execution_time
30
+ def my_function():
31
+ pass
32
+
33
+ @benchmark_utils.log_execution_time(name='my_module.my_function2')
34
+ def my_function2():
35
+ pass
36
+ """
37
+
38
+ def decorator(f: Callable) -> Callable:
39
+
40
+ @functools.wraps(f)
41
+ def wrapper(*args, **kwargs):
42
+ nonlocal name
43
+ name = name or f.__name__
44
+ start_time = time.perf_counter()
45
+ try:
46
+ result = f(*args, **kwargs)
47
+ return result
48
+ finally:
49
+ end_time = time.perf_counter()
50
+ execution_time = end_time - start_time
51
+ log = (f'Method {name} executed in '
52
+ f'{execution_time:.{precision}f}')
53
+ logger.log(level, log)
54
+
55
+ return wrapper
56
+
57
+ if func is None:
58
+ return decorator
59
+ else:
60
+ return decorator(func)
@@ -41,6 +41,8 @@ RSYNC_FILTER_GITIGNORE = f'--filter=\'dir-merge,- {constants.GIT_IGNORE_FILE}\''
41
41
  # The git exclude file to support.
42
42
  GIT_EXCLUDE = '.git/info/exclude'
43
43
  RSYNC_EXCLUDE_OPTION = '--exclude-from={}'
44
+ # Owner and group metadata is not needed for downloads.
45
+ RSYNC_NO_OWNER_NO_GROUP_OPTION = '--no-owner --no-group'
44
46
 
45
47
  _HASH_MAX_LENGTH = 10
46
48
  _DEFAULT_CONNECT_TIMEOUT = 30
@@ -286,6 +288,8 @@ class CommandRunner:
286
288
  if prefix_command is not None:
287
289
  rsync_command.append(prefix_command)
288
290
  rsync_command += ['rsync', RSYNC_DISPLAY_OPTION]
291
+ if not up:
292
+ rsync_command.append(RSYNC_NO_OWNER_NO_GROUP_OPTION)
289
293
 
290
294
  # --filter
291
295
  # The source is a local path, so we need to resolve it.
@@ -4,6 +4,8 @@ import contextlib
4
4
  import logging
5
5
  import os
6
6
  import pathlib
7
+ import threading
8
+ from typing import Dict, Optional
7
9
 
8
10
  from alembic import command as alembic_command
9
11
  from alembic.config import Config
@@ -30,18 +32,32 @@ SERVE_DB_NAME = 'serve_db'
30
32
  SERVE_VERSION = '001'
31
33
  SERVE_LOCK_PATH = '~/.sky/locks/.serve_db.lock'
32
34
 
35
+ _postgres_engine_cache: Dict[str, sqlalchemy.engine.Engine] = {}
36
+ _sqlite_engine_cache: Dict[str, sqlalchemy.engine.Engine] = {}
33
37
 
34
- def get_engine(db_name: str):
38
+ _db_creation_lock = threading.Lock()
39
+
40
+
41
+ def get_engine(db_name: str,
42
+ pg_pool_class: Optional[sqlalchemy.pool.Pool] = None):
35
43
  conn_string = None
36
44
  if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
37
45
  conn_string = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
38
46
  if conn_string:
39
- engine = sqlalchemy.create_engine(conn_string,
40
- poolclass=sqlalchemy.NullPool)
47
+ if pg_pool_class is None:
48
+ pg_pool_class = sqlalchemy.NullPool
49
+ with _db_creation_lock:
50
+ if conn_string not in _postgres_engine_cache:
51
+ _postgres_engine_cache[conn_string] = sqlalchemy.create_engine(
52
+ conn_string, poolclass=pg_pool_class)
53
+ engine = _postgres_engine_cache[conn_string]
41
54
  else:
42
55
  db_path = os.path.expanduser(f'~/.sky/{db_name}.db')
43
56
  pathlib.Path(db_path).parents[0].mkdir(parents=True, exist_ok=True)
44
- engine = sqlalchemy.create_engine('sqlite:///' + db_path)
57
+ if db_path not in _sqlite_engine_cache:
58
+ _sqlite_engine_cache[db_path] = sqlalchemy.create_engine(
59
+ 'sqlite:///' + db_path)
60
+ engine = _sqlite_engine_cache[db_path]
45
61
  return engine
46
62
 
47
63
 
@@ -269,16 +269,17 @@ def _get_active_resources(
269
269
  all_managed_jobs: List[Dict[str, Any]]
270
270
  """
271
271
 
272
- def get_all_clusters():
272
+ def get_all_clusters() -> List[Dict[str, Any]]:
273
273
  return global_user_state.get_clusters()
274
274
 
275
- def get_all_managed_jobs():
275
+ def get_all_managed_jobs() -> List[Dict[str, Any]]:
276
276
  # pylint: disable=import-outside-toplevel
277
277
  from sky.jobs.server import core as managed_jobs_core
278
278
  try:
279
- return managed_jobs_core.queue(refresh=False,
280
- skip_finished=True,
281
- all_users=True)
279
+ filtered_jobs, _, _, _ = managed_jobs_core.queue(refresh=False,
280
+ skip_finished=True,
281
+ all_users=True)
282
+ return filtered_jobs
282
283
  except exceptions.ClusterNotUpError:
283
284
  logger.warning('All jobs should be finished.')
284
285
  return []
sky/utils/schemas.py CHANGED
@@ -432,7 +432,7 @@ def get_volume_schema():
432
432
  return {
433
433
  '$schema': 'https://json-schema.org/draft/2020-12/schema',
434
434
  'type': 'object',
435
- 'required': ['name', 'type', 'infra'],
435
+ 'required': ['name', 'type'],
436
436
  'additionalProperties': False,
437
437
  'properties': {
438
438
  'name': {
sky/utils/volume.py CHANGED
@@ -10,6 +10,8 @@ from sky.utils import common_utils
10
10
  from sky.utils import schemas
11
11
  from sky.utils import status_lib
12
12
 
13
+ MIN_RUNPOD_NETWORK_VOLUME_SIZE_GB = 10
14
+
13
15
 
14
16
  class VolumeAccessMode(enum.Enum):
15
17
  """Volume access mode."""
@@ -22,6 +24,7 @@ class VolumeAccessMode(enum.Enum):
22
24
  class VolumeType(enum.Enum):
23
25
  """Volume type."""
24
26
  PVC = 'k8s-pvc'
27
+ RUNPOD_NETWORK_VOLUME = 'runpod-network-volume'
25
28
 
26
29
 
27
30
  class VolumeMount:
sky/volumes/client/sdk.py CHANGED
@@ -27,6 +27,34 @@ logger = sky_logging.init_logger(__name__)
27
27
  def apply(volume: volume_lib.Volume) -> server_common.RequestId[None]:
28
28
  """Creates or registers a volume.
29
29
 
30
+ Example:
31
+ .. code-block:: python
32
+
33
+ import sky.volumes
34
+ cfg = {
35
+ 'name': 'pvc',
36
+ 'type': 'k8s-pvc',
37
+ 'size': '100GB',
38
+ 'labels': {
39
+ 'key': 'value',
40
+ },
41
+ }
42
+ vol = sky.volumes.Volume.from_yaml_config(cfg)
43
+ request_id = sky.volumes.apply(vol)
44
+ sky.get(request_id)
45
+
46
+ or
47
+
48
+ import sky.volumes
49
+ vol = sky.volumes.Volume(
50
+ name='vol',
51
+ type='runpod-network-volume',
52
+ infra='runpod/ca/CA-MTL-1',
53
+ size='100GB',
54
+ )
55
+ request_id = sky.volumes.apply(vol)
56
+ sky.get(request_id)
57
+
30
58
  Args:
31
59
  volume: The volume to apply.
32
60
 
@@ -19,10 +19,15 @@ router = fastapi.APIRouter()
19
19
  @router.get('')
20
20
  async def volume_list(request: fastapi.Request) -> None:
21
21
  """Gets the volumes."""
22
+ auth_user = request.state.auth_user
23
+ auth_user_env_vars_kwargs = {
24
+ 'env_vars': auth_user.to_env_vars()
25
+ } if auth_user else {}
26
+ volume_list_body = payloads.VolumeListBody(**auth_user_env_vars_kwargs)
22
27
  executor.schedule_request(
23
28
  request_id=request.state.request_id,
24
29
  request_name='volume_list',
25
- request_body=payloads.RequestBody(),
30
+ request_body=volume_list_body,
26
31
  func=core.volume_list,
27
32
  schedule_type=requests_lib.ScheduleType.SHORT,
28
33
  )
@@ -76,6 +81,11 @@ async def volume_apply(request: fastapi.Request,
76
81
  elif access_mode not in supported_access_modes:
77
82
  raise fastapi.HTTPException(
78
83
  status_code=400, detail=f'Invalid access mode: {access_mode}')
84
+ elif volume_type == volume.VolumeType.RUNPOD_NETWORK_VOLUME.value:
85
+ if not cloud.is_same_cloud(clouds.RunPod()):
86
+ raise fastapi.HTTPException(
87
+ status_code=400,
88
+ detail='Runpod network volume is only supported on Runpod')
79
89
  executor.schedule_request(
80
90
  request_id=request.state.request_id,
81
91
  request_name='volume_apply',