skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +22 -6
- sky/adaptors/aws.py +81 -16
- sky/adaptors/common.py +25 -2
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/do.py +8 -2
- sky/adaptors/gcp.py +11 -0
- sky/adaptors/hyperbolic.py +8 -0
- sky/adaptors/ibm.py +5 -2
- sky/adaptors/kubernetes.py +149 -18
- sky/adaptors/nebius.py +173 -30
- sky/adaptors/primeintellect.py +1 -0
- sky/adaptors/runpod.py +68 -0
- sky/adaptors/seeweb.py +183 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +187 -4
- sky/authentication.py +179 -225
- sky/backends/__init__.py +4 -2
- sky/backends/backend.py +22 -9
- sky/backends/backend_utils.py +1323 -397
- sky/backends/cloud_vm_ray_backend.py +1749 -1029
- sky/backends/docker_utils.py +1 -1
- sky/backends/local_docker_backend.py +11 -6
- sky/backends/task_codegen.py +633 -0
- sky/backends/wheel_utils.py +55 -9
- sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
- sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
- sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
- sky/{clouds/service_catalog → catalog}/common.py +90 -49
- sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
- sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
- sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
- sky/catalog/data_fetchers/fetch_nebius.py +338 -0
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
- sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
- sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
- sky/catalog/hyperbolic_catalog.py +136 -0
- sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
- sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
- sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
- sky/catalog/primeintellect_catalog.py +95 -0
- sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
- sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
- sky/catalog/seeweb_catalog.py +184 -0
- sky/catalog/shadeform_catalog.py +165 -0
- sky/catalog/ssh_catalog.py +167 -0
- sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
- sky/check.py +533 -185
- sky/cli.py +5 -5975
- sky/client/{cli.py → cli/command.py} +2591 -1956
- sky/client/cli/deprecation_utils.py +99 -0
- sky/client/cli/flags.py +359 -0
- sky/client/cli/table_utils.py +322 -0
- sky/client/cli/utils.py +79 -0
- sky/client/common.py +78 -32
- sky/client/oauth.py +82 -0
- sky/client/sdk.py +1219 -319
- sky/client/sdk_async.py +827 -0
- sky/client/service_account_auth.py +47 -0
- sky/cloud_stores.py +82 -3
- sky/clouds/__init__.py +13 -0
- sky/clouds/aws.py +564 -164
- sky/clouds/azure.py +105 -83
- sky/clouds/cloud.py +140 -40
- sky/clouds/cudo.py +68 -50
- sky/clouds/do.py +66 -48
- sky/clouds/fluidstack.py +63 -44
- sky/clouds/gcp.py +339 -110
- sky/clouds/hyperbolic.py +293 -0
- sky/clouds/ibm.py +70 -49
- sky/clouds/kubernetes.py +570 -162
- sky/clouds/lambda_cloud.py +74 -54
- sky/clouds/nebius.py +210 -81
- sky/clouds/oci.py +88 -66
- sky/clouds/paperspace.py +61 -44
- sky/clouds/primeintellect.py +317 -0
- sky/clouds/runpod.py +164 -74
- sky/clouds/scp.py +89 -86
- sky/clouds/seeweb.py +477 -0
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +263 -0
- sky/clouds/utils/aws_utils.py +10 -4
- sky/clouds/utils/gcp_utils.py +87 -11
- sky/clouds/utils/oci_utils.py +38 -14
- sky/clouds/utils/scp_utils.py +231 -167
- sky/clouds/vast.py +99 -77
- sky/clouds/vsphere.py +51 -40
- sky/core.py +375 -173
- sky/dag.py +15 -0
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
- sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
- sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
- sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
- sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
- sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
- sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
- sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
- sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
- sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
- sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
- sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -0
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -0
- sky/dashboard/out/infra.html +1 -0
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -0
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -0
- sky/dashboard/out/volumes.html +1 -0
- sky/dashboard/out/workspace/new.html +1 -0
- sky/dashboard/out/workspaces/[name].html +1 -0
- sky/dashboard/out/workspaces.html +1 -0
- sky/data/data_utils.py +137 -1
- sky/data/mounting_utils.py +269 -84
- sky/data/storage.py +1460 -1807
- sky/data/storage_utils.py +43 -57
- sky/exceptions.py +126 -2
- sky/execution.py +216 -63
- sky/global_user_state.py +2390 -586
- sky/jobs/__init__.py +7 -0
- sky/jobs/client/sdk.py +300 -58
- sky/jobs/client/sdk_async.py +161 -0
- sky/jobs/constants.py +15 -8
- sky/jobs/controller.py +848 -275
- sky/jobs/file_content_utils.py +128 -0
- sky/jobs/log_gc.py +193 -0
- sky/jobs/recovery_strategy.py +402 -152
- sky/jobs/scheduler.py +314 -189
- sky/jobs/server/core.py +836 -255
- sky/jobs/server/server.py +156 -115
- sky/jobs/server/utils.py +136 -0
- sky/jobs/state.py +2109 -706
- sky/jobs/utils.py +1306 -215
- sky/logs/__init__.py +21 -0
- sky/logs/agent.py +108 -0
- sky/logs/aws.py +243 -0
- sky/logs/gcp.py +91 -0
- sky/metrics/__init__.py +0 -0
- sky/metrics/utils.py +453 -0
- sky/models.py +78 -1
- sky/optimizer.py +164 -70
- sky/provision/__init__.py +90 -4
- sky/provision/aws/config.py +147 -26
- sky/provision/aws/instance.py +136 -50
- sky/provision/azure/instance.py +11 -6
- sky/provision/common.py +13 -1
- sky/provision/cudo/cudo_machine_type.py +1 -1
- sky/provision/cudo/cudo_utils.py +14 -8
- sky/provision/cudo/cudo_wrapper.py +72 -71
- sky/provision/cudo/instance.py +10 -6
- sky/provision/do/instance.py +10 -6
- sky/provision/do/utils.py +4 -3
- sky/provision/docker_utils.py +140 -33
- sky/provision/fluidstack/instance.py +13 -8
- sky/provision/gcp/__init__.py +1 -0
- sky/provision/gcp/config.py +301 -19
- sky/provision/gcp/constants.py +218 -0
- sky/provision/gcp/instance.py +36 -8
- sky/provision/gcp/instance_utils.py +18 -4
- sky/provision/gcp/volume_utils.py +247 -0
- sky/provision/hyperbolic/__init__.py +12 -0
- sky/provision/hyperbolic/config.py +10 -0
- sky/provision/hyperbolic/instance.py +437 -0
- sky/provision/hyperbolic/utils.py +373 -0
- sky/provision/instance_setup.py +101 -20
- sky/provision/kubernetes/__init__.py +5 -0
- sky/provision/kubernetes/config.py +9 -52
- sky/provision/kubernetes/constants.py +17 -0
- sky/provision/kubernetes/instance.py +919 -280
- sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
- sky/provision/kubernetes/network.py +27 -17
- sky/provision/kubernetes/network_utils.py +44 -43
- sky/provision/kubernetes/utils.py +1221 -534
- sky/provision/kubernetes/volume.py +343 -0
- sky/provision/lambda_cloud/instance.py +22 -16
- sky/provision/nebius/constants.py +50 -0
- sky/provision/nebius/instance.py +19 -6
- sky/provision/nebius/utils.py +237 -137
- sky/provision/oci/instance.py +10 -5
- sky/provision/paperspace/instance.py +10 -7
- sky/provision/paperspace/utils.py +1 -1
- sky/provision/primeintellect/__init__.py +10 -0
- sky/provision/primeintellect/config.py +11 -0
- sky/provision/primeintellect/instance.py +454 -0
- sky/provision/primeintellect/utils.py +398 -0
- sky/provision/provisioner.py +117 -36
- sky/provision/runpod/__init__.py +5 -0
- sky/provision/runpod/instance.py +27 -6
- sky/provision/runpod/utils.py +51 -18
- sky/provision/runpod/volume.py +214 -0
- sky/provision/scp/__init__.py +15 -0
- sky/provision/scp/config.py +93 -0
- sky/provision/scp/instance.py +707 -0
- sky/provision/seeweb/__init__.py +11 -0
- sky/provision/seeweb/config.py +13 -0
- sky/provision/seeweb/instance.py +812 -0
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/ssh/__init__.py +18 -0
- sky/provision/vast/instance.py +13 -8
- sky/provision/vast/utils.py +10 -7
- sky/provision/volume.py +164 -0
- sky/provision/vsphere/common/ssl_helper.py +1 -1
- sky/provision/vsphere/common/vapiconnect.py +2 -1
- sky/provision/vsphere/common/vim_utils.py +4 -4
- sky/provision/vsphere/instance.py +15 -10
- sky/provision/vsphere/vsphere_utils.py +17 -20
- sky/py.typed +0 -0
- sky/resources.py +845 -119
- sky/schemas/__init__.py +0 -0
- sky/schemas/api/__init__.py +0 -0
- sky/schemas/api/responses.py +227 -0
- sky/schemas/db/README +4 -0
- sky/schemas/db/env.py +90 -0
- sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
- sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
- sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
- sky/schemas/db/global_user_state/004_is_managed.py +34 -0
- sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
- sky/schemas/db/global_user_state/006_provision_log.py +41 -0
- sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
- sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
- sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
- sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
- sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
- sky/schemas/db/script.py.mako +28 -0
- sky/schemas/db/serve_state/001_initial_schema.py +67 -0
- sky/schemas/db/serve_state/002_yaml_content.py +34 -0
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
- sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
- sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
- sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
- sky/schemas/generated/__init__.py +0 -0
- sky/schemas/generated/autostopv1_pb2.py +36 -0
- sky/schemas/generated/autostopv1_pb2.pyi +43 -0
- sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
- sky/schemas/generated/jobsv1_pb2.py +86 -0
- sky/schemas/generated/jobsv1_pb2.pyi +254 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
- sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
- sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
- sky/schemas/generated/servev1_pb2.py +58 -0
- sky/schemas/generated/servev1_pb2.pyi +115 -0
- sky/schemas/generated/servev1_pb2_grpc.py +322 -0
- sky/serve/autoscalers.py +357 -5
- sky/serve/client/impl.py +310 -0
- sky/serve/client/sdk.py +47 -139
- sky/serve/client/sdk_async.py +130 -0
- sky/serve/constants.py +12 -9
- sky/serve/controller.py +68 -17
- sky/serve/load_balancer.py +106 -60
- sky/serve/load_balancing_policies.py +116 -2
- sky/serve/replica_managers.py +434 -249
- sky/serve/serve_rpc_utils.py +179 -0
- sky/serve/serve_state.py +569 -257
- sky/serve/serve_utils.py +775 -265
- sky/serve/server/core.py +66 -711
- sky/serve/server/impl.py +1093 -0
- sky/serve/server/server.py +21 -18
- sky/serve/service.py +192 -89
- sky/serve/service_spec.py +144 -20
- sky/serve/spot_placer.py +3 -0
- sky/server/auth/__init__.py +0 -0
- sky/server/auth/authn.py +50 -0
- sky/server/auth/loopback.py +38 -0
- sky/server/auth/oauth2_proxy.py +202 -0
- sky/server/common.py +478 -182
- sky/server/config.py +85 -23
- sky/server/constants.py +44 -6
- sky/server/daemons.py +295 -0
- sky/server/html/token_page.html +185 -0
- sky/server/metrics.py +160 -0
- sky/server/middleware_utils.py +166 -0
- sky/server/requests/executor.py +558 -138
- sky/server/requests/payloads.py +364 -24
- sky/server/requests/preconditions.py +21 -17
- sky/server/requests/process.py +112 -29
- sky/server/requests/request_names.py +121 -0
- sky/server/requests/requests.py +822 -226
- sky/server/requests/serializers/decoders.py +82 -31
- sky/server/requests/serializers/encoders.py +140 -22
- sky/server/requests/threads.py +117 -0
- sky/server/rest.py +455 -0
- sky/server/server.py +1309 -285
- sky/server/state.py +20 -0
- sky/server/stream_utils.py +327 -61
- sky/server/uvicorn.py +217 -3
- sky/server/versions.py +270 -0
- sky/setup_files/MANIFEST.in +11 -1
- sky/setup_files/alembic.ini +160 -0
- sky/setup_files/dependencies.py +139 -31
- sky/setup_files/setup.py +44 -42
- sky/sky_logging.py +114 -7
- sky/skylet/attempt_skylet.py +106 -24
- sky/skylet/autostop_lib.py +129 -8
- sky/skylet/configs.py +29 -20
- sky/skylet/constants.py +216 -25
- sky/skylet/events.py +101 -21
- sky/skylet/job_lib.py +345 -164
- sky/skylet/log_lib.py +297 -18
- sky/skylet/log_lib.pyi +44 -1
- sky/skylet/providers/ibm/node_provider.py +12 -8
- sky/skylet/providers/ibm/vpc_provider.py +13 -12
- sky/skylet/ray_patches/__init__.py +17 -3
- sky/skylet/ray_patches/autoscaler.py.diff +18 -0
- sky/skylet/ray_patches/cli.py.diff +19 -0
- sky/skylet/ray_patches/command_runner.py.diff +17 -0
- sky/skylet/ray_patches/log_monitor.py.diff +20 -0
- sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
- sky/skylet/ray_patches/updater.py.diff +18 -0
- sky/skylet/ray_patches/worker.py.diff +41 -0
- sky/skylet/runtime_utils.py +21 -0
- sky/skylet/services.py +568 -0
- sky/skylet/skylet.py +72 -4
- sky/skylet/subprocess_daemon.py +104 -29
- sky/skypilot_config.py +506 -99
- sky/ssh_node_pools/__init__.py +1 -0
- sky/ssh_node_pools/core.py +135 -0
- sky/ssh_node_pools/server.py +233 -0
- sky/task.py +685 -163
- sky/templates/aws-ray.yml.j2 +11 -3
- sky/templates/azure-ray.yml.j2 +2 -1
- sky/templates/cudo-ray.yml.j2 +1 -0
- sky/templates/do-ray.yml.j2 +2 -1
- sky/templates/fluidstack-ray.yml.j2 +1 -0
- sky/templates/gcp-ray.yml.j2 +62 -1
- sky/templates/hyperbolic-ray.yml.j2 +68 -0
- sky/templates/ibm-ray.yml.j2 +2 -1
- sky/templates/jobs-controller.yaml.j2 +27 -24
- sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
- sky/templates/kubernetes-ray.yml.j2 +611 -50
- sky/templates/lambda-ray.yml.j2 +2 -1
- sky/templates/nebius-ray.yml.j2 +34 -12
- sky/templates/oci-ray.yml.j2 +1 -0
- sky/templates/paperspace-ray.yml.j2 +2 -1
- sky/templates/primeintellect-ray.yml.j2 +72 -0
- sky/templates/runpod-ray.yml.j2 +10 -1
- sky/templates/scp-ray.yml.j2 +4 -50
- sky/templates/seeweb-ray.yml.j2 +171 -0
- sky/templates/shadeform-ray.yml.j2 +73 -0
- sky/templates/sky-serve-controller.yaml.j2 +22 -2
- sky/templates/vast-ray.yml.j2 +1 -0
- sky/templates/vsphere-ray.yml.j2 +1 -0
- sky/templates/websocket_proxy.py +212 -37
- sky/usage/usage_lib.py +31 -15
- sky/users/__init__.py +0 -0
- sky/users/model.conf +15 -0
- sky/users/permission.py +397 -0
- sky/users/rbac.py +121 -0
- sky/users/server.py +720 -0
- sky/users/token_service.py +218 -0
- sky/utils/accelerator_registry.py +35 -5
- sky/utils/admin_policy_utils.py +84 -38
- sky/utils/annotations.py +38 -5
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/atomic.py +1 -1
- sky/utils/auth_utils.py +153 -0
- sky/utils/benchmark_utils.py +60 -0
- sky/utils/cli_utils/status_utils.py +159 -86
- sky/utils/cluster_utils.py +31 -9
- sky/utils/command_runner.py +354 -68
- sky/utils/command_runner.pyi +93 -3
- sky/utils/common.py +35 -8
- sky/utils/common_utils.py +314 -91
- sky/utils/config_utils.py +74 -5
- sky/utils/context.py +403 -0
- sky/utils/context_utils.py +242 -0
- sky/utils/controller_utils.py +383 -89
- sky/utils/dag_utils.py +31 -12
- sky/utils/db/__init__.py +0 -0
- sky/utils/db/db_utils.py +485 -0
- sky/utils/db/kv_cache.py +149 -0
- sky/utils/db/migration_utils.py +137 -0
- sky/utils/directory_utils.py +12 -0
- sky/utils/env_options.py +13 -0
- sky/utils/git.py +567 -0
- sky/utils/git_clone.sh +460 -0
- sky/utils/infra_utils.py +195 -0
- sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
- sky/utils/kubernetes/config_map_utils.py +133 -0
- sky/utils/kubernetes/create_cluster.sh +15 -29
- sky/utils/kubernetes/delete_cluster.sh +10 -7
- sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
- sky/utils/kubernetes/generate_kind_config.py +6 -66
- sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
- sky/utils/kubernetes/gpu_labeler.py +18 -8
- sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
- sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
- sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
- sky/utils/kubernetes/rsync_helper.sh +11 -3
- sky/utils/kubernetes/ssh-tunnel.sh +379 -0
- sky/utils/kubernetes/ssh_utils.py +221 -0
- sky/utils/kubernetes_enums.py +8 -15
- sky/utils/lock_events.py +94 -0
- sky/utils/locks.py +416 -0
- sky/utils/log_utils.py +82 -107
- sky/utils/perf_utils.py +22 -0
- sky/utils/resource_checker.py +298 -0
- sky/utils/resources_utils.py +249 -32
- sky/utils/rich_utils.py +217 -39
- sky/utils/schemas.py +955 -160
- sky/utils/serialize_utils.py +16 -0
- sky/utils/status_lib.py +10 -0
- sky/utils/subprocess_utils.py +29 -15
- sky/utils/tempstore.py +70 -0
- sky/utils/thread_utils.py +91 -0
- sky/utils/timeline.py +26 -53
- sky/utils/ux_utils.py +84 -15
- sky/utils/validator.py +11 -1
- sky/utils/volume.py +165 -0
- sky/utils/yaml_utils.py +111 -0
- sky/volumes/__init__.py +13 -0
- sky/volumes/client/__init__.py +0 -0
- sky/volumes/client/sdk.py +150 -0
- sky/volumes/server/__init__.py +0 -0
- sky/volumes/server/core.py +270 -0
- sky/volumes/server/server.py +124 -0
- sky/volumes/volume.py +215 -0
- sky/workspaces/__init__.py +0 -0
- sky/workspaces/core.py +655 -0
- sky/workspaces/server.py +101 -0
- sky/workspaces/utils.py +56 -0
- sky_templates/README.md +3 -0
- sky_templates/__init__.py +3 -0
- sky_templates/ray/__init__.py +0 -0
- sky_templates/ray/start_cluster +183 -0
- sky_templates/ray/stop_cluster +75 -0
- skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
- skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
- skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
- sky/benchmark/benchmark_state.py +0 -256
- sky/benchmark/benchmark_utils.py +0 -641
- sky/clouds/service_catalog/constants.py +0 -7
- sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
- sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
- sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
- sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
- sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
- sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
- sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
- sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
- sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
- sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
- sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
- sky/jobs/dashboard/dashboard.py +0 -223
- sky/jobs/dashboard/static/favicon.ico +0 -0
- sky/jobs/dashboard/templates/index.html +0 -831
- sky/jobs/server/dashboard_utils.py +0 -69
- sky/skylet/providers/scp/__init__.py +0 -2
- sky/skylet/providers/scp/config.py +0 -149
- sky/skylet/providers/scp/node_provider.py +0 -578
- sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
- sky/utils/db_utils.py +0 -100
- sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
- skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
- skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
- skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
- /sky/{clouds/service_catalog → catalog}/config.py +0 -0
- /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
- /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
- /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/client/sdk.py
CHANGED
|
@@ -10,75 +10,165 @@ Usage example:
|
|
|
10
10
|
statuses = sky.get(request_id)
|
|
11
11
|
|
|
12
12
|
"""
|
|
13
|
-
import
|
|
13
|
+
from http import cookiejar
|
|
14
14
|
import json
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
|
-
import pathlib
|
|
18
17
|
import subprocess
|
|
19
18
|
import typing
|
|
20
|
-
from typing import Any, Dict, List, Optional, Tuple,
|
|
21
|
-
|
|
19
|
+
from typing import (Any, Dict, Iterator, List, Literal, Optional, Tuple,
|
|
20
|
+
TypeVar, Union)
|
|
21
|
+
from urllib import parse as urlparse
|
|
22
22
|
|
|
23
23
|
import click
|
|
24
24
|
import colorama
|
|
25
25
|
import filelock
|
|
26
26
|
|
|
27
27
|
from sky import admin_policy
|
|
28
|
-
from sky import backends
|
|
29
28
|
from sky import exceptions
|
|
30
29
|
from sky import sky_logging
|
|
31
30
|
from sky import skypilot_config
|
|
32
31
|
from sky.adaptors import common as adaptors_common
|
|
33
32
|
from sky.client import common as client_common
|
|
33
|
+
from sky.client import oauth as oauth_lib
|
|
34
|
+
from sky.jobs import scheduler
|
|
35
|
+
from sky.jobs import utils as managed_job_utils
|
|
36
|
+
from sky.schemas.api import responses
|
|
34
37
|
from sky.server import common as server_common
|
|
38
|
+
from sky.server import rest
|
|
39
|
+
from sky.server import versions
|
|
35
40
|
from sky.server.requests import payloads
|
|
41
|
+
from sky.server.requests import request_names
|
|
36
42
|
from sky.server.requests import requests as requests_lib
|
|
43
|
+
from sky.skylet import autostop_lib
|
|
37
44
|
from sky.skylet import constants
|
|
38
45
|
from sky.usage import usage_lib
|
|
46
|
+
from sky.utils import admin_policy_utils
|
|
39
47
|
from sky.utils import annotations
|
|
40
48
|
from sky.utils import cluster_utils
|
|
41
49
|
from sky.utils import common
|
|
42
50
|
from sky.utils import common_utils
|
|
51
|
+
from sky.utils import context as sky_context
|
|
43
52
|
from sky.utils import dag_utils
|
|
44
53
|
from sky.utils import env_options
|
|
54
|
+
from sky.utils import infra_utils
|
|
45
55
|
from sky.utils import rich_utils
|
|
46
56
|
from sky.utils import status_lib
|
|
47
57
|
from sky.utils import subprocess_utils
|
|
48
58
|
from sky.utils import ux_utils
|
|
59
|
+
from sky.utils import yaml_utils
|
|
60
|
+
from sky.utils.kubernetes import ssh_utils
|
|
49
61
|
|
|
50
62
|
if typing.TYPE_CHECKING:
|
|
63
|
+
import base64
|
|
64
|
+
import binascii
|
|
51
65
|
import io
|
|
66
|
+
import pathlib
|
|
67
|
+
import time
|
|
68
|
+
import webbrowser
|
|
52
69
|
|
|
53
70
|
import psutil
|
|
54
71
|
import requests
|
|
55
72
|
|
|
56
73
|
import sky
|
|
74
|
+
from sky import backends
|
|
75
|
+
from sky import catalog
|
|
76
|
+
from sky import models
|
|
77
|
+
from sky.provision.kubernetes import utils as kubernetes_utils
|
|
78
|
+
from sky.skylet import job_lib
|
|
57
79
|
else:
|
|
80
|
+
# only used in api_login()
|
|
81
|
+
base64 = adaptors_common.LazyImport('base64')
|
|
82
|
+
binascii = adaptors_common.LazyImport('binascii')
|
|
83
|
+
pathlib = adaptors_common.LazyImport('pathlib')
|
|
84
|
+
time = adaptors_common.LazyImport('time')
|
|
85
|
+
# only used in dashboard() and api_login()
|
|
86
|
+
webbrowser = adaptors_common.LazyImport('webbrowser')
|
|
87
|
+
# only used in api_stop()
|
|
58
88
|
psutil = adaptors_common.LazyImport('psutil')
|
|
59
|
-
requests = adaptors_common.LazyImport('requests')
|
|
60
89
|
|
|
61
90
|
logger = sky_logging.init_logger(__name__)
|
|
62
91
|
logging.getLogger('httpx').setLevel(logging.CRITICAL)
|
|
63
92
|
|
|
93
|
+
_LINE_PROCESSED_KEY = 'line_processed'
|
|
64
94
|
|
|
65
|
-
|
|
95
|
+
T = TypeVar('T')
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def reload_config() -> None:
|
|
99
|
+
"""Reloads the client-side config."""
|
|
100
|
+
skypilot_config.safe_reload_config()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# The overloads are not comprehensive - e.g. get_result Literal[False] could be
|
|
104
|
+
# specified to return None. We can add more overloads if needed. To do that see
|
|
105
|
+
# https://github.com/python/mypy/issues/8634#issuecomment-609411104
|
|
106
|
+
@typing.overload
|
|
107
|
+
def stream_response(request_id: None,
|
|
66
108
|
response: 'requests.Response',
|
|
67
|
-
output_stream: Optional['io.TextIOBase'] = None
|
|
109
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
110
|
+
resumable: bool = False,
|
|
111
|
+
get_result: bool = True) -> None:
|
|
112
|
+
...
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@typing.overload
|
|
116
|
+
def stream_response(request_id: server_common.RequestId[T],
|
|
117
|
+
response: 'requests.Response',
|
|
118
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
119
|
+
resumable: bool = False,
|
|
120
|
+
get_result: Literal[True] = True) -> T:
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@typing.overload
|
|
125
|
+
def stream_response(request_id: server_common.RequestId[T],
|
|
126
|
+
response: 'requests.Response',
|
|
127
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
128
|
+
resumable: bool = False,
|
|
129
|
+
get_result: bool = True) -> Optional[T]:
|
|
130
|
+
...
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def stream_response(request_id: Optional[server_common.RequestId[T]],
|
|
134
|
+
response: 'requests.Response',
|
|
135
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
136
|
+
resumable: bool = False,
|
|
137
|
+
get_result: bool = True) -> Optional[T]:
|
|
68
138
|
"""Streams the response to the console.
|
|
69
139
|
|
|
70
140
|
Args:
|
|
71
|
-
request_id: The request ID.
|
|
141
|
+
request_id: The request ID of the request to stream. May be a full
|
|
142
|
+
request ID or a prefix.
|
|
143
|
+
If None, the latest request submitted to the API server is streamed.
|
|
144
|
+
Using None request_id is not recommended in multi-user environments.
|
|
72
145
|
response: The HTTP response.
|
|
73
146
|
output_stream: The output stream to write to. If None, print to the
|
|
74
147
|
console.
|
|
148
|
+
resumable: Whether the response is resumable on retry. If True, the
|
|
149
|
+
streaming will start from the previous failure point on retry.
|
|
150
|
+
get_result: Whether to get the result of the request. This will
|
|
151
|
+
typically be set to False for `--no-follow` flags as requests may
|
|
152
|
+
continue to run for long periods of time without further streaming.
|
|
75
153
|
"""
|
|
76
154
|
|
|
155
|
+
retry_context: Optional[rest.RetryContext] = None
|
|
156
|
+
if resumable:
|
|
157
|
+
retry_context = rest.get_retry_context()
|
|
77
158
|
try:
|
|
159
|
+
line_count = 0
|
|
78
160
|
for line in rich_utils.decode_rich_status(response):
|
|
79
161
|
if line is not None:
|
|
80
|
-
|
|
81
|
-
|
|
162
|
+
line_count += 1
|
|
163
|
+
if retry_context is None:
|
|
164
|
+
print(line, flush=True, end='', file=output_stream)
|
|
165
|
+
elif line_count > retry_context.line_processed:
|
|
166
|
+
print(line, flush=True, end='', file=output_stream)
|
|
167
|
+
retry_context.line_processed = line_count
|
|
168
|
+
if request_id is not None and get_result:
|
|
169
|
+
return get(request_id)
|
|
170
|
+
else:
|
|
171
|
+
return None
|
|
82
172
|
except Exception: # pylint: disable=broad-except
|
|
83
173
|
logger.debug(f'To stream request logs: sky api logs {request_id}')
|
|
84
174
|
raise
|
|
@@ -87,13 +177,18 @@ def stream_response(request_id: Optional[str],
|
|
|
87
177
|
@usage_lib.entrypoint
|
|
88
178
|
@server_common.check_server_healthy_or_start
|
|
89
179
|
@annotations.client_api
|
|
90
|
-
def check(
|
|
91
|
-
|
|
180
|
+
def check(
|
|
181
|
+
infra_list: Optional[Tuple[str, ...]],
|
|
182
|
+
verbose: bool,
|
|
183
|
+
workspace: Optional[str] = None
|
|
184
|
+
) -> server_common.RequestId[Dict[str, List[str]]]:
|
|
92
185
|
"""Checks the credentials to enable clouds.
|
|
93
186
|
|
|
94
187
|
Args:
|
|
95
|
-
|
|
188
|
+
infra: The infra to check.
|
|
96
189
|
verbose: Whether to show verbose output.
|
|
190
|
+
workspace: The workspace to check. If None, all workspaces will be
|
|
191
|
+
checked.
|
|
97
192
|
|
|
98
193
|
Returns:
|
|
99
194
|
The request ID of the check request.
|
|
@@ -101,41 +196,69 @@ def check(clouds: Optional[Tuple[str]],
|
|
|
101
196
|
Request Returns:
|
|
102
197
|
None
|
|
103
198
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
199
|
+
if infra_list is None:
|
|
200
|
+
clouds = None
|
|
201
|
+
else:
|
|
202
|
+
specified_clouds = []
|
|
203
|
+
for infra_str in infra_list:
|
|
204
|
+
infra = infra_utils.InfraInfo.from_str(infra_str)
|
|
205
|
+
if infra.cloud is None:
|
|
206
|
+
with ux_utils.print_exception_no_traceback():
|
|
207
|
+
raise ValueError(f'Invalid infra to check: {infra_str}')
|
|
208
|
+
if infra.region is not None or infra.zone is not None:
|
|
209
|
+
region_zone = infra_str.partition('/')[-1]
|
|
210
|
+
logger.warning(f'Infra {infra_str} is specified, but `check` '
|
|
211
|
+
f'only supports checking {infra.cloud}, '
|
|
212
|
+
f'ignoring {region_zone}')
|
|
213
|
+
specified_clouds.append(infra.cloud)
|
|
214
|
+
clouds = tuple(specified_clouds)
|
|
215
|
+
body = payloads.CheckBody(clouds=clouds,
|
|
216
|
+
verbose=verbose,
|
|
217
|
+
workspace=workspace)
|
|
218
|
+
response = server_common.make_authenticated_request(
|
|
219
|
+
'POST', '/check', json=json.loads(body.model_dump_json()))
|
|
108
220
|
return server_common.get_request_id(response)
|
|
109
221
|
|
|
110
222
|
|
|
111
223
|
@usage_lib.entrypoint
|
|
112
224
|
@server_common.check_server_healthy_or_start
|
|
113
225
|
@annotations.client_api
|
|
114
|
-
def enabled_clouds(
|
|
226
|
+
def enabled_clouds(workspace: Optional[str] = None,
|
|
227
|
+
expand: bool = False) -> server_common.RequestId[List[str]]:
|
|
115
228
|
"""Gets the enabled clouds.
|
|
116
229
|
|
|
230
|
+
Args:
|
|
231
|
+
workspace: The workspace to get the enabled clouds for. If None, the
|
|
232
|
+
active workspace will be used.
|
|
233
|
+
expand: Whether to expand Kubernetes and SSH to list of resource pools.
|
|
234
|
+
|
|
117
235
|
Returns:
|
|
118
236
|
The request ID of the enabled clouds request.
|
|
119
237
|
|
|
120
238
|
Request Returns:
|
|
121
239
|
A list of enabled clouds in string format.
|
|
122
240
|
"""
|
|
123
|
-
|
|
124
|
-
|
|
241
|
+
if workspace is None:
|
|
242
|
+
workspace = skypilot_config.get_active_workspace()
|
|
243
|
+
response = server_common.make_authenticated_request(
|
|
244
|
+
'GET', f'/enabled_clouds?workspace={workspace}&expand={expand}')
|
|
125
245
|
return server_common.get_request_id(response)
|
|
126
246
|
|
|
127
247
|
|
|
128
248
|
@usage_lib.entrypoint
|
|
129
249
|
@server_common.check_server_healthy_or_start
|
|
130
250
|
@annotations.client_api
|
|
131
|
-
def list_accelerators(
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
251
|
+
def list_accelerators(
|
|
252
|
+
gpus_only: bool = True,
|
|
253
|
+
name_filter: Optional[str] = None,
|
|
254
|
+
region_filter: Optional[str] = None,
|
|
255
|
+
quantity_filter: Optional[int] = None,
|
|
256
|
+
clouds: Optional[Union[List[str], str]] = None,
|
|
257
|
+
all_regions: bool = False,
|
|
258
|
+
require_price: bool = True,
|
|
259
|
+
case_sensitive: bool = True
|
|
260
|
+
) -> server_common.RequestId[Dict[str,
|
|
261
|
+
List['catalog.common.InstanceTypeInfo']]]:
|
|
139
262
|
"""Lists the names of all accelerators offered by Sky.
|
|
140
263
|
|
|
141
264
|
This will include all accelerators offered by Sky, including those
|
|
@@ -169,10 +292,8 @@ def list_accelerators(gpus_only: bool = True,
|
|
|
169
292
|
require_price=require_price,
|
|
170
293
|
case_sensitive=case_sensitive,
|
|
171
294
|
)
|
|
172
|
-
response =
|
|
173
|
-
|
|
174
|
-
json=json.loads(body.model_dump_json()),
|
|
175
|
-
cookies=server_common.get_api_cookie_jar())
|
|
295
|
+
response = server_common.make_authenticated_request(
|
|
296
|
+
'POST', '/list_accelerators', json=json.loads(body.model_dump_json()))
|
|
176
297
|
return server_common.get_request_id(response)
|
|
177
298
|
|
|
178
299
|
|
|
@@ -180,12 +301,12 @@ def list_accelerators(gpus_only: bool = True,
|
|
|
180
301
|
@server_common.check_server_healthy_or_start
|
|
181
302
|
@annotations.client_api
|
|
182
303
|
def list_accelerator_counts(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
304
|
+
gpus_only: bool = True,
|
|
305
|
+
name_filter: Optional[str] = None,
|
|
306
|
+
region_filter: Optional[str] = None,
|
|
307
|
+
quantity_filter: Optional[int] = None,
|
|
308
|
+
clouds: Optional[Union[List[str], str]] = None
|
|
309
|
+
) -> server_common.RequestId[Dict[str, List[float]]]:
|
|
189
310
|
"""Lists all accelerators offered by Sky and available counts.
|
|
190
311
|
|
|
191
312
|
Args:
|
|
@@ -203,17 +324,17 @@ def list_accelerator_counts(
|
|
|
203
324
|
accelerator names mapped to a list of available counts. See usage
|
|
204
325
|
in cli.py.
|
|
205
326
|
"""
|
|
206
|
-
body = payloads.
|
|
327
|
+
body = payloads.ListAcceleratorCountsBody(
|
|
207
328
|
gpus_only=gpus_only,
|
|
208
329
|
name_filter=name_filter,
|
|
209
330
|
region_filter=region_filter,
|
|
210
331
|
quantity_filter=quantity_filter,
|
|
211
332
|
clouds=clouds,
|
|
212
333
|
)
|
|
213
|
-
response =
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
334
|
+
response = server_common.make_authenticated_request(
|
|
335
|
+
'POST',
|
|
336
|
+
'/list_accelerator_counts',
|
|
337
|
+
json=json.loads(body.model_dump_json()))
|
|
217
338
|
return server_common.get_request_id(response)
|
|
218
339
|
|
|
219
340
|
|
|
@@ -224,7 +345,7 @@ def optimize(
|
|
|
224
345
|
dag: 'sky.Dag',
|
|
225
346
|
minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
226
347
|
admin_policy_request_options: Optional[admin_policy.RequestOptions] = None
|
|
227
|
-
) -> server_common.RequestId:
|
|
348
|
+
) -> server_common.RequestId['sky.Dag']:
|
|
228
349
|
"""Finds the best execution plan for the given DAG.
|
|
229
350
|
|
|
230
351
|
Args:
|
|
@@ -250,12 +371,27 @@ def optimize(
|
|
|
250
371
|
body = payloads.OptimizeBody(dag=dag_str,
|
|
251
372
|
minimize=minimize,
|
|
252
373
|
request_options=admin_policy_request_options)
|
|
253
|
-
response =
|
|
254
|
-
|
|
255
|
-
cookies=server_common.get_api_cookie_jar())
|
|
374
|
+
response = server_common.make_authenticated_request(
|
|
375
|
+
'POST', '/optimize', json=json.loads(body.model_dump_json()))
|
|
256
376
|
return server_common.get_request_id(response)
|
|
257
377
|
|
|
258
378
|
|
|
379
|
+
def workspaces() -> server_common.RequestId[Dict[str, Any]]:
|
|
380
|
+
"""Gets the workspaces."""
|
|
381
|
+
response = server_common.make_authenticated_request('GET', '/workspaces')
|
|
382
|
+
return server_common.get_request_id(response)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _raise_exception_object_on_client(e: BaseException) -> None:
|
|
386
|
+
"""Raise the exception object on the client."""
|
|
387
|
+
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
388
|
+
stacktrace = getattr(e, 'stacktrace', str(e))
|
|
389
|
+
logger.error('=== Traceback on SkyPilot API Server ===\n'
|
|
390
|
+
f'{stacktrace}')
|
|
391
|
+
with ux_utils.print_exception_no_traceback():
|
|
392
|
+
raise e
|
|
393
|
+
|
|
394
|
+
|
|
259
395
|
@usage_lib.entrypoint
|
|
260
396
|
@server_common.check_server_healthy_or_start
|
|
261
397
|
@annotations.client_api
|
|
@@ -279,29 +415,35 @@ def validate(
|
|
|
279
415
|
validation. This is only required when a admin policy is in use,
|
|
280
416
|
see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html
|
|
281
417
|
"""
|
|
418
|
+
remote_api_version = versions.get_remote_api_version()
|
|
419
|
+
# TODO(kevin): remove this in v0.13.0
|
|
420
|
+
omit_user_specified_yaml = (remote_api_version is None or
|
|
421
|
+
remote_api_version < 15)
|
|
282
422
|
for task in dag.tasks:
|
|
423
|
+
if omit_user_specified_yaml:
|
|
424
|
+
# pylint: disable=protected-access
|
|
425
|
+
task._user_specified_yaml = None
|
|
283
426
|
task.expand_and_validate_workdir()
|
|
284
427
|
if not workdir_only:
|
|
285
428
|
task.expand_and_validate_file_mounts()
|
|
286
429
|
dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
|
|
287
430
|
body = payloads.ValidateBody(dag=dag_str,
|
|
288
431
|
request_options=admin_policy_request_options)
|
|
289
|
-
response =
|
|
290
|
-
|
|
291
|
-
cookies=server_common.get_api_cookie_jar())
|
|
432
|
+
response = server_common.make_authenticated_request(
|
|
433
|
+
'POST', '/validate', json=json.loads(body.model_dump_json()))
|
|
292
434
|
if response.status_code == 400:
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
response.json().get('detail'))
|
|
435
|
+
_raise_exception_object_on_client(
|
|
436
|
+
exceptions.deserialize_exception(response.json().get('detail')))
|
|
296
437
|
|
|
297
438
|
|
|
298
439
|
@usage_lib.entrypoint
|
|
299
440
|
@server_common.check_server_healthy_or_start
|
|
300
441
|
@annotations.client_api
|
|
301
|
-
def dashboard() -> None:
|
|
442
|
+
def dashboard(starting_page: Optional[str] = None) -> None:
|
|
302
443
|
"""Starts the dashboard for SkyPilot."""
|
|
303
444
|
api_server_url = server_common.get_server_url()
|
|
304
|
-
url = server_common.get_dashboard_url(api_server_url
|
|
445
|
+
url = server_common.get_dashboard_url(api_server_url,
|
|
446
|
+
starting_page=starting_page)
|
|
305
447
|
logger.info(f'Opening dashboard in browser: {url}')
|
|
306
448
|
webbrowser.open(url)
|
|
307
449
|
|
|
@@ -309,14 +451,16 @@ def dashboard() -> None:
|
|
|
309
451
|
@usage_lib.entrypoint
|
|
310
452
|
@server_common.check_server_healthy_or_start
|
|
311
453
|
@annotations.client_api
|
|
454
|
+
@sky_context.contextual
|
|
312
455
|
def launch(
|
|
313
456
|
task: Union['sky.Task', 'sky.Dag'],
|
|
314
457
|
cluster_name: Optional[str] = None,
|
|
315
458
|
retry_until_up: bool = False,
|
|
316
459
|
idle_minutes_to_autostop: Optional[int] = None,
|
|
460
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
317
461
|
dryrun: bool = False,
|
|
318
462
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
319
|
-
backend: Optional[backends.Backend] = None,
|
|
463
|
+
backend: Optional['backends.Backend'] = None,
|
|
320
464
|
optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
321
465
|
no_setup: bool = False,
|
|
322
466
|
clone_disk_from: Optional[str] = None,
|
|
@@ -327,7 +471,8 @@ def launch(
|
|
|
327
471
|
_is_launched_by_jobs_controller: bool = False,
|
|
328
472
|
_is_launched_by_sky_serve_controller: bool = False,
|
|
329
473
|
_disable_controller_check: bool = False,
|
|
330
|
-
) -> server_common.RequestId
|
|
474
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
475
|
+
Optional['backends.ResourceHandle']]]:
|
|
331
476
|
"""Launches a cluster or task.
|
|
332
477
|
|
|
333
478
|
The task's setup and run commands are executed under the task's workdir
|
|
@@ -344,7 +489,7 @@ def launch(
|
|
|
344
489
|
import sky
|
|
345
490
|
task = sky.Task(run='echo hello SkyPilot')
|
|
346
491
|
task.set_resources(
|
|
347
|
-
sky.Resources(
|
|
492
|
+
sky.Resources(infra='aws', accelerators='V100:4'))
|
|
348
493
|
sky.launch(task, cluster_name='my-cluster')
|
|
349
494
|
|
|
350
495
|
|
|
@@ -355,18 +500,31 @@ def launch(
|
|
|
355
500
|
retry_until_up: whether to retry launching the cluster until it is
|
|
356
501
|
up.
|
|
357
502
|
idle_minutes_to_autostop: automatically stop the cluster after this
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
503
|
+
many minute of idleness, i.e., no running or pending jobs in the
|
|
504
|
+
cluster's job queue. Idleness gets reset whenever setting-up/
|
|
505
|
+
running/pending jobs are found in the job queue. Setting this
|
|
506
|
+
flag is equivalent to running
|
|
507
|
+
``sky.launch(...)`` and then
|
|
508
|
+
``sky.autostop(idle_minutes=<minutes>)``. If set, the autostop
|
|
509
|
+
config specified in the task' resources will be overridden by
|
|
510
|
+
this parameter.
|
|
511
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
512
|
+
This option works in conjunction with ``idle_minutes_to_autostop``.
|
|
513
|
+
Choices:
|
|
514
|
+
|
|
515
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
516
|
+
connections to finish.
|
|
517
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
518
|
+
3. "none" - Wait for nothing; autostop right after
|
|
519
|
+
``idle_minutes_to_autostop``.
|
|
364
520
|
dryrun: if True, do not actually launch the cluster.
|
|
365
521
|
down: Tear down the cluster after all jobs finish (successfully or
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
522
|
+
abnormally). If --idle-minutes-to-autostop is also set, the
|
|
523
|
+
cluster will be torn down after the specified idle time.
|
|
524
|
+
Note that if errors occur during provisioning/data syncing/setting
|
|
525
|
+
up, the cluster will not be torn down for debugging purposes. If
|
|
526
|
+
set, the autostop config specified in the task' resources will be
|
|
527
|
+
overridden by this parameter.
|
|
370
528
|
backend: backend to use. If None, use the default backend
|
|
371
529
|
(CloudVMRayBackend).
|
|
372
530
|
optimize_target: target to optimize for. Choices: OptimizeTarget.COST,
|
|
@@ -422,35 +580,115 @@ def launch(
|
|
|
422
580
|
raise NotImplementedError('clone_disk_from is not implemented yet. '
|
|
423
581
|
'Please contact the SkyPilot team if you '
|
|
424
582
|
'need this feature at slack.skypilot.co.')
|
|
583
|
+
|
|
584
|
+
remote_api_version = versions.get_remote_api_version()
|
|
585
|
+
if wait_for is not None and (remote_api_version is None or
|
|
586
|
+
remote_api_version < 13):
|
|
587
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
588
|
+
'Please upgrade to a newer API server to use it.')
|
|
589
|
+
|
|
425
590
|
dag = dag_utils.convert_entrypoint_to_dag(task)
|
|
591
|
+
# Override the autostop config from command line flags to task YAML.
|
|
592
|
+
for task in dag.tasks:
|
|
593
|
+
for resource in task.resources:
|
|
594
|
+
if remote_api_version is None or remote_api_version < 13:
|
|
595
|
+
# An older server would not recognize the wait_for field
|
|
596
|
+
# in the schema, so we need to omit it.
|
|
597
|
+
resource.override_autostop_config(
|
|
598
|
+
down=down, idle_minutes=idle_minutes_to_autostop)
|
|
599
|
+
else:
|
|
600
|
+
resource.override_autostop_config(
|
|
601
|
+
down=down,
|
|
602
|
+
idle_minutes=idle_minutes_to_autostop,
|
|
603
|
+
wait_for=wait_for)
|
|
604
|
+
if resource.autostop_config is not None:
|
|
605
|
+
# For backward-compatibility, get the final autostop config for
|
|
606
|
+
# admin policy.
|
|
607
|
+
# TODO(aylei): remove this after 0.12.0
|
|
608
|
+
down = resource.autostop_config.down
|
|
609
|
+
idle_minutes_to_autostop = resource.autostop_config.idle_minutes
|
|
610
|
+
|
|
426
611
|
request_options = admin_policy.RequestOptions(
|
|
427
612
|
cluster_name=cluster_name,
|
|
428
613
|
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
429
614
|
down=down,
|
|
430
615
|
dryrun=dryrun)
|
|
616
|
+
with admin_policy_utils.apply_and_use_config_in_current_request(
|
|
617
|
+
dag,
|
|
618
|
+
request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
|
|
619
|
+
request_options=request_options,
|
|
620
|
+
at_client_side=True) as dag:
|
|
621
|
+
return _launch(
|
|
622
|
+
dag,
|
|
623
|
+
cluster_name,
|
|
624
|
+
request_options,
|
|
625
|
+
retry_until_up,
|
|
626
|
+
idle_minutes_to_autostop,
|
|
627
|
+
dryrun,
|
|
628
|
+
down,
|
|
629
|
+
backend,
|
|
630
|
+
optimize_target,
|
|
631
|
+
no_setup,
|
|
632
|
+
clone_disk_from,
|
|
633
|
+
fast,
|
|
634
|
+
_need_confirmation,
|
|
635
|
+
_is_launched_by_jobs_controller,
|
|
636
|
+
_is_launched_by_sky_serve_controller,
|
|
637
|
+
_disable_controller_check,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _launch(
|
|
642
|
+
dag: 'sky.Dag',
|
|
643
|
+
cluster_name: str,
|
|
644
|
+
request_options: admin_policy.RequestOptions,
|
|
645
|
+
retry_until_up: bool = False,
|
|
646
|
+
idle_minutes_to_autostop: Optional[int] = None,
|
|
647
|
+
dryrun: bool = False,
|
|
648
|
+
down: bool = False, # pylint: disable=redefined-outer-name
|
|
649
|
+
backend: Optional['backends.Backend'] = None,
|
|
650
|
+
optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
651
|
+
no_setup: bool = False,
|
|
652
|
+
clone_disk_from: Optional[str] = None,
|
|
653
|
+
fast: bool = False,
|
|
654
|
+
# Internal only:
|
|
655
|
+
# pylint: disable=invalid-name
|
|
656
|
+
_need_confirmation: bool = False,
|
|
657
|
+
_is_launched_by_jobs_controller: bool = False,
|
|
658
|
+
_is_launched_by_sky_serve_controller: bool = False,
|
|
659
|
+
_disable_controller_check: bool = False,
|
|
660
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
661
|
+
Optional['backends.ResourceHandle']]]:
|
|
662
|
+
"""Auxiliary function for launch(), refer to launch() for details."""
|
|
663
|
+
|
|
431
664
|
validate(dag, admin_policy_request_options=request_options)
|
|
665
|
+
# The flags have been applied to the task YAML and the backward
|
|
666
|
+
# compatibility of admin policy has been handled. We should no longer use
|
|
667
|
+
# these flags.
|
|
668
|
+
del down, idle_minutes_to_autostop
|
|
432
669
|
|
|
433
670
|
confirm_shown = False
|
|
434
671
|
if _need_confirmation:
|
|
435
672
|
cluster_status = None
|
|
436
673
|
# TODO(SKY-998): we should reduce RTTs before launching the cluster.
|
|
437
|
-
|
|
438
|
-
clusters = get(
|
|
674
|
+
status_request_id = status([cluster_name], all_users=True)
|
|
675
|
+
clusters = get(status_request_id)
|
|
439
676
|
cluster_user_hash = common_utils.get_user_hash()
|
|
440
677
|
cluster_user_hash_str = ''
|
|
441
|
-
|
|
678
|
+
current_user = common_utils.get_current_user_name()
|
|
679
|
+
cluster_user_name = current_user
|
|
442
680
|
if not clusters:
|
|
443
681
|
# Show the optimize log before the prompt if the cluster does not
|
|
444
682
|
# exist.
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
stream_and_get(
|
|
683
|
+
optimize_request_id = optimize(
|
|
684
|
+
dag, admin_policy_request_options=request_options)
|
|
685
|
+
stream_and_get(optimize_request_id)
|
|
448
686
|
else:
|
|
449
687
|
cluster_record = clusters[0]
|
|
450
688
|
cluster_status = cluster_record['status']
|
|
451
689
|
cluster_user_hash = cluster_record['user_hash']
|
|
452
690
|
cluster_user_name = cluster_record['user_name']
|
|
453
|
-
if cluster_user_name ==
|
|
691
|
+
if cluster_user_name == current_user:
|
|
454
692
|
# Only show the hash if the username is the same as the local
|
|
455
693
|
# username, to avoid confusion.
|
|
456
694
|
cluster_user_hash_str = f' (hash: {cluster_user_hash})'
|
|
@@ -492,9 +730,7 @@ def launch(
|
|
|
492
730
|
task=dag_str,
|
|
493
731
|
cluster_name=cluster_name,
|
|
494
732
|
retry_until_up=retry_until_up,
|
|
495
|
-
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
496
733
|
dryrun=dryrun,
|
|
497
|
-
down=down,
|
|
498
734
|
backend=backend.NAME if backend else None,
|
|
499
735
|
optimize_target=optimize_target,
|
|
500
736
|
no_setup=no_setup,
|
|
@@ -507,12 +743,8 @@ def launch(
|
|
|
507
743
|
_is_launched_by_sky_serve_controller),
|
|
508
744
|
disable_controller_check=_disable_controller_check,
|
|
509
745
|
)
|
|
510
|
-
response =
|
|
511
|
-
|
|
512
|
-
json=json.loads(body.model_dump_json()),
|
|
513
|
-
timeout=5,
|
|
514
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
515
|
-
)
|
|
746
|
+
response = server_common.make_authenticated_request(
|
|
747
|
+
'POST', '/launch', json=json.loads(body.model_dump_json()), timeout=5)
|
|
516
748
|
return server_common.get_request_id(response)
|
|
517
749
|
|
|
518
750
|
|
|
@@ -524,8 +756,9 @@ def exec( # pylint: disable=redefined-builtin
|
|
|
524
756
|
cluster_name: Optional[str] = None,
|
|
525
757
|
dryrun: bool = False,
|
|
526
758
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
527
|
-
backend: Optional[backends.Backend] = None,
|
|
528
|
-
) -> server_common.RequestId
|
|
759
|
+
backend: Optional['backends.Backend'] = None,
|
|
760
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
761
|
+
Optional['backends.ResourceHandle']]]:
|
|
529
762
|
"""Executes a task on an existing cluster.
|
|
530
763
|
|
|
531
764
|
This function performs two actions:
|
|
@@ -591,23 +824,49 @@ def exec( # pylint: disable=redefined-builtin
|
|
|
591
824
|
backend=backend.NAME if backend else None,
|
|
592
825
|
)
|
|
593
826
|
|
|
594
|
-
response =
|
|
595
|
-
|
|
596
|
-
json=json.loads(body.model_dump_json()),
|
|
597
|
-
timeout=5,
|
|
598
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
599
|
-
)
|
|
827
|
+
response = server_common.make_authenticated_request(
|
|
828
|
+
'POST', '/exec', json=json.loads(body.model_dump_json()), timeout=5)
|
|
600
829
|
return server_common.get_request_id(response)
|
|
601
830
|
|
|
602
831
|
|
|
603
|
-
@
|
|
604
|
-
|
|
605
|
-
|
|
832
|
+
@typing.overload
|
|
833
|
+
def tail_logs(
|
|
834
|
+
cluster_name: str,
|
|
835
|
+
job_id: Optional[int],
|
|
836
|
+
follow: bool,
|
|
837
|
+
tail: int = 0,
|
|
838
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
839
|
+
*, # keyword only separator
|
|
840
|
+
preload_content: Literal[True] = True) -> int:
|
|
841
|
+
...
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
@typing.overload
|
|
606
845
|
def tail_logs(cluster_name: str,
|
|
607
846
|
job_id: Optional[int],
|
|
608
847
|
follow: bool,
|
|
609
848
|
tail: int = 0,
|
|
610
|
-
output_stream:
|
|
849
|
+
output_stream: None = None,
|
|
850
|
+
*,
|
|
851
|
+
preload_content: Literal[False]) -> Iterator[Optional[str]]:
|
|
852
|
+
...
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
# TODO(aylei): when retry logs request, there will be duplicated log entries.
|
|
856
|
+
# We should fix this.
|
|
857
|
+
@usage_lib.entrypoint
|
|
858
|
+
@server_common.check_server_healthy_or_start
|
|
859
|
+
@annotations.client_api
|
|
860
|
+
@rest.retry_transient_errors()
|
|
861
|
+
def tail_logs(
|
|
862
|
+
cluster_name: str,
|
|
863
|
+
job_id: Optional[int],
|
|
864
|
+
follow: bool,
|
|
865
|
+
tail: int = 0,
|
|
866
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
867
|
+
*, # keyword only separator
|
|
868
|
+
preload_content: bool = True
|
|
869
|
+
) -> Union[int, Iterator[Optional[str]]]:
|
|
611
870
|
"""Tails the logs of a job.
|
|
612
871
|
|
|
613
872
|
Args:
|
|
@@ -617,12 +876,21 @@ def tail_logs(cluster_name: str,
|
|
|
617
876
|
immediately.
|
|
618
877
|
tail: if > 0, tail the last N lines of the logs.
|
|
619
878
|
output_stream: the stream to write the logs to. If None, print to the
|
|
620
|
-
console.
|
|
879
|
+
console. Cannot be used with preload_content=False.
|
|
880
|
+
preload_content: if False, returns an Iterator[str | None] containing
|
|
881
|
+
the logs without the function blocking on the retrieval of entire
|
|
882
|
+
log. Iterator returns None when the log has been completely
|
|
883
|
+
streamed. Default True. Cannot be used with output_stream.
|
|
621
884
|
|
|
622
885
|
Returns:
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
886
|
+
If preload_content is True:
|
|
887
|
+
Exit code based on success or failure of the job. 0 if success,
|
|
888
|
+
100 if the job failed. See exceptions.JobExitCode for possible exit
|
|
889
|
+
codes.
|
|
890
|
+
If preload_content is False:
|
|
891
|
+
Iterator[str | None] containing the logs without the function
|
|
892
|
+
blocking on the retrieval of entire log. Iterator returns None
|
|
893
|
+
when the log has been completely streamed.
|
|
626
894
|
|
|
627
895
|
Request Raises:
|
|
628
896
|
ValueError: if arguments are invalid or the cluster is not supported.
|
|
@@ -635,21 +903,110 @@ def tail_logs(cluster_name: str,
|
|
|
635
903
|
sky.exceptions.CloudUserIdentityError: if we fail to get the current
|
|
636
904
|
user identity.
|
|
637
905
|
"""
|
|
906
|
+
if output_stream is not None and not preload_content:
|
|
907
|
+
raise ValueError(
|
|
908
|
+
'output_stream cannot be specified when preload_content is False')
|
|
909
|
+
|
|
638
910
|
body = payloads.ClusterJobBody(
|
|
639
911
|
cluster_name=cluster_name,
|
|
640
912
|
job_id=job_id,
|
|
641
913
|
follow=follow,
|
|
642
914
|
tail=tail,
|
|
643
915
|
)
|
|
644
|
-
response =
|
|
645
|
-
|
|
916
|
+
response = server_common.make_authenticated_request(
|
|
917
|
+
'POST',
|
|
918
|
+
'/logs',
|
|
646
919
|
json=json.loads(body.model_dump_json()),
|
|
647
920
|
stream=True,
|
|
648
921
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
649
|
-
None)
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
922
|
+
None))
|
|
923
|
+
request_id: server_common.RequestId[int] = server_common.get_request_id(
|
|
924
|
+
response)
|
|
925
|
+
if preload_content:
|
|
926
|
+
# Log request is idempotent when tail is 0, thus can resume previous
|
|
927
|
+
# streaming point on retry.
|
|
928
|
+
return stream_response(request_id=request_id,
|
|
929
|
+
response=response,
|
|
930
|
+
output_stream=output_stream,
|
|
931
|
+
resumable=(tail == 0))
|
|
932
|
+
else:
|
|
933
|
+
return rich_utils.decode_rich_status(response)
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
@usage_lib.entrypoint
|
|
937
|
+
@server_common.check_server_healthy_or_start
|
|
938
|
+
@versions.minimal_api_version(17)
|
|
939
|
+
@annotations.client_api
|
|
940
|
+
@rest.retry_transient_errors()
|
|
941
|
+
def tail_provision_logs(cluster_name: str,
|
|
942
|
+
worker: Optional[int] = None,
|
|
943
|
+
follow: bool = True,
|
|
944
|
+
tail: int = 0,
|
|
945
|
+
output_stream: Optional['io.TextIOBase'] = None) -> int:
|
|
946
|
+
"""Tails the provisioning logs (provision.log) for a cluster.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
cluster_name: name of the cluster.
|
|
950
|
+
worker: worker id in multi-node cluster.
|
|
951
|
+
If None, stream the logs of the head node.
|
|
952
|
+
follow: follow the logs.
|
|
953
|
+
tail: lines from end to tail.
|
|
954
|
+
output_stream: optional stream to write logs.
|
|
955
|
+
Returns:
|
|
956
|
+
Exit code 0 on streaming success; raises on HTTP error.
|
|
957
|
+
"""
|
|
958
|
+
body = payloads.ProvisionLogsBody(cluster_name=cluster_name)
|
|
959
|
+
|
|
960
|
+
if worker is not None:
|
|
961
|
+
remote_api_version = versions.get_remote_api_version()
|
|
962
|
+
if remote_api_version is not None and remote_api_version >= 21:
|
|
963
|
+
if worker < 1:
|
|
964
|
+
raise ValueError('Worker must be a positive integer.')
|
|
965
|
+
body.worker = worker
|
|
966
|
+
else:
|
|
967
|
+
raise exceptions.APINotSupportedError(
|
|
968
|
+
'Worker node provision logs are not supported in your API '
|
|
969
|
+
'server. Please upgrade to a newer API server to use it.')
|
|
970
|
+
params = {
|
|
971
|
+
'follow': str(follow).lower(),
|
|
972
|
+
'tail': tail,
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
response = server_common.make_authenticated_request(
|
|
976
|
+
'POST',
|
|
977
|
+
'/provision_logs',
|
|
978
|
+
json=json.loads(body.model_dump_json()),
|
|
979
|
+
params=params,
|
|
980
|
+
stream=True,
|
|
981
|
+
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
982
|
+
None))
|
|
983
|
+
# Check for HTTP errors before streaming the response
|
|
984
|
+
if response.status_code != 200:
|
|
985
|
+
with ux_utils.print_exception_no_traceback():
|
|
986
|
+
raise exceptions.CommandError(response.status_code,
|
|
987
|
+
'tail_provision_logs',
|
|
988
|
+
'Failed to stream provision logs',
|
|
989
|
+
response.text)
|
|
990
|
+
|
|
991
|
+
# Log request is idempotent when tail is 0, thus can resume previous
|
|
992
|
+
# streaming point on retry.
|
|
993
|
+
# request_id=None here because /provision_logs does not create an async
|
|
994
|
+
# request. Instead, it streams a plain file from the server. This does NOT
|
|
995
|
+
# violate the stream_response doc warning about None in multi-user
|
|
996
|
+
# environments: we are not asking stream_response to select "the latest
|
|
997
|
+
# request". We already have the HTTP response to stream; request_id=None
|
|
998
|
+
# merely disables the follow-up GET. It is also necessary for --no-follow
|
|
999
|
+
# to return cleanly after printing the tailed lines. If we provided a
|
|
1000
|
+
# non-None request_id here, the get(request_id) in stream_response(
|
|
1001
|
+
# would fail since /provision_logs does not create a request record.
|
|
1002
|
+
# By virtue of this, we set get_result to False to block get() from
|
|
1003
|
+
# running.
|
|
1004
|
+
stream_response(request_id=None,
|
|
1005
|
+
response=response,
|
|
1006
|
+
output_stream=output_stream,
|
|
1007
|
+
resumable=(tail == 0),
|
|
1008
|
+
get_result=False)
|
|
1009
|
+
return 0
|
|
653
1010
|
|
|
654
1011
|
|
|
655
1012
|
@usage_lib.entrypoint
|
|
@@ -683,11 +1040,11 @@ def download_logs(cluster_name: str,
|
|
|
683
1040
|
cluster_name=cluster_name,
|
|
684
1041
|
job_ids=job_ids,
|
|
685
1042
|
)
|
|
686
|
-
response =
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
1043
|
+
response = server_common.make_authenticated_request(
|
|
1044
|
+
'POST', '/download_logs', json=json.loads(body.model_dump_json()))
|
|
1045
|
+
request_id: server_common.RequestId[Dict[
|
|
1046
|
+
str, str]] = server_common.get_request_id(response)
|
|
1047
|
+
job_id_remote_path_dict = stream_and_get(request_id)
|
|
691
1048
|
remote2local_path_dict = client_common.download_logs_from_api_server(
|
|
692
1049
|
job_id_remote_path_dict.values())
|
|
693
1050
|
return {
|
|
@@ -702,10 +1059,11 @@ def download_logs(cluster_name: str,
|
|
|
702
1059
|
def start(
|
|
703
1060
|
cluster_name: str,
|
|
704
1061
|
idle_minutes_to_autostop: Optional[int] = None,
|
|
1062
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
705
1063
|
retry_until_up: bool = False,
|
|
706
1064
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
707
1065
|
force: bool = False,
|
|
708
|
-
) -> server_common.RequestId:
|
|
1066
|
+
) -> server_common.RequestId['backends.CloudVmRayResourceHandle']:
|
|
709
1067
|
"""Restart a cluster.
|
|
710
1068
|
|
|
711
1069
|
If a cluster is previously stopped (status is STOPPED) or failed in
|
|
@@ -728,6 +1086,15 @@ def start(
|
|
|
728
1086
|
flag is equivalent to running ``sky.launch()`` and then
|
|
729
1087
|
``sky.autostop(idle_minutes=<minutes>)``. If not set, the
|
|
730
1088
|
cluster will not be autostopped.
|
|
1089
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
1090
|
+
This option works in conjunction with ``idle_minutes_to_autostop``.
|
|
1091
|
+
Choices:
|
|
1092
|
+
|
|
1093
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
1094
|
+
connections to finish.
|
|
1095
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
1096
|
+
3. "none" - Wait for nothing; autostop right after
|
|
1097
|
+
``idle_minutes_to_autostop``.
|
|
731
1098
|
retry_until_up: whether to retry launching the cluster until it is
|
|
732
1099
|
up.
|
|
733
1100
|
down: Autodown the cluster: tear down the cluster after specified
|
|
@@ -756,26 +1123,30 @@ def start(
|
|
|
756
1123
|
sky.exceptions.ClusterOwnerIdentitiesMismatchError: if the cluster to
|
|
757
1124
|
restart was launched by a different user.
|
|
758
1125
|
"""
|
|
1126
|
+
remote_api_version = versions.get_remote_api_version()
|
|
1127
|
+
if wait_for is not None and (remote_api_version is None or
|
|
1128
|
+
remote_api_version < 13):
|
|
1129
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
1130
|
+
'Please upgrade to a newer API server to use it.')
|
|
1131
|
+
|
|
759
1132
|
body = payloads.StartBody(
|
|
760
1133
|
cluster_name=cluster_name,
|
|
761
1134
|
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
1135
|
+
wait_for=wait_for,
|
|
762
1136
|
retry_until_up=retry_until_up,
|
|
763
1137
|
down=down,
|
|
764
1138
|
force=force,
|
|
765
1139
|
)
|
|
766
|
-
response =
|
|
767
|
-
|
|
768
|
-
json=json.loads(body.model_dump_json()),
|
|
769
|
-
timeout=5,
|
|
770
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
771
|
-
)
|
|
1140
|
+
response = server_common.make_authenticated_request(
|
|
1141
|
+
'POST', '/start', json=json.loads(body.model_dump_json()), timeout=5)
|
|
772
1142
|
return server_common.get_request_id(response)
|
|
773
1143
|
|
|
774
1144
|
|
|
775
1145
|
@usage_lib.entrypoint
|
|
776
1146
|
@server_common.check_server_healthy_or_start
|
|
777
1147
|
@annotations.client_api
|
|
778
|
-
def down(cluster_name: str,
|
|
1148
|
+
def down(cluster_name: str,
|
|
1149
|
+
purge: bool = False) -> server_common.RequestId[None]:
|
|
779
1150
|
"""Tears down a cluster.
|
|
780
1151
|
|
|
781
1152
|
Tearing down a cluster will delete all associated resources (all billing
|
|
@@ -809,19 +1180,16 @@ def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
809
1180
|
cluster_name=cluster_name,
|
|
810
1181
|
purge=purge,
|
|
811
1182
|
)
|
|
812
|
-
response =
|
|
813
|
-
|
|
814
|
-
json=json.loads(body.model_dump_json()),
|
|
815
|
-
timeout=5,
|
|
816
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
817
|
-
)
|
|
1183
|
+
response = server_common.make_authenticated_request(
|
|
1184
|
+
'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
|
|
818
1185
|
return server_common.get_request_id(response)
|
|
819
1186
|
|
|
820
1187
|
|
|
821
1188
|
@usage_lib.entrypoint
|
|
822
1189
|
@server_common.check_server_healthy_or_start
|
|
823
1190
|
@annotations.client_api
|
|
824
|
-
def stop(cluster_name: str,
|
|
1191
|
+
def stop(cluster_name: str,
|
|
1192
|
+
purge: bool = False) -> server_common.RequestId[None]:
|
|
825
1193
|
"""Stops a cluster.
|
|
826
1194
|
|
|
827
1195
|
Data on attached disks is not lost when a cluster is stopped. Billing for
|
|
@@ -858,12 +1226,8 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
858
1226
|
cluster_name=cluster_name,
|
|
859
1227
|
purge=purge,
|
|
860
1228
|
)
|
|
861
|
-
response =
|
|
862
|
-
|
|
863
|
-
json=json.loads(body.model_dump_json()),
|
|
864
|
-
timeout=5,
|
|
865
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
866
|
-
)
|
|
1229
|
+
response = server_common.make_authenticated_request(
|
|
1230
|
+
'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
|
|
867
1231
|
return server_common.get_request_id(response)
|
|
868
1232
|
|
|
869
1233
|
|
|
@@ -871,10 +1235,11 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
871
1235
|
@server_common.check_server_healthy_or_start
|
|
872
1236
|
@annotations.client_api
|
|
873
1237
|
def autostop(
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1238
|
+
cluster_name: str,
|
|
1239
|
+
idle_minutes: int,
|
|
1240
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
1241
|
+
down: bool = False, # pylint: disable=redefined-outer-name
|
|
1242
|
+
) -> server_common.RequestId[None]:
|
|
878
1243
|
"""Schedules an autostop/autodown for a cluster.
|
|
879
1244
|
|
|
880
1245
|
Autostop/autodown will automatically stop or teardown a cluster when it
|
|
@@ -904,6 +1269,14 @@ def autostop(
|
|
|
904
1269
|
idle_minutes: the number of minutes of idleness (no pending/running
|
|
905
1270
|
jobs) after which the cluster will be stopped automatically. Setting
|
|
906
1271
|
to a negative number cancels any autostop/autodown setting.
|
|
1272
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
1273
|
+
This option works in conjunction with ``idle_minutes``.
|
|
1274
|
+
Choices:
|
|
1275
|
+
|
|
1276
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
1277
|
+
connections to finish.
|
|
1278
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
1279
|
+
3. "none" - Wait for nothing; autostop right after ``idle_minutes``.
|
|
907
1280
|
down: if true, use autodown (tear down the cluster; non-restartable),
|
|
908
1281
|
rather than autostop (restartable).
|
|
909
1282
|
|
|
@@ -923,26 +1296,31 @@ def autostop(
|
|
|
923
1296
|
sky.exceptions.CloudUserIdentityError: if we fail to get the current
|
|
924
1297
|
user identity.
|
|
925
1298
|
"""
|
|
1299
|
+
remote_api_version = versions.get_remote_api_version()
|
|
1300
|
+
if wait_for is not None and (remote_api_version is None or
|
|
1301
|
+
remote_api_version < 13):
|
|
1302
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
1303
|
+
'Please upgrade to a newer API server to use it.')
|
|
1304
|
+
|
|
926
1305
|
body = payloads.AutostopBody(
|
|
927
1306
|
cluster_name=cluster_name,
|
|
928
1307
|
idle_minutes=idle_minutes,
|
|
1308
|
+
wait_for=wait_for,
|
|
929
1309
|
down=down,
|
|
930
1310
|
)
|
|
931
|
-
response =
|
|
932
|
-
|
|
933
|
-
json=json.loads(body.model_dump_json()),
|
|
934
|
-
timeout=5,
|
|
935
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
936
|
-
)
|
|
1311
|
+
response = server_common.make_authenticated_request(
|
|
1312
|
+
'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
|
|
937
1313
|
return server_common.get_request_id(response)
|
|
938
1314
|
|
|
939
1315
|
|
|
940
1316
|
@usage_lib.entrypoint
|
|
941
1317
|
@server_common.check_server_healthy_or_start
|
|
942
1318
|
@annotations.client_api
|
|
943
|
-
def queue(
|
|
944
|
-
|
|
945
|
-
|
|
1319
|
+
def queue(
|
|
1320
|
+
cluster_name: str,
|
|
1321
|
+
skip_finished: bool = False,
|
|
1322
|
+
all_users: bool = False
|
|
1323
|
+
) -> server_common.RequestId[List[responses.ClusterJobRecord]]:
|
|
946
1324
|
"""Gets the job queue of a cluster.
|
|
947
1325
|
|
|
948
1326
|
Args:
|
|
@@ -955,8 +1333,8 @@ def queue(cluster_name: str,
|
|
|
955
1333
|
The request ID of the queue request.
|
|
956
1334
|
|
|
957
1335
|
Request Returns:
|
|
958
|
-
job_records (List[
|
|
959
|
-
queue.
|
|
1336
|
+
job_records (List[responses.ClusterJobRecord]): A list of job records
|
|
1337
|
+
for each job in the queue.
|
|
960
1338
|
|
|
961
1339
|
.. code-block:: python
|
|
962
1340
|
|
|
@@ -991,17 +1369,19 @@ def queue(cluster_name: str,
|
|
|
991
1369
|
skip_finished=skip_finished,
|
|
992
1370
|
all_users=all_users,
|
|
993
1371
|
)
|
|
994
|
-
response =
|
|
995
|
-
|
|
996
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1372
|
+
response = server_common.make_authenticated_request(
|
|
1373
|
+
'POST', '/queue', json=json.loads(body.model_dump_json()))
|
|
997
1374
|
return server_common.get_request_id(response)
|
|
998
1375
|
|
|
999
1376
|
|
|
1000
1377
|
@usage_lib.entrypoint
|
|
1001
1378
|
@server_common.check_server_healthy_or_start
|
|
1002
1379
|
@annotations.client_api
|
|
1003
|
-
def job_status(
|
|
1004
|
-
|
|
1380
|
+
def job_status(
|
|
1381
|
+
cluster_name: str,
|
|
1382
|
+
job_ids: Optional[List[int]] = None
|
|
1383
|
+
) -> server_common.RequestId[Dict[Optional[int],
|
|
1384
|
+
Optional['job_lib.JobStatus']]]:
|
|
1005
1385
|
"""Gets the status of jobs on a cluster.
|
|
1006
1386
|
|
|
1007
1387
|
Args:
|
|
@@ -1033,9 +1413,8 @@ def job_status(cluster_name: str,
|
|
|
1033
1413
|
cluster_name=cluster_name,
|
|
1034
1414
|
job_ids=job_ids,
|
|
1035
1415
|
)
|
|
1036
|
-
response =
|
|
1037
|
-
|
|
1038
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1416
|
+
response = server_common.make_authenticated_request(
|
|
1417
|
+
'POST', '/job_status', json=json.loads(body.model_dump_json()))
|
|
1039
1418
|
return server_common.get_request_id(response)
|
|
1040
1419
|
|
|
1041
1420
|
|
|
@@ -1049,7 +1428,7 @@ def cancel(
|
|
|
1049
1428
|
job_ids: Optional[List[int]] = None,
|
|
1050
1429
|
# pylint: disable=invalid-name
|
|
1051
1430
|
_try_cancel_if_cluster_is_init: bool = False
|
|
1052
|
-
) -> server_common.RequestId:
|
|
1431
|
+
) -> server_common.RequestId[None]:
|
|
1053
1432
|
"""Cancels jobs on a cluster.
|
|
1054
1433
|
|
|
1055
1434
|
Args:
|
|
@@ -1087,9 +1466,8 @@ def cancel(
|
|
|
1087
1466
|
job_ids=job_ids,
|
|
1088
1467
|
try_cancel_if_cluster_is_init=_try_cancel_if_cluster_is_init,
|
|
1089
1468
|
)
|
|
1090
|
-
response =
|
|
1091
|
-
|
|
1092
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1469
|
+
response = server_common.make_authenticated_request(
|
|
1470
|
+
'POST', '/cancel', json=json.loads(body.model_dump_json()))
|
|
1093
1471
|
return server_common.get_request_id(response)
|
|
1094
1472
|
|
|
1095
1473
|
|
|
@@ -1100,7 +1478,10 @@ def status(
|
|
|
1100
1478
|
cluster_names: Optional[List[str]] = None,
|
|
1101
1479
|
refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE,
|
|
1102
1480
|
all_users: bool = False,
|
|
1103
|
-
|
|
1481
|
+
*,
|
|
1482
|
+
_include_credentials: bool = False,
|
|
1483
|
+
_summary_response: bool = False,
|
|
1484
|
+
) -> server_common.RequestId[List[responses.StatusResponse]]:
|
|
1104
1485
|
"""Gets cluster statuses.
|
|
1105
1486
|
|
|
1106
1487
|
If cluster_names is given, return those clusters. Otherwise, return all
|
|
@@ -1148,6 +1529,8 @@ def status(
|
|
|
1148
1529
|
provider(s).
|
|
1149
1530
|
all_users: whether to include all users' clusters. By default, only
|
|
1150
1531
|
the current user's clusters are included.
|
|
1532
|
+
_include_credentials: (internal) whether to include cluster ssh
|
|
1533
|
+
credentials in the response (default: False).
|
|
1151
1534
|
|
|
1152
1535
|
Returns:
|
|
1153
1536
|
The request ID of the status request.
|
|
@@ -1182,10 +1565,11 @@ def status(
|
|
|
1182
1565
|
cluster_names=cluster_names,
|
|
1183
1566
|
refresh=refresh,
|
|
1184
1567
|
all_users=all_users,
|
|
1568
|
+
include_credentials=_include_credentials,
|
|
1569
|
+
summary_response=_summary_response,
|
|
1185
1570
|
)
|
|
1186
|
-
response =
|
|
1187
|
-
|
|
1188
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1571
|
+
response = server_common.make_authenticated_request(
|
|
1572
|
+
'POST', '/status', json=json.loads(body.model_dump_json()))
|
|
1189
1573
|
return server_common.get_request_id(response)
|
|
1190
1574
|
|
|
1191
1575
|
|
|
@@ -1193,10 +1577,19 @@ def status(
|
|
|
1193
1577
|
@server_common.check_server_healthy_or_start
|
|
1194
1578
|
@annotations.client_api
|
|
1195
1579
|
def endpoints(
|
|
1196
|
-
|
|
1197
|
-
|
|
1580
|
+
cluster: str,
|
|
1581
|
+
port: Optional[Union[int, str]] = None
|
|
1582
|
+
) -> server_common.RequestId[Dict[int, str]]:
|
|
1198
1583
|
"""Gets the endpoint for a given cluster and port number (endpoint).
|
|
1199
1584
|
|
|
1585
|
+
Example:
|
|
1586
|
+
.. code-block:: python
|
|
1587
|
+
|
|
1588
|
+
import sky
|
|
1589
|
+
request_id = sky.endpoints('test-cluster')
|
|
1590
|
+
sky.get(request_id)
|
|
1591
|
+
|
|
1592
|
+
|
|
1200
1593
|
Args:
|
|
1201
1594
|
cluster: The name of the cluster.
|
|
1202
1595
|
port: The port number to get the endpoint for. If None, endpoints
|
|
@@ -1206,8 +1599,9 @@ def endpoints(
|
|
|
1206
1599
|
The request ID of the endpoints request.
|
|
1207
1600
|
|
|
1208
1601
|
Request Returns:
|
|
1209
|
-
A dictionary of port numbers to endpoints.
|
|
1210
|
-
|
|
1602
|
+
A dictionary of port numbers to endpoints.
|
|
1603
|
+
If port is None, the dictionary contains all
|
|
1604
|
+
ports:endpoints exposed on the cluster.
|
|
1211
1605
|
|
|
1212
1606
|
Request Raises:
|
|
1213
1607
|
ValueError: if the cluster is not UP or the endpoint is not exposed.
|
|
@@ -1218,16 +1612,17 @@ def endpoints(
|
|
|
1218
1612
|
cluster=cluster,
|
|
1219
1613
|
port=port,
|
|
1220
1614
|
)
|
|
1221
|
-
response =
|
|
1222
|
-
|
|
1223
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1615
|
+
response = server_common.make_authenticated_request(
|
|
1616
|
+
'POST', '/endpoints', json=json.loads(body.model_dump_json()))
|
|
1224
1617
|
return server_common.get_request_id(response)
|
|
1225
1618
|
|
|
1226
1619
|
|
|
1227
1620
|
@usage_lib.entrypoint
|
|
1228
1621
|
@server_common.check_server_healthy_or_start
|
|
1229
1622
|
@annotations.client_api
|
|
1230
|
-
def cost_report(
|
|
1623
|
+
def cost_report(
|
|
1624
|
+
days: Optional[int] = None
|
|
1625
|
+
) -> server_common.RequestId[List[Dict[str, Any]]]: # pylint: disable=redefined-builtin
|
|
1231
1626
|
"""Gets all cluster cost reports, including those that have been downed.
|
|
1232
1627
|
|
|
1233
1628
|
The estimated cost column indicates price for the cluster based on the type
|
|
@@ -1237,6 +1632,10 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1237
1632
|
cache of the cluster status, and may not be accurate for the cluster with
|
|
1238
1633
|
autostop/use_spot set or terminated/stopped on the cloud console.
|
|
1239
1634
|
|
|
1635
|
+
Args:
|
|
1636
|
+
days: The number of days to get the cost report for. If not provided,
|
|
1637
|
+
the default is 30 days.
|
|
1638
|
+
|
|
1240
1639
|
Returns:
|
|
1241
1640
|
The request ID of the cost report request.
|
|
1242
1641
|
|
|
@@ -1258,8 +1657,9 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1258
1657
|
'total_cost': (float) cost given resources and usage intervals,
|
|
1259
1658
|
}
|
|
1260
1659
|
"""
|
|
1261
|
-
|
|
1262
|
-
|
|
1660
|
+
body = payloads.CostReportBody(days=days)
|
|
1661
|
+
response = server_common.make_authenticated_request(
|
|
1662
|
+
'POST', '/cost_report', json=json.loads(body.model_dump_json()))
|
|
1263
1663
|
return server_common.get_request_id(response)
|
|
1264
1664
|
|
|
1265
1665
|
|
|
@@ -1267,36 +1667,24 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1267
1667
|
@usage_lib.entrypoint
|
|
1268
1668
|
@server_common.check_server_healthy_or_start
|
|
1269
1669
|
@annotations.client_api
|
|
1270
|
-
def storage_ls() -> server_common.RequestId:
|
|
1670
|
+
def storage_ls() -> server_common.RequestId[List[responses.StorageRecord]]:
|
|
1271
1671
|
"""Gets the storages.
|
|
1272
1672
|
|
|
1273
1673
|
Returns:
|
|
1274
1674
|
The request ID of the storage list request.
|
|
1275
1675
|
|
|
1276
1676
|
Request Returns:
|
|
1277
|
-
storage_records (List[
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
.. code-block:: python
|
|
1281
|
-
|
|
1282
|
-
{
|
|
1283
|
-
'name': (str) storage name,
|
|
1284
|
-
'launched_at': (int) timestamp of creation,
|
|
1285
|
-
'store': (List[sky.StoreType]) storage type,
|
|
1286
|
-
'last_use': (int) timestamp of last use,
|
|
1287
|
-
'status': (sky.StorageStatus) storage status,
|
|
1288
|
-
}
|
|
1289
|
-
]
|
|
1677
|
+
storage_records (List[responses.StorageRecord]):
|
|
1678
|
+
A list of storage records.
|
|
1290
1679
|
"""
|
|
1291
|
-
response =
|
|
1292
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1680
|
+
response = server_common.make_authenticated_request('GET', '/storage/ls')
|
|
1293
1681
|
return server_common.get_request_id(response)
|
|
1294
1682
|
|
|
1295
1683
|
|
|
1296
1684
|
@usage_lib.entrypoint
|
|
1297
1685
|
@server_common.check_server_healthy_or_start
|
|
1298
1686
|
@annotations.client_api
|
|
1299
|
-
def storage_delete(name: str) -> server_common.RequestId:
|
|
1687
|
+
def storage_delete(name: str) -> server_common.RequestId[None]:
|
|
1300
1688
|
"""Deletes a storage.
|
|
1301
1689
|
|
|
1302
1690
|
Args:
|
|
@@ -1312,9 +1700,8 @@ def storage_delete(name: str) -> server_common.RequestId:
|
|
|
1312
1700
|
ValueError: If the storage does not exist.
|
|
1313
1701
|
"""
|
|
1314
1702
|
body = payloads.StorageBody(name=name)
|
|
1315
|
-
response =
|
|
1316
|
-
|
|
1317
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1703
|
+
response = server_common.make_authenticated_request(
|
|
1704
|
+
'POST', '/storage/delete', json=json.loads(body.model_dump_json()))
|
|
1318
1705
|
return server_common.get_request_id(response)
|
|
1319
1706
|
|
|
1320
1707
|
|
|
@@ -1325,12 +1712,8 @@ def storage_delete(name: str) -> server_common.RequestId:
|
|
|
1325
1712
|
@server_common.check_server_healthy_or_start
|
|
1326
1713
|
@annotations.client_api
|
|
1327
1714
|
def local_up(gpus: bool,
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
ssh_key: Optional[str],
|
|
1331
|
-
cleanup: bool,
|
|
1332
|
-
context_name: Optional[str] = None,
|
|
1333
|
-
password: Optional[str] = None) -> server_common.RequestId:
|
|
1715
|
+
name: Optional[str] = None,
|
|
1716
|
+
port_start: Optional[int] = None) -> server_common.RequestId[None]:
|
|
1334
1717
|
"""Launches a Kubernetes cluster on local machines.
|
|
1335
1718
|
|
|
1336
1719
|
Returns:
|
|
@@ -1341,36 +1724,151 @@ def local_up(gpus: bool,
|
|
|
1341
1724
|
# TODO: move this check to server.
|
|
1342
1725
|
if not server_common.is_api_server_local():
|
|
1343
1726
|
with ux_utils.print_exception_no_traceback():
|
|
1344
|
-
raise ValueError(
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
body = payloads.LocalUpBody(gpus=gpus,
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
ssh_key=ssh_key,
|
|
1351
|
-
cleanup=cleanup,
|
|
1352
|
-
context_name=context_name,
|
|
1353
|
-
password=password)
|
|
1354
|
-
response = requests.post(f'{server_common.get_server_url()}/local_up',
|
|
1355
|
-
json=json.loads(body.model_dump_json()),
|
|
1356
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1727
|
+
raise ValueError('`sky local up` is only supported when '
|
|
1728
|
+
'running SkyPilot locally.')
|
|
1729
|
+
|
|
1730
|
+
body = payloads.LocalUpBody(gpus=gpus, name=name, port_start=port_start)
|
|
1731
|
+
response = server_common.make_authenticated_request(
|
|
1732
|
+
'POST', '/local_up', json=json.loads(body.model_dump_json()))
|
|
1357
1733
|
return server_common.get_request_id(response)
|
|
1358
1734
|
|
|
1359
1735
|
|
|
1360
1736
|
@usage_lib.entrypoint
|
|
1361
1737
|
@server_common.check_server_healthy_or_start
|
|
1362
1738
|
@annotations.client_api
|
|
1363
|
-
def local_down() -> server_common.RequestId:
|
|
1739
|
+
def local_down(name: Optional[str]) -> server_common.RequestId[None]:
|
|
1364
1740
|
"""Tears down the Kubernetes cluster started by local_up."""
|
|
1365
1741
|
# We do not allow local up when the API server is running remotely since it
|
|
1366
1742
|
# will modify the kubeconfig.
|
|
1367
1743
|
# TODO: move this check to remote server.
|
|
1368
1744
|
if not server_common.is_api_server_local():
|
|
1369
1745
|
with ux_utils.print_exception_no_traceback():
|
|
1370
|
-
raise ValueError('sky local down is only supported when running '
|
|
1746
|
+
raise ValueError('`sky local down` is only supported when running '
|
|
1371
1747
|
'SkyPilot locally.')
|
|
1372
|
-
|
|
1373
|
-
|
|
1748
|
+
|
|
1749
|
+
body = payloads.LocalDownBody(name=name)
|
|
1750
|
+
response = server_common.make_authenticated_request(
|
|
1751
|
+
'POST', '/local_down', json=json.loads(body.model_dump_json()))
|
|
1752
|
+
return server_common.get_request_id(response)
|
|
1753
|
+
|
|
1754
|
+
|
|
1755
|
+
def _update_remote_ssh_node_pools(file: str,
|
|
1756
|
+
infra: Optional[str] = None) -> None:
|
|
1757
|
+
"""Update the SSH node pools on the remote server.
|
|
1758
|
+
|
|
1759
|
+
This function will also upload the local SSH key to the remote server, and
|
|
1760
|
+
replace the file path to the remote SSH key file path.
|
|
1761
|
+
|
|
1762
|
+
Args:
|
|
1763
|
+
file: The path to the local SSH node pools config file.
|
|
1764
|
+
infra: The name of the cluster configuration in the local SSH node
|
|
1765
|
+
pools config file. If None, all clusters in the file are updated.
|
|
1766
|
+
"""
|
|
1767
|
+
file = os.path.expanduser(file)
|
|
1768
|
+
if not os.path.exists(file):
|
|
1769
|
+
with ux_utils.print_exception_no_traceback():
|
|
1770
|
+
raise ValueError(
|
|
1771
|
+
f'SSH Node Pool config file {file} does not exist. '
|
|
1772
|
+
'Please check if the file exists and the path is correct.')
|
|
1773
|
+
config = ssh_utils.load_ssh_targets(file)
|
|
1774
|
+
config = ssh_utils.get_cluster_config(config, infra)
|
|
1775
|
+
pools_config = {}
|
|
1776
|
+
for name, pool_config in config.items():
|
|
1777
|
+
hosts_info = ssh_utils.prepare_hosts_info(
|
|
1778
|
+
name, pool_config, upload_ssh_key_func=_upload_ssh_key_and_wait)
|
|
1779
|
+
pools_config[name] = {'hosts': hosts_info}
|
|
1780
|
+
server_common.make_authenticated_request('POST',
|
|
1781
|
+
'/ssh_node_pools',
|
|
1782
|
+
json=pools_config)
|
|
1783
|
+
|
|
1784
|
+
|
|
1785
|
+
def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
|
|
1786
|
+
"""Upload the SSH key to the remote server and wait for the key to be
|
|
1787
|
+
uploaded.
|
|
1788
|
+
|
|
1789
|
+
Args:
|
|
1790
|
+
key_name: The name of the SSH key.
|
|
1791
|
+
key_file_path: The path to the local SSH key file.
|
|
1792
|
+
|
|
1793
|
+
Returns:
|
|
1794
|
+
The path for the remote SSH key file on the API server.
|
|
1795
|
+
"""
|
|
1796
|
+
if not os.path.exists(os.path.expanduser(key_file_path)):
|
|
1797
|
+
with ux_utils.print_exception_no_traceback():
|
|
1798
|
+
raise ValueError(f'SSH key file not found: {key_file_path}')
|
|
1799
|
+
|
|
1800
|
+
with open(os.path.expanduser(key_file_path), 'rb') as key_file:
|
|
1801
|
+
response = server_common.make_authenticated_request(
|
|
1802
|
+
'POST',
|
|
1803
|
+
'/ssh_node_pools/keys',
|
|
1804
|
+
files={
|
|
1805
|
+
'key_file': (key_name, key_file, 'application/octet-stream')
|
|
1806
|
+
},
|
|
1807
|
+
data={'key_name': key_name},
|
|
1808
|
+
cookies=server_common.get_api_cookie_jar())
|
|
1809
|
+
|
|
1810
|
+
return response.json()['key_path']
|
|
1811
|
+
|
|
1812
|
+
|
|
1813
|
+
@usage_lib.entrypoint
|
|
1814
|
+
@server_common.check_server_healthy_or_start
|
|
1815
|
+
@annotations.client_api
|
|
1816
|
+
def ssh_up(infra: Optional[str] = None,
|
|
1817
|
+
file: Optional[str] = None) -> server_common.RequestId[None]:
|
|
1818
|
+
"""Deploys the SSH Node Pools defined in ~/.sky/ssh_targets.yaml.
|
|
1819
|
+
|
|
1820
|
+
Args:
|
|
1821
|
+
infra: Name of the cluster configuration in ssh_targets.yaml.
|
|
1822
|
+
If None, the first cluster in the file is used.
|
|
1823
|
+
file: Name of the ssh node pool configuration file to use. If
|
|
1824
|
+
None, the default path, ~/.sky/ssh_node_pools.yaml is used.
|
|
1825
|
+
|
|
1826
|
+
Returns:
|
|
1827
|
+
request_id: The request ID of the SSH cluster deployment request.
|
|
1828
|
+
"""
|
|
1829
|
+
if file is not None:
|
|
1830
|
+
_update_remote_ssh_node_pools(file, infra)
|
|
1831
|
+
|
|
1832
|
+
# Use SSH node pools router endpoint
|
|
1833
|
+
body = payloads.SSHUpBody(infra=infra, cleanup=False)
|
|
1834
|
+
if infra is not None:
|
|
1835
|
+
# Call the specific pool deployment endpoint
|
|
1836
|
+
response = server_common.make_authenticated_request(
|
|
1837
|
+
'POST', f'/ssh_node_pools/{infra}/deploy')
|
|
1838
|
+
else:
|
|
1839
|
+
# Call the general deployment endpoint
|
|
1840
|
+
response = server_common.make_authenticated_request(
|
|
1841
|
+
'POST',
|
|
1842
|
+
'/ssh_node_pools/deploy',
|
|
1843
|
+
json=json.loads(body.model_dump_json()))
|
|
1844
|
+
return server_common.get_request_id(response)
|
|
1845
|
+
|
|
1846
|
+
|
|
1847
|
+
@usage_lib.entrypoint
|
|
1848
|
+
@server_common.check_server_healthy_or_start
|
|
1849
|
+
@annotations.client_api
|
|
1850
|
+
def ssh_down(infra: Optional[str] = None) -> server_common.RequestId[None]:
|
|
1851
|
+
"""Tears down a Kubernetes cluster on SSH targets.
|
|
1852
|
+
|
|
1853
|
+
Args:
|
|
1854
|
+
infra: Name of the cluster configuration in ssh_targets.yaml.
|
|
1855
|
+
If None, the first cluster in the file is used.
|
|
1856
|
+
|
|
1857
|
+
Returns:
|
|
1858
|
+
request_id: The request ID of the SSH cluster teardown request.
|
|
1859
|
+
"""
|
|
1860
|
+
# Use SSH node pools router endpoint
|
|
1861
|
+
body = payloads.SSHUpBody(infra=infra, cleanup=True)
|
|
1862
|
+
if infra is not None:
|
|
1863
|
+
# Call the specific pool down endpoint
|
|
1864
|
+
response = server_common.make_authenticated_request(
|
|
1865
|
+
'POST', f'/ssh_node_pools/{infra}/down')
|
|
1866
|
+
else:
|
|
1867
|
+
# Call the general down endpoint
|
|
1868
|
+
response = server_common.make_authenticated_request(
|
|
1869
|
+
'POST',
|
|
1870
|
+
'/ssh_node_pools/down',
|
|
1871
|
+
json=json.loads(body.model_dump_json()))
|
|
1374
1872
|
return server_common.get_request_id(response)
|
|
1375
1873
|
|
|
1376
1874
|
|
|
@@ -1378,9 +1876,12 @@ def local_down() -> server_common.RequestId:
|
|
|
1378
1876
|
@server_common.check_server_healthy_or_start
|
|
1379
1877
|
@annotations.client_api
|
|
1380
1878
|
def realtime_kubernetes_gpu_availability(
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1879
|
+
context: Optional[str] = None,
|
|
1880
|
+
name_filter: Optional[str] = None,
|
|
1881
|
+
quantity_filter: Optional[int] = None,
|
|
1882
|
+
is_ssh: Optional[bool] = None
|
|
1883
|
+
) -> server_common.RequestId[List[Tuple[
|
|
1884
|
+
str, List['models.RealtimeGpuAvailability']]]]:
|
|
1384
1885
|
"""Gets the real-time Kubernetes GPU availability.
|
|
1385
1886
|
|
|
1386
1887
|
Returns:
|
|
@@ -1390,12 +1891,12 @@ def realtime_kubernetes_gpu_availability(
|
|
|
1390
1891
|
context=context,
|
|
1391
1892
|
name_filter=name_filter,
|
|
1392
1893
|
quantity_filter=quantity_filter,
|
|
1894
|
+
is_ssh=is_ssh,
|
|
1393
1895
|
)
|
|
1394
|
-
response =
|
|
1395
|
-
|
|
1396
|
-
'realtime_kubernetes_gpu_availability',
|
|
1397
|
-
json=json.loads(body.model_dump_json())
|
|
1398
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1896
|
+
response = server_common.make_authenticated_request(
|
|
1897
|
+
'POST',
|
|
1898
|
+
'/realtime_kubernetes_gpu_availability',
|
|
1899
|
+
json=json.loads(body.model_dump_json()))
|
|
1399
1900
|
return server_common.get_request_id(response)
|
|
1400
1901
|
|
|
1401
1902
|
|
|
@@ -1403,7 +1904,8 @@ def realtime_kubernetes_gpu_availability(
|
|
|
1403
1904
|
@server_common.check_server_healthy_or_start
|
|
1404
1905
|
@annotations.client_api
|
|
1405
1906
|
def kubernetes_node_info(
|
|
1406
|
-
|
|
1907
|
+
context: Optional[str] = None
|
|
1908
|
+
) -> server_common.RequestId['models.KubernetesNodesInfo']:
|
|
1407
1909
|
"""Gets the resource information for all the nodes in the cluster.
|
|
1408
1910
|
|
|
1409
1911
|
Currently only GPU resources are supported. The function returns the total
|
|
@@ -1424,18 +1926,22 @@ def kubernetes_node_info(
|
|
|
1424
1926
|
information.
|
|
1425
1927
|
"""
|
|
1426
1928
|
body = payloads.KubernetesNodeInfoRequestBody(context=context)
|
|
1427
|
-
response =
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1929
|
+
response = server_common.make_authenticated_request(
|
|
1930
|
+
'POST',
|
|
1931
|
+
'/kubernetes_node_info',
|
|
1932
|
+
json=json.loads(body.model_dump_json()))
|
|
1431
1933
|
return server_common.get_request_id(response)
|
|
1432
1934
|
|
|
1433
1935
|
|
|
1434
1936
|
@usage_lib.entrypoint
|
|
1435
1937
|
@server_common.check_server_healthy_or_start
|
|
1436
1938
|
@annotations.client_api
|
|
1437
|
-
def status_kubernetes() -> server_common.RequestId
|
|
1438
|
-
|
|
1939
|
+
def status_kubernetes() -> server_common.RequestId[
|
|
1940
|
+
Tuple[List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
|
|
1941
|
+
List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
|
|
1942
|
+
List[responses.ManagedJobRecord], Optional[str]]]:
|
|
1943
|
+
"""[Experimental] Gets all SkyPilot clusters and jobs
|
|
1944
|
+
in the Kubernetes cluster.
|
|
1439
1945
|
|
|
1440
1946
|
Managed jobs and services are also included in the clusters returned.
|
|
1441
1947
|
The caller must parse the controllers to identify which clusters are run
|
|
@@ -1455,21 +1961,24 @@ def status_kubernetes() -> server_common.RequestId:
|
|
|
1455
1961
|
dictionary job info, see jobs.queue_from_kubernetes_pod for details.
|
|
1456
1962
|
- context: Kubernetes context used to fetch the cluster information.
|
|
1457
1963
|
"""
|
|
1458
|
-
response =
|
|
1459
|
-
|
|
1460
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1964
|
+
response = server_common.make_authenticated_request('GET',
|
|
1965
|
+
'/status_kubernetes')
|
|
1461
1966
|
return server_common.get_request_id(response)
|
|
1462
1967
|
|
|
1463
1968
|
|
|
1464
1969
|
# === API request APIs ===
|
|
1465
1970
|
@usage_lib.entrypoint
|
|
1466
|
-
@server_common.check_server_healthy_or_start
|
|
1467
1971
|
@annotations.client_api
|
|
1468
|
-
def get(request_id:
|
|
1972
|
+
def get(request_id: server_common.RequestId[T]) -> T:
|
|
1469
1973
|
"""Waits for and gets the result of a request.
|
|
1470
1974
|
|
|
1975
|
+
This function will not check the server health since /api/get is typically
|
|
1976
|
+
not the first API call in an SDK session and checking the server health
|
|
1977
|
+
may cause GET /api/get being sent to a restarted API server.
|
|
1978
|
+
|
|
1471
1979
|
Args:
|
|
1472
|
-
request_id: The request ID of the request to get.
|
|
1980
|
+
request_id: The request ID of the request to get. May be a full request
|
|
1981
|
+
ID or a prefix.
|
|
1473
1982
|
|
|
1474
1983
|
Returns:
|
|
1475
1984
|
The ``Request Returns`` of the specified request. See the documentation
|
|
@@ -1480,19 +1989,20 @@ def get(request_id: str) -> Any:
|
|
|
1480
1989
|
see ``Request Raises`` in the documentation of the specific requests
|
|
1481
1990
|
above.
|
|
1482
1991
|
"""
|
|
1483
|
-
response =
|
|
1484
|
-
|
|
1992
|
+
response = server_common.make_authenticated_request(
|
|
1993
|
+
'GET',
|
|
1994
|
+
f'/api/get?request_id={request_id}',
|
|
1995
|
+
retry=False,
|
|
1485
1996
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1486
|
-
None)
|
|
1487
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1997
|
+
None))
|
|
1488
1998
|
request_task = None
|
|
1489
1999
|
if response.status_code == 200:
|
|
1490
2000
|
request_task = requests_lib.Request.decode(
|
|
1491
|
-
|
|
2001
|
+
payloads.RequestPayload(**response.json()))
|
|
1492
2002
|
elif response.status_code == 500:
|
|
1493
2003
|
try:
|
|
1494
2004
|
request_task = requests_lib.Request.decode(
|
|
1495
|
-
|
|
2005
|
+
payloads.RequestPayload(**response.json().get('detail')))
|
|
1496
2006
|
logger.debug(f'Got request with error: {request_task.name}')
|
|
1497
2007
|
except Exception: # pylint: disable=broad-except
|
|
1498
2008
|
request_task = None
|
|
@@ -1503,12 +2013,7 @@ def get(request_id: str) -> Any:
|
|
|
1503
2013
|
error = request_task.get_error()
|
|
1504
2014
|
if error is not None:
|
|
1505
2015
|
error_obj = error['object']
|
|
1506
|
-
|
|
1507
|
-
stacktrace = getattr(error_obj, 'stacktrace', str(error_obj))
|
|
1508
|
-
logger.error('=== Traceback on SkyPilot API Server ===\n'
|
|
1509
|
-
f'{stacktrace}')
|
|
1510
|
-
with ux_utils.print_exception_no_traceback():
|
|
1511
|
-
raise error_obj
|
|
2016
|
+
_raise_exception_object_on_client(error_obj)
|
|
1512
2017
|
if request_task.status == requests_lib.RequestStatus.CANCELLED:
|
|
1513
2018
|
with ux_utils.print_exception_no_traceback():
|
|
1514
2019
|
raise exceptions.RequestCancelled(
|
|
@@ -1518,23 +2023,45 @@ def get(request_id: str) -> Any:
|
|
|
1518
2023
|
return request_task.get_return_value()
|
|
1519
2024
|
|
|
1520
2025
|
|
|
2026
|
+
@typing.overload
|
|
2027
|
+
def stream_and_get(request_id: server_common.RequestId[T],
|
|
2028
|
+
log_path: Optional[str] = None,
|
|
2029
|
+
tail: Optional[int] = None,
|
|
2030
|
+
follow: bool = True,
|
|
2031
|
+
output_stream: Optional['io.TextIOBase'] = None) -> T:
|
|
2032
|
+
...
|
|
2033
|
+
|
|
2034
|
+
|
|
2035
|
+
@typing.overload
|
|
2036
|
+
def stream_and_get(request_id: None = None,
|
|
2037
|
+
log_path: Optional[str] = None,
|
|
2038
|
+
tail: Optional[int] = None,
|
|
2039
|
+
follow: bool = True,
|
|
2040
|
+
output_stream: Optional['io.TextIOBase'] = None) -> None:
|
|
2041
|
+
...
|
|
2042
|
+
|
|
2043
|
+
|
|
1521
2044
|
@usage_lib.entrypoint
|
|
1522
2045
|
@server_common.check_server_healthy_or_start
|
|
1523
2046
|
@annotations.client_api
|
|
2047
|
+
@rest.retry_transient_errors()
|
|
1524
2048
|
def stream_and_get(
|
|
1525
|
-
request_id: Optional[
|
|
2049
|
+
request_id: Optional[server_common.RequestId[T]] = None,
|
|
1526
2050
|
log_path: Optional[str] = None,
|
|
1527
2051
|
tail: Optional[int] = None,
|
|
1528
2052
|
follow: bool = True,
|
|
1529
2053
|
output_stream: Optional['io.TextIOBase'] = None,
|
|
1530
|
-
) ->
|
|
2054
|
+
) -> Optional[T]:
|
|
1531
2055
|
"""Streams the logs of a request or a log file and gets the final result.
|
|
1532
2056
|
|
|
1533
2057
|
This will block until the request is finished. The request id can be a
|
|
1534
2058
|
prefix of the full request id.
|
|
1535
2059
|
|
|
1536
2060
|
Args:
|
|
1537
|
-
request_id: The
|
|
2061
|
+
request_id: The request ID of the request to stream. May be a full
|
|
2062
|
+
request ID or a prefix.
|
|
2063
|
+
If None, the latest request submitted to the API server is streamed.
|
|
2064
|
+
Using None request_id is not recommended in multi-user environments.
|
|
1538
2065
|
log_path: The path to the log file to stream.
|
|
1539
2066
|
tail: The number of lines to show from the end of the logs.
|
|
1540
2067
|
If None, show all logs.
|
|
@@ -1545,6 +2072,8 @@ def stream_and_get(
|
|
|
1545
2072
|
Returns:
|
|
1546
2073
|
The ``Request Returns`` of the specified request. See the documentation
|
|
1547
2074
|
of the specific requests above for more details.
|
|
2075
|
+
If follow is False, will always return None. See note on
|
|
2076
|
+
stream_response.
|
|
1548
2077
|
|
|
1549
2078
|
Raises:
|
|
1550
2079
|
Exception: It raises the same exceptions as the specific requests,
|
|
@@ -1558,27 +2087,44 @@ def stream_and_get(
|
|
|
1558
2087
|
'follow': follow,
|
|
1559
2088
|
'format': 'console',
|
|
1560
2089
|
}
|
|
1561
|
-
response =
|
|
1562
|
-
|
|
2090
|
+
response = server_common.make_authenticated_request(
|
|
2091
|
+
'GET',
|
|
2092
|
+
'/api/stream',
|
|
1563
2093
|
params=params,
|
|
2094
|
+
retry=False,
|
|
1564
2095
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1565
2096
|
None),
|
|
1566
|
-
stream=True
|
|
1567
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2097
|
+
stream=True)
|
|
1568
2098
|
if response.status_code in [404, 400]:
|
|
1569
2099
|
detail = response.json().get('detail')
|
|
1570
2100
|
with ux_utils.print_exception_no_traceback():
|
|
1571
|
-
raise
|
|
2101
|
+
raise exceptions.ClientError(f'Failed to stream logs: {detail}')
|
|
2102
|
+
stream_request_id: Optional[server_common.RequestId[
|
|
2103
|
+
T]] = server_common.get_stream_request_id(response)
|
|
2104
|
+
if request_id is not None and stream_request_id is not None:
|
|
2105
|
+
assert request_id == stream_request_id
|
|
2106
|
+
if request_id is None:
|
|
2107
|
+
request_id = stream_request_id
|
|
1572
2108
|
elif response.status_code != 200:
|
|
2109
|
+
# TODO(syang): handle the case where the requestID is not provided
|
|
2110
|
+
# see https://github.com/skypilot-org/skypilot/issues/6549
|
|
2111
|
+
if request_id is None:
|
|
2112
|
+
return None
|
|
1573
2113
|
return get(request_id)
|
|
1574
|
-
return stream_response(request_id,
|
|
2114
|
+
return stream_response(request_id,
|
|
2115
|
+
response,
|
|
2116
|
+
output_stream,
|
|
2117
|
+
resumable=True,
|
|
2118
|
+
get_result=follow)
|
|
1575
2119
|
|
|
1576
2120
|
|
|
1577
2121
|
@usage_lib.entrypoint
|
|
1578
2122
|
@annotations.client_api
|
|
1579
|
-
def api_cancel(request_ids: Optional[Union[
|
|
2123
|
+
def api_cancel(request_ids: Optional[Union[server_common.RequestId[T],
|
|
2124
|
+
List[server_common.RequestId[T]],
|
|
2125
|
+
str, List[str]]] = None,
|
|
1580
2126
|
all_users: bool = False,
|
|
1581
|
-
silent: bool = False) -> server_common.RequestId:
|
|
2127
|
+
silent: bool = False) -> server_common.RequestId[List[str]]:
|
|
1582
2128
|
"""Aborts a request or all requests.
|
|
1583
2129
|
|
|
1584
2130
|
Args:
|
|
@@ -1618,20 +2164,35 @@ def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
|
|
|
1618
2164
|
echo(f'Cancelling {len(request_ids)} request{plural}: '
|
|
1619
2165
|
f'{request_id_str}...')
|
|
1620
2166
|
|
|
1621
|
-
response =
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
2167
|
+
response = server_common.make_authenticated_request(
|
|
2168
|
+
'POST',
|
|
2169
|
+
'/api/cancel',
|
|
2170
|
+
json=json.loads(body.model_dump_json()),
|
|
2171
|
+
timeout=5)
|
|
1625
2172
|
return server_common.get_request_id(response)
|
|
1626
2173
|
|
|
1627
2174
|
|
|
2175
|
+
def _local_api_server_running(kill: bool = False) -> bool:
|
|
2176
|
+
"""Checks if the local api server is running."""
|
|
2177
|
+
for process in psutil.process_iter(attrs=['pid', 'cmdline']):
|
|
2178
|
+
cmdline = process.info['cmdline']
|
|
2179
|
+
if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
|
|
2180
|
+
if kill:
|
|
2181
|
+
subprocess_utils.kill_children_processes(
|
|
2182
|
+
parent_pids=[process.pid], force=True)
|
|
2183
|
+
return True
|
|
2184
|
+
return False
|
|
2185
|
+
|
|
2186
|
+
|
|
1628
2187
|
@usage_lib.entrypoint
|
|
1629
2188
|
@annotations.client_api
|
|
1630
2189
|
def api_status(
|
|
1631
|
-
request_ids: Optional[List[str]] = None,
|
|
2190
|
+
request_ids: Optional[List[Union[server_common.RequestId[T], str]]] = None,
|
|
1632
2191
|
# pylint: disable=redefined-builtin
|
|
1633
|
-
all_status: bool = False
|
|
1634
|
-
|
|
2192
|
+
all_status: bool = False,
|
|
2193
|
+
limit: Optional[int] = None,
|
|
2194
|
+
fields: Optional[List[str]] = None,
|
|
2195
|
+
) -> List[payloads.RequestPayload]:
|
|
1635
2196
|
"""Lists all requests.
|
|
1636
2197
|
|
|
1637
2198
|
Args:
|
|
@@ -1639,29 +2200,37 @@ def api_status(
|
|
|
1639
2200
|
If None, all requests are queried.
|
|
1640
2201
|
all_status: Whether to list all finished requests as well. This argument
|
|
1641
2202
|
is ignored if request_ids is not None.
|
|
2203
|
+
limit: The number of requests to show. If None, show all requests.
|
|
2204
|
+
fields: The fields to get. If None, get all fields.
|
|
1642
2205
|
|
|
1643
2206
|
Returns:
|
|
1644
2207
|
A list of request payloads.
|
|
1645
2208
|
"""
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
2209
|
+
if server_common.is_api_server_local() and not _local_api_server_running():
|
|
2210
|
+
logger.info('SkyPilot API server is not running.')
|
|
2211
|
+
return []
|
|
2212
|
+
|
|
2213
|
+
body = payloads.RequestStatusBody(
|
|
2214
|
+
request_ids=request_ids,
|
|
2215
|
+
all_status=all_status,
|
|
2216
|
+
limit=limit,
|
|
2217
|
+
fields=fields,
|
|
2218
|
+
)
|
|
2219
|
+
response = server_common.make_authenticated_request(
|
|
2220
|
+
'GET',
|
|
2221
|
+
'/api/status',
|
|
1650
2222
|
params=server_common.request_body_to_params(body),
|
|
1651
2223
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1652
|
-
None)
|
|
1653
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2224
|
+
None))
|
|
1654
2225
|
server_common.handle_request_error(response)
|
|
1655
|
-
return [
|
|
1656
|
-
requests_lib.RequestPayload(**request) for request in response.json()
|
|
1657
|
-
]
|
|
2226
|
+
return [payloads.RequestPayload(**request) for request in response.json()]
|
|
1658
2227
|
|
|
1659
2228
|
|
|
1660
2229
|
# === API server management APIs ===
|
|
1661
2230
|
@usage_lib.entrypoint
|
|
1662
2231
|
@server_common.check_server_healthy_or_start
|
|
1663
2232
|
@annotations.client_api
|
|
1664
|
-
def api_info() ->
|
|
2233
|
+
def api_info() -> responses.APIHealthResponse:
|
|
1665
2234
|
"""Gets the server's status, commit and version.
|
|
1666
2235
|
|
|
1667
2236
|
Returns:
|
|
@@ -1674,13 +2243,19 @@ def api_info() -> Dict[str, str]:
|
|
|
1674
2243
|
'api_version': '1',
|
|
1675
2244
|
'commit': 'abc1234567890',
|
|
1676
2245
|
'version': '1.0.0',
|
|
2246
|
+
'version_on_disk': '1.0.0',
|
|
2247
|
+
'user': {
|
|
2248
|
+
'name': 'test@example.com',
|
|
2249
|
+
'id': '12345abcd',
|
|
2250
|
+
},
|
|
1677
2251
|
}
|
|
1678
2252
|
|
|
2253
|
+
Note that user may be None if we are not using an auth proxy.
|
|
2254
|
+
|
|
1679
2255
|
"""
|
|
1680
|
-
response =
|
|
1681
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2256
|
+
response = server_common.make_authenticated_request('GET', '/api/health')
|
|
1682
2257
|
response.raise_for_status()
|
|
1683
|
-
return response.json()
|
|
2258
|
+
return responses.APIHealthResponse(**response.json())
|
|
1684
2259
|
|
|
1685
2260
|
|
|
1686
2261
|
@usage_lib.entrypoint
|
|
@@ -1690,6 +2265,9 @@ def api_start(
|
|
|
1690
2265
|
deploy: bool = False,
|
|
1691
2266
|
host: str = '127.0.0.1',
|
|
1692
2267
|
foreground: bool = False,
|
|
2268
|
+
metrics: bool = False,
|
|
2269
|
+
metrics_port: Optional[int] = None,
|
|
2270
|
+
enable_basic_auth: bool = False,
|
|
1693
2271
|
) -> None:
|
|
1694
2272
|
"""Starts the API server.
|
|
1695
2273
|
|
|
@@ -1703,6 +2281,10 @@ def api_start(
|
|
|
1703
2281
|
if deploy is True, to allow remote access.
|
|
1704
2282
|
foreground: Whether to run the API server in the foreground (run in
|
|
1705
2283
|
the current process).
|
|
2284
|
+
metrics: Whether to export metrics of the API server.
|
|
2285
|
+
metrics_port: The port to export metrics of the API server.
|
|
2286
|
+
enable_basic_auth: Whether to enable basic authentication
|
|
2287
|
+
in the API server.
|
|
1706
2288
|
Returns:
|
|
1707
2289
|
None
|
|
1708
2290
|
"""
|
|
@@ -1721,15 +2303,15 @@ def api_start(
|
|
|
1721
2303
|
'from the config file and/or unset the '
|
|
1722
2304
|
'SKYPILOT_API_SERVER_ENDPOINT environment '
|
|
1723
2305
|
'variable.')
|
|
1724
|
-
server_common.check_server_healthy_or_start_fn(deploy, host, foreground
|
|
2306
|
+
server_common.check_server_healthy_or_start_fn(deploy, host, foreground,
|
|
2307
|
+
metrics, metrics_port,
|
|
2308
|
+
enable_basic_auth)
|
|
1725
2309
|
if foreground:
|
|
1726
2310
|
# Explain why current process exited
|
|
1727
2311
|
logger.info('API server is already running:')
|
|
1728
2312
|
api_server_url = server_common.get_server_url(host)
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server: '
|
|
1732
|
-
f'{api_server_url} {dashboard_msg}\n'
|
|
2313
|
+
logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server and dashboard: '
|
|
2314
|
+
f'{api_server_url}\n'
|
|
1733
2315
|
f'{ux_utils.INDENT_LAST_SYMBOL}'
|
|
1734
2316
|
f'View API server logs at: {constants.API_SERVER_LOGS}')
|
|
1735
2317
|
|
|
@@ -1752,16 +2334,32 @@ def api_stop() -> None:
|
|
|
1752
2334
|
f'Cannot kill the API server at {server_url} because it is not '
|
|
1753
2335
|
f'the default SkyPilot API server started locally.')
|
|
1754
2336
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
2337
|
+
# Acquire the api server creation lock to prevent multiple processes from
|
|
2338
|
+
# stopping and starting the API server at the same time.
|
|
2339
|
+
with filelock.FileLock(
|
|
2340
|
+
os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
|
|
2341
|
+
try:
|
|
2342
|
+
records = scheduler.get_controller_process_records()
|
|
2343
|
+
if records is not None:
|
|
2344
|
+
for record in records:
|
|
2345
|
+
try:
|
|
2346
|
+
if managed_job_utils.controller_process_alive(
|
|
2347
|
+
record, quiet=False):
|
|
2348
|
+
subprocess_utils.kill_children_processes(
|
|
2349
|
+
parent_pids=[record.pid], force=True)
|
|
2350
|
+
except (psutil.NoSuchProcess, psutil.ZombieProcess):
|
|
2351
|
+
continue
|
|
2352
|
+
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
|
|
2353
|
+
except FileNotFoundError:
|
|
2354
|
+
# its fine we will create it
|
|
2355
|
+
pass
|
|
2356
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2357
|
+
# in case we get perm issues or something is messed up, just ignore
|
|
2358
|
+
# it and assume the process is dead
|
|
2359
|
+
logger.error(f'Error looking at job controller pid file: {e}')
|
|
2360
|
+
pass
|
|
2361
|
+
|
|
2362
|
+
found = _local_api_server_running(kill=True)
|
|
1765
2363
|
|
|
1766
2364
|
if found:
|
|
1767
2365
|
logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.'
|
|
@@ -1796,9 +2394,86 @@ def api_server_logs(follow: bool = True, tail: Optional[int] = None) -> None:
|
|
|
1796
2394
|
stream_and_get(log_path=constants.API_SERVER_LOGS, tail=tail)
|
|
1797
2395
|
|
|
1798
2396
|
|
|
2397
|
+
def _save_config_updates(endpoint: Optional[str] = None,
|
|
2398
|
+
service_account_token: Optional[str] = None) -> None:
|
|
2399
|
+
"""Save endpoint and/or service account token to config file."""
|
|
2400
|
+
config_path = pathlib.Path(
|
|
2401
|
+
skypilot_config.get_user_config_path()).expanduser()
|
|
2402
|
+
with filelock.FileLock(config_path.with_suffix('.lock')):
|
|
2403
|
+
if not config_path.exists():
|
|
2404
|
+
config_path.touch()
|
|
2405
|
+
config: Dict[str, Any] = {}
|
|
2406
|
+
else:
|
|
2407
|
+
config = skypilot_config.get_user_config()
|
|
2408
|
+
config = dict(config)
|
|
2409
|
+
|
|
2410
|
+
# Update endpoint if provided
|
|
2411
|
+
if endpoint is not None:
|
|
2412
|
+
# We should always reset the api_server config to avoid legacy
|
|
2413
|
+
# service account token.
|
|
2414
|
+
config['api_server'] = {}
|
|
2415
|
+
config['api_server']['endpoint'] = endpoint
|
|
2416
|
+
|
|
2417
|
+
# Update service account token if provided
|
|
2418
|
+
if service_account_token is not None:
|
|
2419
|
+
if 'api_server' not in config:
|
|
2420
|
+
config['api_server'] = {}
|
|
2421
|
+
config['api_server'][
|
|
2422
|
+
'service_account_token'] = service_account_token
|
|
2423
|
+
|
|
2424
|
+
yaml_utils.dump_yaml(str(config_path), config)
|
|
2425
|
+
skypilot_config.reload_config()
|
|
2426
|
+
|
|
2427
|
+
|
|
2428
|
+
def _clear_api_server_config() -> None:
|
|
2429
|
+
"""Clear endpoint and service account token from config file."""
|
|
2430
|
+
config_path = pathlib.Path(
|
|
2431
|
+
skypilot_config.get_user_config_path()).expanduser()
|
|
2432
|
+
with filelock.FileLock(config_path.with_suffix('.lock')):
|
|
2433
|
+
if not config_path.exists():
|
|
2434
|
+
return
|
|
2435
|
+
|
|
2436
|
+
config = skypilot_config.get_user_config()
|
|
2437
|
+
config = dict(config)
|
|
2438
|
+
if 'api_server' in config:
|
|
2439
|
+
# We might not have set the endpoint in the config file, so we
|
|
2440
|
+
# need to check before deleting.
|
|
2441
|
+
del config['api_server']
|
|
2442
|
+
|
|
2443
|
+
yaml_utils.dump_yaml(str(config_path), config, blank=True)
|
|
2444
|
+
skypilot_config.reload_config()
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
def _validate_endpoint(endpoint: Optional[str]) -> str:
|
|
2448
|
+
"""Validate and normalize the endpoint URL."""
|
|
2449
|
+
if endpoint is None:
|
|
2450
|
+
endpoint = click.prompt('Enter your SkyPilot API server endpoint')
|
|
2451
|
+
# Check endpoint is a valid URL
|
|
2452
|
+
if (endpoint is not None and not endpoint.startswith('http://') and
|
|
2453
|
+
not endpoint.startswith('https://')):
|
|
2454
|
+
raise click.BadParameter('Endpoint must be a valid URL.')
|
|
2455
|
+
return endpoint.rstrip('/')
|
|
2456
|
+
|
|
2457
|
+
|
|
2458
|
+
def _check_endpoint_in_env_var(is_login: bool) -> None:
|
|
2459
|
+
# If the user has set the endpoint via the environment variable, we should
|
|
2460
|
+
# not do anything as we can't disambiguate between the env var and the
|
|
2461
|
+
# config file.
|
|
2462
|
+
"""Check if the endpoint is set in the environment variable."""
|
|
2463
|
+
if constants.SKY_API_SERVER_URL_ENV_VAR in os.environ:
|
|
2464
|
+
with ux_utils.print_exception_no_traceback():
|
|
2465
|
+
action = 'login to' if is_login else 'logout of'
|
|
2466
|
+
raise RuntimeError(f'Cannot {action} API server when the endpoint '
|
|
2467
|
+
'is set via the environment variable. Run unset '
|
|
2468
|
+
f'{constants.SKY_API_SERVER_URL_ENV_VAR} to '
|
|
2469
|
+
'clear the environment variable.')
|
|
2470
|
+
|
|
2471
|
+
|
|
1799
2472
|
@usage_lib.entrypoint
|
|
1800
2473
|
@annotations.client_api
|
|
1801
|
-
def api_login(endpoint: Optional[str] = None
|
|
2474
|
+
def api_login(endpoint: Optional[str] = None,
|
|
2475
|
+
relogin: bool = False,
|
|
2476
|
+
service_account_token: Optional[str] = None) -> None:
|
|
1802
2477
|
"""Logs into a SkyPilot API server.
|
|
1803
2478
|
|
|
1804
2479
|
This sets the endpoint globally, i.e., all SkyPilot CLI and SDK calls will
|
|
@@ -1810,37 +2485,262 @@ def api_login(endpoint: Optional[str] = None) -> None:
|
|
|
1810
2485
|
Args:
|
|
1811
2486
|
endpoint: The endpoint of the SkyPilot API server, e.g.,
|
|
1812
2487
|
http://1.2.3.4:46580 or https://skypilot.mydomain.com.
|
|
2488
|
+
relogin: Whether to force relogin with OAuth2 when enabled.
|
|
2489
|
+
service_account_token: Service account token for authentication.
|
|
1813
2490
|
|
|
1814
2491
|
Returns:
|
|
1815
2492
|
None
|
|
1816
2493
|
"""
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
#
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
2494
|
+
_check_endpoint_in_env_var(is_login=True)
|
|
2495
|
+
|
|
2496
|
+
# Validate and normalize endpoint
|
|
2497
|
+
endpoint = _validate_endpoint(endpoint)
|
|
2498
|
+
|
|
2499
|
+
def _show_logged_in_message(
|
|
2500
|
+
endpoint: str, dashboard_url: str, user: Optional[Dict[str, Any]],
|
|
2501
|
+
server_status: server_common.ApiServerStatus) -> None:
|
|
2502
|
+
"""Show the logged in message."""
|
|
2503
|
+
if server_status != server_common.ApiServerStatus.HEALTHY:
|
|
2504
|
+
with ux_utils.print_exception_no_traceback():
|
|
2505
|
+
raise ValueError(f'Cannot log in API server at '
|
|
2506
|
+
f'{endpoint} (status: {server_status.value})')
|
|
2507
|
+
|
|
2508
|
+
identity_info = f'\n{ux_utils.INDENT_SYMBOL}{colorama.Fore.GREEN}User: '
|
|
2509
|
+
if user:
|
|
2510
|
+
user_name = user.get('name')
|
|
2511
|
+
user_id = user.get('id')
|
|
2512
|
+
if user_name and user_id:
|
|
2513
|
+
identity_info += f'{user_name} ({user_id})'
|
|
2514
|
+
elif user_id:
|
|
2515
|
+
identity_info += user_id
|
|
1836
2516
|
else:
|
|
1837
|
-
|
|
1838
|
-
config.set_nested(('api_server', 'endpoint'), endpoint)
|
|
1839
|
-
common_utils.dump_yaml(str(config_path), dict(config))
|
|
1840
|
-
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2517
|
+
identity_info = ''
|
|
1841
2518
|
dashboard_msg = f'Dashboard: {dashboard_url}'
|
|
1842
2519
|
click.secho(
|
|
1843
2520
|
f'Logged into SkyPilot API server at: {endpoint}'
|
|
2521
|
+
f'{identity_info}'
|
|
1844
2522
|
f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
|
|
1845
2523
|
f'{dashboard_msg}',
|
|
1846
2524
|
fg='green')
|
|
2525
|
+
|
|
2526
|
+
def _set_user_hash(user_hash: Optional[str]) -> None:
|
|
2527
|
+
if user_hash is not None:
|
|
2528
|
+
if not common_utils.is_valid_user_hash(user_hash):
|
|
2529
|
+
raise ValueError(f'Invalid user hash: {user_hash}')
|
|
2530
|
+
common_utils.set_user_hash_locally(user_hash)
|
|
2531
|
+
|
|
2532
|
+
# Handle service account token authentication
|
|
2533
|
+
if service_account_token:
|
|
2534
|
+
if not service_account_token.startswith('sky_'):
|
|
2535
|
+
raise ValueError('Invalid service account token format. '
|
|
2536
|
+
'Token must start with "sky_"')
|
|
2537
|
+
|
|
2538
|
+
# Save both endpoint and token to config in a single operation
|
|
2539
|
+
_save_config_updates(endpoint=endpoint,
|
|
2540
|
+
service_account_token=service_account_token)
|
|
2541
|
+
|
|
2542
|
+
# Test the authentication by checking server health
|
|
2543
|
+
try:
|
|
2544
|
+
server_status, api_server_info = server_common.check_server_healthy(
|
|
2545
|
+
endpoint)
|
|
2546
|
+
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2547
|
+
if api_server_info.user is not None:
|
|
2548
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2549
|
+
_show_logged_in_message(endpoint, dashboard_url,
|
|
2550
|
+
api_server_info.user, server_status)
|
|
2551
|
+
|
|
2552
|
+
return
|
|
2553
|
+
except exceptions.ApiServerConnectionError as e:
|
|
2554
|
+
with ux_utils.print_exception_no_traceback():
|
|
2555
|
+
raise RuntimeError(
|
|
2556
|
+
f'Failed to connect to API server at {endpoint}: {e}'
|
|
2557
|
+
) from e
|
|
2558
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2559
|
+
with ux_utils.print_exception_no_traceback():
|
|
2560
|
+
raise RuntimeError(
|
|
2561
|
+
f'{colorama.Fore.RED}Service account token authentication '
|
|
2562
|
+
f'failed:{colorama.Style.RESET_ALL} {e}') from None
|
|
2563
|
+
|
|
2564
|
+
# OAuth2/cookie-based authentication flow
|
|
2565
|
+
# TODO(zhwu): this SDK sets global endpoint, which may not be the best
|
|
2566
|
+
# design as a user may expect this is only effective for the current
|
|
2567
|
+
# session. We should consider using env var for specifying endpoint.
|
|
2568
|
+
|
|
2569
|
+
server_status, api_server_info = server_common.check_server_healthy(
|
|
2570
|
+
endpoint)
|
|
2571
|
+
if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
|
|
2572
|
+
# We detected an auth proxy, so go through the auth proxy cookie flow.
|
|
2573
|
+
token: Optional[str] = None
|
|
2574
|
+
server: Optional[oauth_lib.HTTPServer] = None
|
|
2575
|
+
try:
|
|
2576
|
+
callback_port = common_utils.find_free_port(8000)
|
|
2577
|
+
|
|
2578
|
+
token_container: Dict[str, Optional[str]] = {'token': None}
|
|
2579
|
+
logger.debug('Starting local authentication server...')
|
|
2580
|
+
server = oauth_lib.start_local_auth_server(callback_port,
|
|
2581
|
+
token_container,
|
|
2582
|
+
endpoint)
|
|
2583
|
+
|
|
2584
|
+
token_url = (f'{endpoint}/token?local_port={callback_port}')
|
|
2585
|
+
if webbrowser.open(token_url):
|
|
2586
|
+
click.echo(f'{colorama.Fore.GREEN}A web browser has been '
|
|
2587
|
+
f'opened at {token_url}. Please continue the login '
|
|
2588
|
+
f'in the web browser.{colorama.Style.RESET_ALL}\n'
|
|
2589
|
+
f'{colorama.Style.DIM}To manually copy the token, '
|
|
2590
|
+
f'press ctrl+c.{colorama.Style.RESET_ALL}')
|
|
2591
|
+
else:
|
|
2592
|
+
raise ValueError('Failed to open browser.')
|
|
2593
|
+
|
|
2594
|
+
start_time = time.time()
|
|
2595
|
+
|
|
2596
|
+
while (token_container['token'] is None and
|
|
2597
|
+
time.time() - start_time < oauth_lib.AUTH_TIMEOUT):
|
|
2598
|
+
time.sleep(1)
|
|
2599
|
+
|
|
2600
|
+
if token_container['token'] is None:
|
|
2601
|
+
click.echo(f'{colorama.Fore.YELLOW}Authentication timed out '
|
|
2602
|
+
f'after {oauth_lib.AUTH_TIMEOUT} seconds.')
|
|
2603
|
+
else:
|
|
2604
|
+
token = token_container['token']
|
|
2605
|
+
|
|
2606
|
+
except (Exception, KeyboardInterrupt) as e: # pylint: disable=broad-except
|
|
2607
|
+
logger.debug(f'Automatic authentication failed: {e}, '
|
|
2608
|
+
'falling back to manual token entry.')
|
|
2609
|
+
if isinstance(e, KeyboardInterrupt):
|
|
2610
|
+
click.echo(f'\n{colorama.Style.DIM}Interrupted. Press ctrl+c '
|
|
2611
|
+
f'again to exit.{colorama.Style.RESET_ALL}')
|
|
2612
|
+
# Fall back to manual token entry
|
|
2613
|
+
token_url = f'{endpoint}/token'
|
|
2614
|
+
click.echo('Authentication is needed. Please visit this URL '
|
|
2615
|
+
f'to set up the token:{colorama.Style.BRIGHT}\n\n'
|
|
2616
|
+
f'{token_url}\n{colorama.Style.RESET_ALL}')
|
|
2617
|
+
token = click.prompt('Paste the token')
|
|
2618
|
+
finally:
|
|
2619
|
+
if server is not None:
|
|
2620
|
+
try:
|
|
2621
|
+
server.server_close()
|
|
2622
|
+
except Exception: # pylint: disable=broad-except
|
|
2623
|
+
pass
|
|
2624
|
+
if not token:
|
|
2625
|
+
with ux_utils.print_exception_no_traceback():
|
|
2626
|
+
raise ValueError('Authentication failed.')
|
|
2627
|
+
|
|
2628
|
+
# Parse the token.
|
|
2629
|
+
# b64decode will ignore invalid characters, but does some length and
|
|
2630
|
+
# padding checks.
|
|
2631
|
+
try:
|
|
2632
|
+
data = base64.b64decode(token)
|
|
2633
|
+
except binascii.Error as e:
|
|
2634
|
+
raise ValueError(f'Malformed token: {token}') from e
|
|
2635
|
+
logger.debug(f'Token data: {data!r}')
|
|
2636
|
+
try:
|
|
2637
|
+
json_data = json.loads(data)
|
|
2638
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
2639
|
+
raise ValueError(f'Malformed token data: {data!r}') from e
|
|
2640
|
+
if not isinstance(json_data, dict):
|
|
2641
|
+
raise ValueError(f'Malformed token JSON: {json_data}')
|
|
2642
|
+
|
|
2643
|
+
if json_data.get('v') == 1:
|
|
2644
|
+
user_hash = json_data.get('user')
|
|
2645
|
+
cookie_dict = json_data['cookies']
|
|
2646
|
+
elif 'v' not in json_data:
|
|
2647
|
+
user_hash = None
|
|
2648
|
+
cookie_dict = json_data
|
|
2649
|
+
else:
|
|
2650
|
+
raise ValueError(f'Unsupported token version: {json_data.get("v")}')
|
|
2651
|
+
|
|
2652
|
+
parsed_url = urlparse.urlparse(endpoint)
|
|
2653
|
+
cookie_jar = cookiejar.MozillaCookieJar()
|
|
2654
|
+
for (name, value) in cookie_dict.items():
|
|
2655
|
+
# dict keys in JSON must be strings
|
|
2656
|
+
assert isinstance(name, str)
|
|
2657
|
+
if not isinstance(value, str):
|
|
2658
|
+
raise ValueError('Malformed token - bad key/value: '
|
|
2659
|
+
f'{name}: {value}')
|
|
2660
|
+
|
|
2661
|
+
# See CookieJar._cookie_from_cookie_tuple
|
|
2662
|
+
# oauth2proxy default is Max-Age 604800
|
|
2663
|
+
expires = int(time.time()) + 604800
|
|
2664
|
+
domain = str(parsed_url.hostname)
|
|
2665
|
+
domain_initial_dot = domain.startswith('.')
|
|
2666
|
+
secure = parsed_url.scheme == 'https'
|
|
2667
|
+
if not domain_initial_dot:
|
|
2668
|
+
domain = '.' + domain
|
|
2669
|
+
|
|
2670
|
+
cookie_jar.set_cookie(
|
|
2671
|
+
cookiejar.Cookie(
|
|
2672
|
+
version=0,
|
|
2673
|
+
name=name,
|
|
2674
|
+
value=value,
|
|
2675
|
+
port=None,
|
|
2676
|
+
port_specified=False,
|
|
2677
|
+
domain=domain,
|
|
2678
|
+
domain_specified=True,
|
|
2679
|
+
domain_initial_dot=domain_initial_dot,
|
|
2680
|
+
path='',
|
|
2681
|
+
path_specified=False,
|
|
2682
|
+
secure=secure,
|
|
2683
|
+
expires=expires,
|
|
2684
|
+
discard=False,
|
|
2685
|
+
comment=None,
|
|
2686
|
+
comment_url=None,
|
|
2687
|
+
rest=dict(),
|
|
2688
|
+
))
|
|
2689
|
+
|
|
2690
|
+
# Now that the cookies are parsed, save them to the cookie jar.
|
|
2691
|
+
server_common.set_api_cookie_jar(cookie_jar)
|
|
2692
|
+
|
|
2693
|
+
# Set the user hash in the local file.
|
|
2694
|
+
# If the server already has a token for this user set it to the local
|
|
2695
|
+
# file, otherwise use the new user hash.
|
|
2696
|
+
if (api_server_info.user is not None and
|
|
2697
|
+
api_server_info.user.get('id') is not None):
|
|
2698
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2699
|
+
else:
|
|
2700
|
+
_set_user_hash(user_hash)
|
|
2701
|
+
else:
|
|
2702
|
+
# Check if basic auth is enabled
|
|
2703
|
+
if api_server_info.basic_auth_enabled:
|
|
2704
|
+
if api_server_info.user is None:
|
|
2705
|
+
with ux_utils.print_exception_no_traceback():
|
|
2706
|
+
raise ValueError(
|
|
2707
|
+
'Basic auth is enabled but no valid user is found')
|
|
2708
|
+
|
|
2709
|
+
# Set the user hash in the local file.
|
|
2710
|
+
if api_server_info.user is not None:
|
|
2711
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2712
|
+
|
|
2713
|
+
# Set the endpoint in the config file
|
|
2714
|
+
_save_config_updates(endpoint=endpoint)
|
|
2715
|
+
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2716
|
+
|
|
2717
|
+
# see https://github.com/python/mypy/issues/5107 on why
|
|
2718
|
+
# typing is disabled on this line
|
|
2719
|
+
server_common.get_api_server_status.cache_clear() # type: ignore
|
|
2720
|
+
# After successful authentication, check server health again to get user
|
|
2721
|
+
# identity
|
|
2722
|
+
server_status, final_api_server_info = server_common.check_server_healthy(
|
|
2723
|
+
endpoint)
|
|
2724
|
+
_show_logged_in_message(endpoint, dashboard_url, final_api_server_info.user,
|
|
2725
|
+
server_status)
|
|
2726
|
+
|
|
2727
|
+
|
|
2728
|
+
@usage_lib.entrypoint
|
|
2729
|
+
@annotations.client_api
|
|
2730
|
+
def api_logout() -> None:
|
|
2731
|
+
"""Logout of the API server.
|
|
2732
|
+
|
|
2733
|
+
Clears all cookies and settings stored in ~/.sky/config.yaml"""
|
|
2734
|
+
_check_endpoint_in_env_var(is_login=False)
|
|
2735
|
+
|
|
2736
|
+
if server_common.is_api_server_local():
|
|
2737
|
+
with ux_utils.print_exception_no_traceback():
|
|
2738
|
+
raise RuntimeError('Local api server cannot be logged out. '
|
|
2739
|
+
'Use `sky api stop` instead.')
|
|
2740
|
+
|
|
2741
|
+
# no need to clear cookies if it doesn't exist.
|
|
2742
|
+
server_common.set_api_cookie_jar(cookiejar.MozillaCookieJar(),
|
|
2743
|
+
create_if_not_exists=False)
|
|
2744
|
+
_clear_api_server_config()
|
|
2745
|
+
logger.info(f'{colorama.Fore.GREEN}Logged out of SkyPilot API server.'
|
|
2746
|
+
f'{colorama.Style.RESET_ALL}')
|