skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +22 -6
- sky/adaptors/aws.py +25 -7
- sky/adaptors/common.py +24 -1
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/do.py +8 -2
- sky/adaptors/hyperbolic.py +8 -0
- sky/adaptors/kubernetes.py +149 -18
- sky/adaptors/nebius.py +170 -17
- sky/adaptors/primeintellect.py +1 -0
- sky/adaptors/runpod.py +68 -0
- sky/adaptors/seeweb.py +167 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +187 -4
- sky/authentication.py +179 -225
- sky/backends/__init__.py +4 -2
- sky/backends/backend.py +22 -9
- sky/backends/backend_utils.py +1299 -380
- sky/backends/cloud_vm_ray_backend.py +1715 -518
- sky/backends/docker_utils.py +1 -1
- sky/backends/local_docker_backend.py +11 -6
- sky/backends/wheel_utils.py +37 -9
- sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
- sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
- sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
- sky/{clouds/service_catalog → catalog}/common.py +89 -48
- sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
- sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +30 -40
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +42 -15
- sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
- sky/catalog/data_fetchers/fetch_nebius.py +335 -0
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
- sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
- sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
- sky/catalog/hyperbolic_catalog.py +136 -0
- sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
- sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
- sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
- sky/catalog/primeintellect_catalog.py +95 -0
- sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
- sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
- sky/catalog/seeweb_catalog.py +184 -0
- sky/catalog/shadeform_catalog.py +165 -0
- sky/catalog/ssh_catalog.py +167 -0
- sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
- sky/check.py +491 -203
- sky/cli.py +5 -6005
- sky/client/{cli.py → cli/command.py} +2477 -1885
- sky/client/cli/deprecation_utils.py +99 -0
- sky/client/cli/flags.py +359 -0
- sky/client/cli/table_utils.py +320 -0
- sky/client/common.py +70 -32
- sky/client/oauth.py +82 -0
- sky/client/sdk.py +1203 -297
- sky/client/sdk_async.py +833 -0
- sky/client/service_account_auth.py +47 -0
- sky/cloud_stores.py +73 -0
- sky/clouds/__init__.py +13 -0
- sky/clouds/aws.py +358 -93
- sky/clouds/azure.py +105 -83
- sky/clouds/cloud.py +127 -36
- sky/clouds/cudo.py +68 -50
- sky/clouds/do.py +66 -48
- sky/clouds/fluidstack.py +63 -44
- sky/clouds/gcp.py +339 -110
- sky/clouds/hyperbolic.py +293 -0
- sky/clouds/ibm.py +70 -49
- sky/clouds/kubernetes.py +563 -162
- sky/clouds/lambda_cloud.py +74 -54
- sky/clouds/nebius.py +206 -80
- sky/clouds/oci.py +88 -66
- sky/clouds/paperspace.py +61 -44
- sky/clouds/primeintellect.py +317 -0
- sky/clouds/runpod.py +164 -74
- sky/clouds/scp.py +89 -83
- sky/clouds/seeweb.py +466 -0
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +263 -0
- sky/clouds/utils/aws_utils.py +10 -4
- sky/clouds/utils/gcp_utils.py +87 -11
- sky/clouds/utils/oci_utils.py +38 -14
- sky/clouds/utils/scp_utils.py +177 -124
- sky/clouds/vast.py +99 -77
- sky/clouds/vsphere.py +51 -40
- sky/core.py +349 -139
- sky/dag.py +15 -0
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-74503c8e80fd253b.js +6 -0
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
- sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
- sky/dashboard/out/_next/static/chunks/3785.ad6adaa2a0fa9768.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
- sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
- sky/dashboard/out/_next/static/chunks/4725.a830b5c9e7867c92.js +1 -0
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
- sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
- sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
- sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
- sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
- sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
- sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
- sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
- sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
- sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
- sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
- sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
- sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a37d2063af475a1c.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters-d44859594e6f8064.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-6edeb7d06032adfc.js +21 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-479dde13399cf270.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-5ab3b907622cf0fe.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c5a3eeee1c218af1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces-22b23febb3e89ce1.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -0
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -0
- sky/dashboard/out/infra.html +1 -0
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -0
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -0
- sky/dashboard/out/volumes.html +1 -0
- sky/dashboard/out/workspace/new.html +1 -0
- sky/dashboard/out/workspaces/[name].html +1 -0
- sky/dashboard/out/workspaces.html +1 -0
- sky/data/data_utils.py +137 -1
- sky/data/mounting_utils.py +269 -84
- sky/data/storage.py +1451 -1807
- sky/data/storage_utils.py +43 -57
- sky/exceptions.py +132 -2
- sky/execution.py +206 -63
- sky/global_user_state.py +2374 -586
- sky/jobs/__init__.py +5 -0
- sky/jobs/client/sdk.py +242 -65
- sky/jobs/client/sdk_async.py +143 -0
- sky/jobs/constants.py +9 -8
- sky/jobs/controller.py +839 -277
- sky/jobs/file_content_utils.py +80 -0
- sky/jobs/log_gc.py +201 -0
- sky/jobs/recovery_strategy.py +398 -152
- sky/jobs/scheduler.py +315 -189
- sky/jobs/server/core.py +829 -255
- sky/jobs/server/server.py +156 -115
- sky/jobs/server/utils.py +136 -0
- sky/jobs/state.py +2092 -701
- sky/jobs/utils.py +1242 -160
- sky/logs/__init__.py +21 -0
- sky/logs/agent.py +108 -0
- sky/logs/aws.py +243 -0
- sky/logs/gcp.py +91 -0
- sky/metrics/__init__.py +0 -0
- sky/metrics/utils.py +443 -0
- sky/models.py +78 -1
- sky/optimizer.py +164 -70
- sky/provision/__init__.py +90 -4
- sky/provision/aws/config.py +147 -26
- sky/provision/aws/instance.py +135 -50
- sky/provision/azure/instance.py +10 -5
- sky/provision/common.py +13 -1
- sky/provision/cudo/cudo_machine_type.py +1 -1
- sky/provision/cudo/cudo_utils.py +14 -8
- sky/provision/cudo/cudo_wrapper.py +72 -71
- sky/provision/cudo/instance.py +10 -6
- sky/provision/do/instance.py +10 -6
- sky/provision/do/utils.py +4 -3
- sky/provision/docker_utils.py +114 -23
- sky/provision/fluidstack/instance.py +13 -8
- sky/provision/gcp/__init__.py +1 -0
- sky/provision/gcp/config.py +301 -19
- sky/provision/gcp/constants.py +218 -0
- sky/provision/gcp/instance.py +36 -8
- sky/provision/gcp/instance_utils.py +18 -4
- sky/provision/gcp/volume_utils.py +247 -0
- sky/provision/hyperbolic/__init__.py +12 -0
- sky/provision/hyperbolic/config.py +10 -0
- sky/provision/hyperbolic/instance.py +437 -0
- sky/provision/hyperbolic/utils.py +373 -0
- sky/provision/instance_setup.py +93 -14
- sky/provision/kubernetes/__init__.py +5 -0
- sky/provision/kubernetes/config.py +9 -52
- sky/provision/kubernetes/constants.py +17 -0
- sky/provision/kubernetes/instance.py +789 -247
- sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
- sky/provision/kubernetes/network.py +27 -17
- sky/provision/kubernetes/network_utils.py +40 -43
- sky/provision/kubernetes/utils.py +1192 -531
- sky/provision/kubernetes/volume.py +282 -0
- sky/provision/lambda_cloud/instance.py +22 -16
- sky/provision/nebius/constants.py +50 -0
- sky/provision/nebius/instance.py +19 -6
- sky/provision/nebius/utils.py +196 -91
- sky/provision/oci/instance.py +10 -5
- sky/provision/paperspace/instance.py +10 -7
- sky/provision/paperspace/utils.py +1 -1
- sky/provision/primeintellect/__init__.py +10 -0
- sky/provision/primeintellect/config.py +11 -0
- sky/provision/primeintellect/instance.py +454 -0
- sky/provision/primeintellect/utils.py +398 -0
- sky/provision/provisioner.py +110 -36
- sky/provision/runpod/__init__.py +5 -0
- sky/provision/runpod/instance.py +27 -6
- sky/provision/runpod/utils.py +51 -18
- sky/provision/runpod/volume.py +180 -0
- sky/provision/scp/__init__.py +15 -0
- sky/provision/scp/config.py +93 -0
- sky/provision/scp/instance.py +531 -0
- sky/provision/seeweb/__init__.py +11 -0
- sky/provision/seeweb/config.py +13 -0
- sky/provision/seeweb/instance.py +807 -0
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/ssh/__init__.py +18 -0
- sky/provision/vast/instance.py +13 -8
- sky/provision/vast/utils.py +10 -7
- sky/provision/vsphere/common/vim_utils.py +1 -2
- sky/provision/vsphere/instance.py +15 -10
- sky/provision/vsphere/vsphere_utils.py +9 -19
- sky/py.typed +0 -0
- sky/resources.py +844 -118
- sky/schemas/__init__.py +0 -0
- sky/schemas/api/__init__.py +0 -0
- sky/schemas/api/responses.py +225 -0
- sky/schemas/db/README +4 -0
- sky/schemas/db/env.py +90 -0
- sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
- sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
- sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
- sky/schemas/db/global_user_state/004_is_managed.py +34 -0
- sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
- sky/schemas/db/global_user_state/006_provision_log.py +41 -0
- sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
- sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
- sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
- sky/schemas/db/script.py.mako +28 -0
- sky/schemas/db/serve_state/001_initial_schema.py +67 -0
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
- sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
- sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/generated/__init__.py +0 -0
- sky/schemas/generated/autostopv1_pb2.py +36 -0
- sky/schemas/generated/autostopv1_pb2.pyi +43 -0
- sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
- sky/schemas/generated/jobsv1_pb2.py +86 -0
- sky/schemas/generated/jobsv1_pb2.pyi +254 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +74 -0
- sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
- sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
- sky/schemas/generated/servev1_pb2.py +58 -0
- sky/schemas/generated/servev1_pb2.pyi +115 -0
- sky/schemas/generated/servev1_pb2_grpc.py +322 -0
- sky/serve/autoscalers.py +357 -5
- sky/serve/client/impl.py +310 -0
- sky/serve/client/sdk.py +47 -139
- sky/serve/client/sdk_async.py +130 -0
- sky/serve/constants.py +10 -8
- sky/serve/controller.py +64 -19
- sky/serve/load_balancer.py +106 -60
- sky/serve/load_balancing_policies.py +115 -1
- sky/serve/replica_managers.py +273 -162
- sky/serve/serve_rpc_utils.py +179 -0
- sky/serve/serve_state.py +554 -251
- sky/serve/serve_utils.py +733 -220
- sky/serve/server/core.py +66 -711
- sky/serve/server/impl.py +1093 -0
- sky/serve/server/server.py +21 -18
- sky/serve/service.py +133 -48
- sky/serve/service_spec.py +135 -16
- sky/serve/spot_placer.py +3 -0
- sky/server/auth/__init__.py +0 -0
- sky/server/auth/authn.py +50 -0
- sky/server/auth/loopback.py +38 -0
- sky/server/auth/oauth2_proxy.py +200 -0
- sky/server/common.py +475 -181
- sky/server/config.py +81 -23
- sky/server/constants.py +44 -6
- sky/server/daemons.py +229 -0
- sky/server/html/token_page.html +185 -0
- sky/server/metrics.py +160 -0
- sky/server/requests/executor.py +528 -138
- sky/server/requests/payloads.py +351 -17
- sky/server/requests/preconditions.py +21 -17
- sky/server/requests/process.py +112 -29
- sky/server/requests/request_names.py +120 -0
- sky/server/requests/requests.py +817 -224
- sky/server/requests/serializers/decoders.py +82 -31
- sky/server/requests/serializers/encoders.py +140 -22
- sky/server/requests/threads.py +106 -0
- sky/server/rest.py +417 -0
- sky/server/server.py +1290 -284
- sky/server/state.py +20 -0
- sky/server/stream_utils.py +345 -57
- sky/server/uvicorn.py +217 -3
- sky/server/versions.py +270 -0
- sky/setup_files/MANIFEST.in +5 -0
- sky/setup_files/alembic.ini +156 -0
- sky/setup_files/dependencies.py +136 -31
- sky/setup_files/setup.py +44 -42
- sky/sky_logging.py +102 -5
- sky/skylet/attempt_skylet.py +1 -0
- sky/skylet/autostop_lib.py +129 -8
- sky/skylet/configs.py +27 -20
- sky/skylet/constants.py +171 -19
- sky/skylet/events.py +105 -21
- sky/skylet/job_lib.py +335 -104
- sky/skylet/log_lib.py +297 -18
- sky/skylet/log_lib.pyi +44 -1
- sky/skylet/ray_patches/__init__.py +17 -3
- sky/skylet/ray_patches/autoscaler.py.diff +18 -0
- sky/skylet/ray_patches/cli.py.diff +19 -0
- sky/skylet/ray_patches/command_runner.py.diff +17 -0
- sky/skylet/ray_patches/log_monitor.py.diff +20 -0
- sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
- sky/skylet/ray_patches/updater.py.diff +18 -0
- sky/skylet/ray_patches/worker.py.diff +41 -0
- sky/skylet/services.py +564 -0
- sky/skylet/skylet.py +63 -4
- sky/skylet/subprocess_daemon.py +103 -29
- sky/skypilot_config.py +506 -99
- sky/ssh_node_pools/__init__.py +1 -0
- sky/ssh_node_pools/core.py +135 -0
- sky/ssh_node_pools/server.py +233 -0
- sky/task.py +621 -137
- sky/templates/aws-ray.yml.j2 +10 -3
- sky/templates/azure-ray.yml.j2 +1 -1
- sky/templates/do-ray.yml.j2 +1 -1
- sky/templates/gcp-ray.yml.j2 +57 -0
- sky/templates/hyperbolic-ray.yml.j2 +67 -0
- sky/templates/jobs-controller.yaml.j2 +27 -24
- sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
- sky/templates/kubernetes-ray.yml.j2 +607 -51
- sky/templates/lambda-ray.yml.j2 +1 -1
- sky/templates/nebius-ray.yml.j2 +33 -12
- sky/templates/paperspace-ray.yml.j2 +1 -1
- sky/templates/primeintellect-ray.yml.j2 +71 -0
- sky/templates/runpod-ray.yml.j2 +9 -1
- sky/templates/scp-ray.yml.j2 +3 -50
- sky/templates/seeweb-ray.yml.j2 +108 -0
- sky/templates/shadeform-ray.yml.j2 +72 -0
- sky/templates/sky-serve-controller.yaml.j2 +22 -2
- sky/templates/websocket_proxy.py +178 -18
- sky/usage/usage_lib.py +18 -11
- sky/users/__init__.py +0 -0
- sky/users/model.conf +15 -0
- sky/users/permission.py +387 -0
- sky/users/rbac.py +121 -0
- sky/users/server.py +720 -0
- sky/users/token_service.py +218 -0
- sky/utils/accelerator_registry.py +34 -5
- sky/utils/admin_policy_utils.py +84 -38
- sky/utils/annotations.py +16 -5
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/auth_utils.py +153 -0
- sky/utils/benchmark_utils.py +60 -0
- sky/utils/cli_utils/status_utils.py +159 -86
- sky/utils/cluster_utils.py +31 -9
- sky/utils/command_runner.py +354 -68
- sky/utils/command_runner.pyi +93 -3
- sky/utils/common.py +35 -8
- sky/utils/common_utils.py +310 -87
- sky/utils/config_utils.py +87 -5
- sky/utils/context.py +402 -0
- sky/utils/context_utils.py +222 -0
- sky/utils/controller_utils.py +264 -89
- sky/utils/dag_utils.py +31 -12
- sky/utils/db/__init__.py +0 -0
- sky/utils/db/db_utils.py +470 -0
- sky/utils/db/migration_utils.py +133 -0
- sky/utils/directory_utils.py +12 -0
- sky/utils/env_options.py +13 -0
- sky/utils/git.py +567 -0
- sky/utils/git_clone.sh +460 -0
- sky/utils/infra_utils.py +195 -0
- sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
- sky/utils/kubernetes/config_map_utils.py +133 -0
- sky/utils/kubernetes/create_cluster.sh +13 -27
- sky/utils/kubernetes/delete_cluster.sh +10 -7
- sky/utils/kubernetes/deploy_remote_cluster.py +1299 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
- sky/utils/kubernetes/generate_kind_config.py +6 -66
- sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
- sky/utils/kubernetes/gpu_labeler.py +5 -5
- sky/utils/kubernetes/kubernetes_deploy_utils.py +354 -47
- sky/utils/kubernetes/ssh-tunnel.sh +379 -0
- sky/utils/kubernetes/ssh_utils.py +221 -0
- sky/utils/kubernetes_enums.py +8 -15
- sky/utils/lock_events.py +94 -0
- sky/utils/locks.py +368 -0
- sky/utils/log_utils.py +300 -6
- sky/utils/perf_utils.py +22 -0
- sky/utils/resource_checker.py +298 -0
- sky/utils/resources_utils.py +249 -32
- sky/utils/rich_utils.py +213 -37
- sky/utils/schemas.py +905 -147
- sky/utils/serialize_utils.py +16 -0
- sky/utils/status_lib.py +10 -0
- sky/utils/subprocess_utils.py +38 -15
- sky/utils/tempstore.py +70 -0
- sky/utils/timeline.py +24 -52
- sky/utils/ux_utils.py +84 -15
- sky/utils/validator.py +11 -1
- sky/utils/volume.py +86 -0
- sky/utils/yaml_utils.py +111 -0
- sky/volumes/__init__.py +13 -0
- sky/volumes/client/__init__.py +0 -0
- sky/volumes/client/sdk.py +149 -0
- sky/volumes/server/__init__.py +0 -0
- sky/volumes/server/core.py +258 -0
- sky/volumes/server/server.py +122 -0
- sky/volumes/volume.py +212 -0
- sky/workspaces/__init__.py +0 -0
- sky/workspaces/core.py +655 -0
- sky/workspaces/server.py +101 -0
- sky/workspaces/utils.py +56 -0
- skypilot_nightly-1.0.0.dev20251107.dist-info/METADATA +675 -0
- skypilot_nightly-1.0.0.dev20251107.dist-info/RECORD +594 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +1 -1
- sky/benchmark/benchmark_state.py +0 -256
- sky/benchmark/benchmark_utils.py +0 -641
- sky/clouds/service_catalog/constants.py +0 -7
- sky/dashboard/out/_next/static/LksQgChY5izXjokL3LcEu/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
- sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
- sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
- sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
- sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
- sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
- sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
- sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
- sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
- sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
- sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
- sky/jobs/dashboard/dashboard.py +0 -223
- sky/jobs/dashboard/static/favicon.ico +0 -0
- sky/jobs/dashboard/templates/index.html +0 -831
- sky/jobs/server/dashboard_utils.py +0 -69
- sky/skylet/providers/scp/__init__.py +0 -2
- sky/skylet/providers/scp/config.py +0 -149
- sky/skylet/providers/scp/node_provider.py +0 -578
- sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
- sky/utils/db_utils.py +0 -100
- sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
- skypilot_nightly-1.0.0.dev20250509.dist-info/METADATA +0 -361
- skypilot_nightly-1.0.0.dev20250509.dist-info/RECORD +0 -396
- /sky/{clouds/service_catalog → catalog}/config.py +0 -0
- /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
- /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
- /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/client/sdk.py
CHANGED
|
@@ -10,75 +10,164 @@ Usage example:
|
|
|
10
10
|
statuses = sky.get(request_id)
|
|
11
11
|
|
|
12
12
|
"""
|
|
13
|
-
import
|
|
13
|
+
from http import cookiejar
|
|
14
14
|
import json
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
|
-
import pathlib
|
|
18
17
|
import subprocess
|
|
19
18
|
import typing
|
|
20
|
-
from typing import Any, Dict, List, Optional, Tuple,
|
|
21
|
-
|
|
19
|
+
from typing import (Any, Dict, Iterator, List, Literal, Optional, Tuple,
|
|
20
|
+
TypeVar, Union)
|
|
21
|
+
from urllib import parse as urlparse
|
|
22
22
|
|
|
23
23
|
import click
|
|
24
24
|
import colorama
|
|
25
25
|
import filelock
|
|
26
26
|
|
|
27
27
|
from sky import admin_policy
|
|
28
|
-
from sky import backends
|
|
29
28
|
from sky import exceptions
|
|
30
29
|
from sky import sky_logging
|
|
31
30
|
from sky import skypilot_config
|
|
32
31
|
from sky.adaptors import common as adaptors_common
|
|
33
32
|
from sky.client import common as client_common
|
|
33
|
+
from sky.client import oauth as oauth_lib
|
|
34
|
+
from sky.jobs import scheduler
|
|
35
|
+
from sky.schemas.api import responses
|
|
34
36
|
from sky.server import common as server_common
|
|
37
|
+
from sky.server import rest
|
|
38
|
+
from sky.server import versions
|
|
35
39
|
from sky.server.requests import payloads
|
|
40
|
+
from sky.server.requests import request_names
|
|
36
41
|
from sky.server.requests import requests as requests_lib
|
|
42
|
+
from sky.skylet import autostop_lib
|
|
37
43
|
from sky.skylet import constants
|
|
38
44
|
from sky.usage import usage_lib
|
|
45
|
+
from sky.utils import admin_policy_utils
|
|
39
46
|
from sky.utils import annotations
|
|
40
47
|
from sky.utils import cluster_utils
|
|
41
48
|
from sky.utils import common
|
|
42
49
|
from sky.utils import common_utils
|
|
50
|
+
from sky.utils import context as sky_context
|
|
43
51
|
from sky.utils import dag_utils
|
|
44
52
|
from sky.utils import env_options
|
|
53
|
+
from sky.utils import infra_utils
|
|
45
54
|
from sky.utils import rich_utils
|
|
46
55
|
from sky.utils import status_lib
|
|
47
56
|
from sky.utils import subprocess_utils
|
|
48
57
|
from sky.utils import ux_utils
|
|
58
|
+
from sky.utils import yaml_utils
|
|
59
|
+
from sky.utils.kubernetes import ssh_utils
|
|
49
60
|
|
|
50
61
|
if typing.TYPE_CHECKING:
|
|
62
|
+
import base64
|
|
63
|
+
import binascii
|
|
51
64
|
import io
|
|
65
|
+
import pathlib
|
|
66
|
+
import time
|
|
67
|
+
import webbrowser
|
|
52
68
|
|
|
53
69
|
import psutil
|
|
54
70
|
import requests
|
|
55
71
|
|
|
56
72
|
import sky
|
|
73
|
+
from sky import backends
|
|
74
|
+
from sky import catalog
|
|
75
|
+
from sky import models
|
|
76
|
+
from sky.provision.kubernetes import utils as kubernetes_utils
|
|
77
|
+
from sky.skylet import job_lib
|
|
57
78
|
else:
|
|
79
|
+
# only used in api_login()
|
|
80
|
+
base64 = adaptors_common.LazyImport('base64')
|
|
81
|
+
binascii = adaptors_common.LazyImport('binascii')
|
|
82
|
+
pathlib = adaptors_common.LazyImport('pathlib')
|
|
83
|
+
time = adaptors_common.LazyImport('time')
|
|
84
|
+
# only used in dashboard() and api_login()
|
|
85
|
+
webbrowser = adaptors_common.LazyImport('webbrowser')
|
|
86
|
+
# only used in api_stop()
|
|
58
87
|
psutil = adaptors_common.LazyImport('psutil')
|
|
59
|
-
requests = adaptors_common.LazyImport('requests')
|
|
60
88
|
|
|
61
89
|
logger = sky_logging.init_logger(__name__)
|
|
62
90
|
logging.getLogger('httpx').setLevel(logging.CRITICAL)
|
|
63
91
|
|
|
92
|
+
_LINE_PROCESSED_KEY = 'line_processed'
|
|
64
93
|
|
|
65
|
-
|
|
94
|
+
T = TypeVar('T')
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def reload_config() -> None:
|
|
98
|
+
"""Reloads the client-side config."""
|
|
99
|
+
skypilot_config.safe_reload_config()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# The overloads are not comprehensive - e.g. get_result Literal[False] could be
|
|
103
|
+
# specified to return None. We can add more overloads if needed. To do that see
|
|
104
|
+
# https://github.com/python/mypy/issues/8634#issuecomment-609411104
|
|
105
|
+
@typing.overload
|
|
106
|
+
def stream_response(request_id: None,
|
|
66
107
|
response: 'requests.Response',
|
|
67
|
-
output_stream: Optional['io.TextIOBase'] = None
|
|
108
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
109
|
+
resumable: bool = False,
|
|
110
|
+
get_result: bool = True) -> None:
|
|
111
|
+
...
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@typing.overload
|
|
115
|
+
def stream_response(request_id: server_common.RequestId[T],
|
|
116
|
+
response: 'requests.Response',
|
|
117
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
118
|
+
resumable: bool = False,
|
|
119
|
+
get_result: Literal[True] = True) -> T:
|
|
120
|
+
...
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@typing.overload
|
|
124
|
+
def stream_response(request_id: server_common.RequestId[T],
|
|
125
|
+
response: 'requests.Response',
|
|
126
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
127
|
+
resumable: bool = False,
|
|
128
|
+
get_result: bool = True) -> Optional[T]:
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def stream_response(request_id: Optional[server_common.RequestId[T]],
|
|
133
|
+
response: 'requests.Response',
|
|
134
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
135
|
+
resumable: bool = False,
|
|
136
|
+
get_result: bool = True) -> Optional[T]:
|
|
68
137
|
"""Streams the response to the console.
|
|
69
138
|
|
|
70
139
|
Args:
|
|
71
|
-
request_id: The request ID.
|
|
140
|
+
request_id: The request ID of the request to stream. May be a full
|
|
141
|
+
request ID or a prefix.
|
|
142
|
+
If None, the latest request submitted to the API server is streamed.
|
|
143
|
+
Using None request_id is not recommended in multi-user environments.
|
|
72
144
|
response: The HTTP response.
|
|
73
145
|
output_stream: The output stream to write to. If None, print to the
|
|
74
146
|
console.
|
|
147
|
+
resumable: Whether the response is resumable on retry. If True, the
|
|
148
|
+
streaming will start from the previous failure point on retry.
|
|
149
|
+
get_result: Whether to get the result of the request. This will
|
|
150
|
+
typically be set to False for `--no-follow` flags as requests may
|
|
151
|
+
continue to run for long periods of time without further streaming.
|
|
75
152
|
"""
|
|
76
153
|
|
|
154
|
+
retry_context: Optional[rest.RetryContext] = None
|
|
155
|
+
if resumable:
|
|
156
|
+
retry_context = rest.get_retry_context()
|
|
77
157
|
try:
|
|
158
|
+
line_count = 0
|
|
78
159
|
for line in rich_utils.decode_rich_status(response):
|
|
79
160
|
if line is not None:
|
|
80
|
-
|
|
81
|
-
|
|
161
|
+
line_count += 1
|
|
162
|
+
if retry_context is None:
|
|
163
|
+
print(line, flush=True, end='', file=output_stream)
|
|
164
|
+
elif line_count > retry_context.line_processed:
|
|
165
|
+
print(line, flush=True, end='', file=output_stream)
|
|
166
|
+
retry_context.line_processed = line_count
|
|
167
|
+
if request_id is not None and get_result:
|
|
168
|
+
return get(request_id)
|
|
169
|
+
else:
|
|
170
|
+
return None
|
|
82
171
|
except Exception: # pylint: disable=broad-except
|
|
83
172
|
logger.debug(f'To stream request logs: sky api logs {request_id}')
|
|
84
173
|
raise
|
|
@@ -87,13 +176,18 @@ def stream_response(request_id: Optional[str],
|
|
|
87
176
|
@usage_lib.entrypoint
|
|
88
177
|
@server_common.check_server_healthy_or_start
|
|
89
178
|
@annotations.client_api
|
|
90
|
-
def check(
|
|
91
|
-
|
|
179
|
+
def check(
|
|
180
|
+
infra_list: Optional[Tuple[str, ...]],
|
|
181
|
+
verbose: bool,
|
|
182
|
+
workspace: Optional[str] = None
|
|
183
|
+
) -> server_common.RequestId[Dict[str, List[str]]]:
|
|
92
184
|
"""Checks the credentials to enable clouds.
|
|
93
185
|
|
|
94
186
|
Args:
|
|
95
|
-
|
|
187
|
+
infra: The infra to check.
|
|
96
188
|
verbose: Whether to show verbose output.
|
|
189
|
+
workspace: The workspace to check. If None, all workspaces will be
|
|
190
|
+
checked.
|
|
97
191
|
|
|
98
192
|
Returns:
|
|
99
193
|
The request ID of the check request.
|
|
@@ -101,41 +195,69 @@ def check(clouds: Optional[Tuple[str]],
|
|
|
101
195
|
Request Returns:
|
|
102
196
|
None
|
|
103
197
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
198
|
+
if infra_list is None:
|
|
199
|
+
clouds = None
|
|
200
|
+
else:
|
|
201
|
+
specified_clouds = []
|
|
202
|
+
for infra_str in infra_list:
|
|
203
|
+
infra = infra_utils.InfraInfo.from_str(infra_str)
|
|
204
|
+
if infra.cloud is None:
|
|
205
|
+
with ux_utils.print_exception_no_traceback():
|
|
206
|
+
raise ValueError(f'Invalid infra to check: {infra_str}')
|
|
207
|
+
if infra.region is not None or infra.zone is not None:
|
|
208
|
+
region_zone = infra_str.partition('/')[-1]
|
|
209
|
+
logger.warning(f'Infra {infra_str} is specified, but `check` '
|
|
210
|
+
f'only supports checking {infra.cloud}, '
|
|
211
|
+
f'ignoring {region_zone}')
|
|
212
|
+
specified_clouds.append(infra.cloud)
|
|
213
|
+
clouds = tuple(specified_clouds)
|
|
214
|
+
body = payloads.CheckBody(clouds=clouds,
|
|
215
|
+
verbose=verbose,
|
|
216
|
+
workspace=workspace)
|
|
217
|
+
response = server_common.make_authenticated_request(
|
|
218
|
+
'POST', '/check', json=json.loads(body.model_dump_json()))
|
|
108
219
|
return server_common.get_request_id(response)
|
|
109
220
|
|
|
110
221
|
|
|
111
222
|
@usage_lib.entrypoint
|
|
112
223
|
@server_common.check_server_healthy_or_start
|
|
113
224
|
@annotations.client_api
|
|
114
|
-
def enabled_clouds(
|
|
225
|
+
def enabled_clouds(workspace: Optional[str] = None,
|
|
226
|
+
expand: bool = False) -> server_common.RequestId[List[str]]:
|
|
115
227
|
"""Gets the enabled clouds.
|
|
116
228
|
|
|
229
|
+
Args:
|
|
230
|
+
workspace: The workspace to get the enabled clouds for. If None, the
|
|
231
|
+
active workspace will be used.
|
|
232
|
+
expand: Whether to expand Kubernetes and SSH to list of resource pools.
|
|
233
|
+
|
|
117
234
|
Returns:
|
|
118
235
|
The request ID of the enabled clouds request.
|
|
119
236
|
|
|
120
237
|
Request Returns:
|
|
121
238
|
A list of enabled clouds in string format.
|
|
122
239
|
"""
|
|
123
|
-
|
|
124
|
-
|
|
240
|
+
if workspace is None:
|
|
241
|
+
workspace = skypilot_config.get_active_workspace()
|
|
242
|
+
response = server_common.make_authenticated_request(
|
|
243
|
+
'GET', f'/enabled_clouds?workspace={workspace}&expand={expand}')
|
|
125
244
|
return server_common.get_request_id(response)
|
|
126
245
|
|
|
127
246
|
|
|
128
247
|
@usage_lib.entrypoint
|
|
129
248
|
@server_common.check_server_healthy_or_start
|
|
130
249
|
@annotations.client_api
|
|
131
|
-
def list_accelerators(
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
250
|
+
def list_accelerators(
|
|
251
|
+
gpus_only: bool = True,
|
|
252
|
+
name_filter: Optional[str] = None,
|
|
253
|
+
region_filter: Optional[str] = None,
|
|
254
|
+
quantity_filter: Optional[int] = None,
|
|
255
|
+
clouds: Optional[Union[List[str], str]] = None,
|
|
256
|
+
all_regions: bool = False,
|
|
257
|
+
require_price: bool = True,
|
|
258
|
+
case_sensitive: bool = True
|
|
259
|
+
) -> server_common.RequestId[Dict[str,
|
|
260
|
+
List['catalog.common.InstanceTypeInfo']]]:
|
|
139
261
|
"""Lists the names of all accelerators offered by Sky.
|
|
140
262
|
|
|
141
263
|
This will include all accelerators offered by Sky, including those
|
|
@@ -169,10 +291,8 @@ def list_accelerators(gpus_only: bool = True,
|
|
|
169
291
|
require_price=require_price,
|
|
170
292
|
case_sensitive=case_sensitive,
|
|
171
293
|
)
|
|
172
|
-
response =
|
|
173
|
-
|
|
174
|
-
json=json.loads(body.model_dump_json()),
|
|
175
|
-
cookies=server_common.get_api_cookie_jar())
|
|
294
|
+
response = server_common.make_authenticated_request(
|
|
295
|
+
'POST', '/list_accelerators', json=json.loads(body.model_dump_json()))
|
|
176
296
|
return server_common.get_request_id(response)
|
|
177
297
|
|
|
178
298
|
|
|
@@ -180,12 +300,12 @@ def list_accelerators(gpus_only: bool = True,
|
|
|
180
300
|
@server_common.check_server_healthy_or_start
|
|
181
301
|
@annotations.client_api
|
|
182
302
|
def list_accelerator_counts(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
303
|
+
gpus_only: bool = True,
|
|
304
|
+
name_filter: Optional[str] = None,
|
|
305
|
+
region_filter: Optional[str] = None,
|
|
306
|
+
quantity_filter: Optional[int] = None,
|
|
307
|
+
clouds: Optional[Union[List[str], str]] = None
|
|
308
|
+
) -> server_common.RequestId[Dict[str, List[float]]]:
|
|
189
309
|
"""Lists all accelerators offered by Sky and available counts.
|
|
190
310
|
|
|
191
311
|
Args:
|
|
@@ -203,17 +323,17 @@ def list_accelerator_counts(
|
|
|
203
323
|
accelerator names mapped to a list of available counts. See usage
|
|
204
324
|
in cli.py.
|
|
205
325
|
"""
|
|
206
|
-
body = payloads.
|
|
326
|
+
body = payloads.ListAcceleratorCountsBody(
|
|
207
327
|
gpus_only=gpus_only,
|
|
208
328
|
name_filter=name_filter,
|
|
209
329
|
region_filter=region_filter,
|
|
210
330
|
quantity_filter=quantity_filter,
|
|
211
331
|
clouds=clouds,
|
|
212
332
|
)
|
|
213
|
-
response =
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
333
|
+
response = server_common.make_authenticated_request(
|
|
334
|
+
'POST',
|
|
335
|
+
'/list_accelerator_counts',
|
|
336
|
+
json=json.loads(body.model_dump_json()))
|
|
217
337
|
return server_common.get_request_id(response)
|
|
218
338
|
|
|
219
339
|
|
|
@@ -224,7 +344,7 @@ def optimize(
|
|
|
224
344
|
dag: 'sky.Dag',
|
|
225
345
|
minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
226
346
|
admin_policy_request_options: Optional[admin_policy.RequestOptions] = None
|
|
227
|
-
) -> server_common.RequestId:
|
|
347
|
+
) -> server_common.RequestId['sky.Dag']:
|
|
228
348
|
"""Finds the best execution plan for the given DAG.
|
|
229
349
|
|
|
230
350
|
Args:
|
|
@@ -250,9 +370,14 @@ def optimize(
|
|
|
250
370
|
body = payloads.OptimizeBody(dag=dag_str,
|
|
251
371
|
minimize=minimize,
|
|
252
372
|
request_options=admin_policy_request_options)
|
|
253
|
-
response =
|
|
254
|
-
|
|
255
|
-
|
|
373
|
+
response = server_common.make_authenticated_request(
|
|
374
|
+
'POST', '/optimize', json=json.loads(body.model_dump_json()))
|
|
375
|
+
return server_common.get_request_id(response)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def workspaces() -> server_common.RequestId[Dict[str, Any]]:
|
|
379
|
+
"""Gets the workspaces."""
|
|
380
|
+
response = server_common.make_authenticated_request('GET', '/workspaces')
|
|
256
381
|
return server_common.get_request_id(response)
|
|
257
382
|
|
|
258
383
|
|
|
@@ -279,16 +404,22 @@ def validate(
|
|
|
279
404
|
validation. This is only required when a admin policy is in use,
|
|
280
405
|
see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html
|
|
281
406
|
"""
|
|
407
|
+
remote_api_version = versions.get_remote_api_version()
|
|
408
|
+
# TODO(kevin): remove this in v0.13.0
|
|
409
|
+
omit_user_specified_yaml = (remote_api_version is None or
|
|
410
|
+
remote_api_version < 15)
|
|
282
411
|
for task in dag.tasks:
|
|
412
|
+
if omit_user_specified_yaml:
|
|
413
|
+
# pylint: disable=protected-access
|
|
414
|
+
task._user_specified_yaml = None
|
|
283
415
|
task.expand_and_validate_workdir()
|
|
284
416
|
if not workdir_only:
|
|
285
417
|
task.expand_and_validate_file_mounts()
|
|
286
418
|
dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
|
|
287
419
|
body = payloads.ValidateBody(dag=dag_str,
|
|
288
420
|
request_options=admin_policy_request_options)
|
|
289
|
-
response =
|
|
290
|
-
|
|
291
|
-
cookies=server_common.get_api_cookie_jar())
|
|
421
|
+
response = server_common.make_authenticated_request(
|
|
422
|
+
'POST', '/validate', json=json.loads(body.model_dump_json()))
|
|
292
423
|
if response.status_code == 400:
|
|
293
424
|
with ux_utils.print_exception_no_traceback():
|
|
294
425
|
raise exceptions.deserialize_exception(
|
|
@@ -298,10 +429,11 @@ def validate(
|
|
|
298
429
|
@usage_lib.entrypoint
|
|
299
430
|
@server_common.check_server_healthy_or_start
|
|
300
431
|
@annotations.client_api
|
|
301
|
-
def dashboard() -> None:
|
|
432
|
+
def dashboard(starting_page: Optional[str] = None) -> None:
|
|
302
433
|
"""Starts the dashboard for SkyPilot."""
|
|
303
434
|
api_server_url = server_common.get_server_url()
|
|
304
|
-
url = server_common.get_dashboard_url(api_server_url
|
|
435
|
+
url = server_common.get_dashboard_url(api_server_url,
|
|
436
|
+
starting_page=starting_page)
|
|
305
437
|
logger.info(f'Opening dashboard in browser: {url}')
|
|
306
438
|
webbrowser.open(url)
|
|
307
439
|
|
|
@@ -309,14 +441,16 @@ def dashboard() -> None:
|
|
|
309
441
|
@usage_lib.entrypoint
|
|
310
442
|
@server_common.check_server_healthy_or_start
|
|
311
443
|
@annotations.client_api
|
|
444
|
+
@sky_context.contextual
|
|
312
445
|
def launch(
|
|
313
446
|
task: Union['sky.Task', 'sky.Dag'],
|
|
314
447
|
cluster_name: Optional[str] = None,
|
|
315
448
|
retry_until_up: bool = False,
|
|
316
449
|
idle_minutes_to_autostop: Optional[int] = None,
|
|
450
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
317
451
|
dryrun: bool = False,
|
|
318
452
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
319
|
-
backend: Optional[backends.Backend] = None,
|
|
453
|
+
backend: Optional['backends.Backend'] = None,
|
|
320
454
|
optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
321
455
|
no_setup: bool = False,
|
|
322
456
|
clone_disk_from: Optional[str] = None,
|
|
@@ -327,7 +461,8 @@ def launch(
|
|
|
327
461
|
_is_launched_by_jobs_controller: bool = False,
|
|
328
462
|
_is_launched_by_sky_serve_controller: bool = False,
|
|
329
463
|
_disable_controller_check: bool = False,
|
|
330
|
-
) -> server_common.RequestId
|
|
464
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
465
|
+
Optional['backends.ResourceHandle']]]:
|
|
331
466
|
"""Launches a cluster or task.
|
|
332
467
|
|
|
333
468
|
The task's setup and run commands are executed under the task's workdir
|
|
@@ -344,7 +479,7 @@ def launch(
|
|
|
344
479
|
import sky
|
|
345
480
|
task = sky.Task(run='echo hello SkyPilot')
|
|
346
481
|
task.set_resources(
|
|
347
|
-
sky.Resources(
|
|
482
|
+
sky.Resources(infra='aws', accelerators='V100:4'))
|
|
348
483
|
sky.launch(task, cluster_name='my-cluster')
|
|
349
484
|
|
|
350
485
|
|
|
@@ -355,18 +490,31 @@ def launch(
|
|
|
355
490
|
retry_until_up: whether to retry launching the cluster until it is
|
|
356
491
|
up.
|
|
357
492
|
idle_minutes_to_autostop: automatically stop the cluster after this
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
493
|
+
many minute of idleness, i.e., no running or pending jobs in the
|
|
494
|
+
cluster's job queue. Idleness gets reset whenever setting-up/
|
|
495
|
+
running/pending jobs are found in the job queue. Setting this
|
|
496
|
+
flag is equivalent to running
|
|
497
|
+
``sky.launch(...)`` and then
|
|
498
|
+
``sky.autostop(idle_minutes=<minutes>)``. If set, the autostop
|
|
499
|
+
config specified in the task' resources will be overridden by
|
|
500
|
+
this parameter.
|
|
501
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
502
|
+
This option works in conjunction with ``idle_minutes_to_autostop``.
|
|
503
|
+
Choices:
|
|
504
|
+
|
|
505
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
506
|
+
connections to finish.
|
|
507
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
508
|
+
3. "none" - Wait for nothing; autostop right after
|
|
509
|
+
``idle_minutes_to_autostop``.
|
|
364
510
|
dryrun: if True, do not actually launch the cluster.
|
|
365
511
|
down: Tear down the cluster after all jobs finish (successfully or
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
512
|
+
abnormally). If --idle-minutes-to-autostop is also set, the
|
|
513
|
+
cluster will be torn down after the specified idle time.
|
|
514
|
+
Note that if errors occur during provisioning/data syncing/setting
|
|
515
|
+
up, the cluster will not be torn down for debugging purposes. If
|
|
516
|
+
set, the autostop config specified in the task' resources will be
|
|
517
|
+
overridden by this parameter.
|
|
370
518
|
backend: backend to use. If None, use the default backend
|
|
371
519
|
(CloudVMRayBackend).
|
|
372
520
|
optimize_target: target to optimize for. Choices: OptimizeTarget.COST,
|
|
@@ -422,35 +570,115 @@ def launch(
|
|
|
422
570
|
raise NotImplementedError('clone_disk_from is not implemented yet. '
|
|
423
571
|
'Please contact the SkyPilot team if you '
|
|
424
572
|
'need this feature at slack.skypilot.co.')
|
|
573
|
+
|
|
574
|
+
remote_api_version = versions.get_remote_api_version()
|
|
575
|
+
if wait_for is not None and (remote_api_version is None or
|
|
576
|
+
remote_api_version < 13):
|
|
577
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
578
|
+
'Please upgrade to a newer API server to use it.')
|
|
579
|
+
|
|
425
580
|
dag = dag_utils.convert_entrypoint_to_dag(task)
|
|
581
|
+
# Override the autostop config from command line flags to task YAML.
|
|
582
|
+
for task in dag.tasks:
|
|
583
|
+
for resource in task.resources:
|
|
584
|
+
if remote_api_version is None or remote_api_version < 13:
|
|
585
|
+
# An older server would not recognize the wait_for field
|
|
586
|
+
# in the schema, so we need to omit it.
|
|
587
|
+
resource.override_autostop_config(
|
|
588
|
+
down=down, idle_minutes=idle_minutes_to_autostop)
|
|
589
|
+
else:
|
|
590
|
+
resource.override_autostop_config(
|
|
591
|
+
down=down,
|
|
592
|
+
idle_minutes=idle_minutes_to_autostop,
|
|
593
|
+
wait_for=wait_for)
|
|
594
|
+
if resource.autostop_config is not None:
|
|
595
|
+
# For backward-compatibility, get the final autostop config for
|
|
596
|
+
# admin policy.
|
|
597
|
+
# TODO(aylei): remove this after 0.12.0
|
|
598
|
+
down = resource.autostop_config.down
|
|
599
|
+
idle_minutes_to_autostop = resource.autostop_config.idle_minutes
|
|
600
|
+
|
|
426
601
|
request_options = admin_policy.RequestOptions(
|
|
427
602
|
cluster_name=cluster_name,
|
|
428
603
|
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
429
604
|
down=down,
|
|
430
605
|
dryrun=dryrun)
|
|
606
|
+
with admin_policy_utils.apply_and_use_config_in_current_request(
|
|
607
|
+
dag,
|
|
608
|
+
request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
|
|
609
|
+
request_options=request_options,
|
|
610
|
+
at_client_side=True) as dag:
|
|
611
|
+
return _launch(
|
|
612
|
+
dag,
|
|
613
|
+
cluster_name,
|
|
614
|
+
request_options,
|
|
615
|
+
retry_until_up,
|
|
616
|
+
idle_minutes_to_autostop,
|
|
617
|
+
dryrun,
|
|
618
|
+
down,
|
|
619
|
+
backend,
|
|
620
|
+
optimize_target,
|
|
621
|
+
no_setup,
|
|
622
|
+
clone_disk_from,
|
|
623
|
+
fast,
|
|
624
|
+
_need_confirmation,
|
|
625
|
+
_is_launched_by_jobs_controller,
|
|
626
|
+
_is_launched_by_sky_serve_controller,
|
|
627
|
+
_disable_controller_check,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def _launch(
|
|
632
|
+
dag: 'sky.Dag',
|
|
633
|
+
cluster_name: str,
|
|
634
|
+
request_options: admin_policy.RequestOptions,
|
|
635
|
+
retry_until_up: bool = False,
|
|
636
|
+
idle_minutes_to_autostop: Optional[int] = None,
|
|
637
|
+
dryrun: bool = False,
|
|
638
|
+
down: bool = False, # pylint: disable=redefined-outer-name
|
|
639
|
+
backend: Optional['backends.Backend'] = None,
|
|
640
|
+
optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
|
|
641
|
+
no_setup: bool = False,
|
|
642
|
+
clone_disk_from: Optional[str] = None,
|
|
643
|
+
fast: bool = False,
|
|
644
|
+
# Internal only:
|
|
645
|
+
# pylint: disable=invalid-name
|
|
646
|
+
_need_confirmation: bool = False,
|
|
647
|
+
_is_launched_by_jobs_controller: bool = False,
|
|
648
|
+
_is_launched_by_sky_serve_controller: bool = False,
|
|
649
|
+
_disable_controller_check: bool = False,
|
|
650
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
651
|
+
Optional['backends.ResourceHandle']]]:
|
|
652
|
+
"""Auxiliary function for launch(), refer to launch() for details."""
|
|
653
|
+
|
|
431
654
|
validate(dag, admin_policy_request_options=request_options)
|
|
655
|
+
# The flags have been applied to the task YAML and the backward
|
|
656
|
+
# compatibility of admin policy has been handled. We should no longer use
|
|
657
|
+
# these flags.
|
|
658
|
+
del down, idle_minutes_to_autostop
|
|
432
659
|
|
|
433
660
|
confirm_shown = False
|
|
434
661
|
if _need_confirmation:
|
|
435
662
|
cluster_status = None
|
|
436
663
|
# TODO(SKY-998): we should reduce RTTs before launching the cluster.
|
|
437
|
-
|
|
438
|
-
clusters = get(
|
|
664
|
+
status_request_id = status([cluster_name], all_users=True)
|
|
665
|
+
clusters = get(status_request_id)
|
|
439
666
|
cluster_user_hash = common_utils.get_user_hash()
|
|
440
667
|
cluster_user_hash_str = ''
|
|
441
|
-
|
|
668
|
+
current_user = common_utils.get_current_user_name()
|
|
669
|
+
cluster_user_name = current_user
|
|
442
670
|
if not clusters:
|
|
443
671
|
# Show the optimize log before the prompt if the cluster does not
|
|
444
672
|
# exist.
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
stream_and_get(
|
|
673
|
+
optimize_request_id = optimize(
|
|
674
|
+
dag, admin_policy_request_options=request_options)
|
|
675
|
+
stream_and_get(optimize_request_id)
|
|
448
676
|
else:
|
|
449
677
|
cluster_record = clusters[0]
|
|
450
678
|
cluster_status = cluster_record['status']
|
|
451
679
|
cluster_user_hash = cluster_record['user_hash']
|
|
452
680
|
cluster_user_name = cluster_record['user_name']
|
|
453
|
-
if cluster_user_name ==
|
|
681
|
+
if cluster_user_name == current_user:
|
|
454
682
|
# Only show the hash if the username is the same as the local
|
|
455
683
|
# username, to avoid confusion.
|
|
456
684
|
cluster_user_hash_str = f' (hash: {cluster_user_hash})'
|
|
@@ -492,9 +720,7 @@ def launch(
|
|
|
492
720
|
task=dag_str,
|
|
493
721
|
cluster_name=cluster_name,
|
|
494
722
|
retry_until_up=retry_until_up,
|
|
495
|
-
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
496
723
|
dryrun=dryrun,
|
|
497
|
-
down=down,
|
|
498
724
|
backend=backend.NAME if backend else None,
|
|
499
725
|
optimize_target=optimize_target,
|
|
500
726
|
no_setup=no_setup,
|
|
@@ -507,12 +733,8 @@ def launch(
|
|
|
507
733
|
_is_launched_by_sky_serve_controller),
|
|
508
734
|
disable_controller_check=_disable_controller_check,
|
|
509
735
|
)
|
|
510
|
-
response =
|
|
511
|
-
|
|
512
|
-
json=json.loads(body.model_dump_json()),
|
|
513
|
-
timeout=5,
|
|
514
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
515
|
-
)
|
|
736
|
+
response = server_common.make_authenticated_request(
|
|
737
|
+
'POST', '/launch', json=json.loads(body.model_dump_json()), timeout=5)
|
|
516
738
|
return server_common.get_request_id(response)
|
|
517
739
|
|
|
518
740
|
|
|
@@ -524,8 +746,9 @@ def exec( # pylint: disable=redefined-builtin
|
|
|
524
746
|
cluster_name: Optional[str] = None,
|
|
525
747
|
dryrun: bool = False,
|
|
526
748
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
527
|
-
backend: Optional[backends.Backend] = None,
|
|
528
|
-
) -> server_common.RequestId
|
|
749
|
+
backend: Optional['backends.Backend'] = None,
|
|
750
|
+
) -> server_common.RequestId[Tuple[Optional[int],
|
|
751
|
+
Optional['backends.ResourceHandle']]]:
|
|
529
752
|
"""Executes a task on an existing cluster.
|
|
530
753
|
|
|
531
754
|
This function performs two actions:
|
|
@@ -591,23 +814,49 @@ def exec( # pylint: disable=redefined-builtin
|
|
|
591
814
|
backend=backend.NAME if backend else None,
|
|
592
815
|
)
|
|
593
816
|
|
|
594
|
-
response =
|
|
595
|
-
|
|
596
|
-
json=json.loads(body.model_dump_json()),
|
|
597
|
-
timeout=5,
|
|
598
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
599
|
-
)
|
|
817
|
+
response = server_common.make_authenticated_request(
|
|
818
|
+
'POST', '/exec', json=json.loads(body.model_dump_json()), timeout=5)
|
|
600
819
|
return server_common.get_request_id(response)
|
|
601
820
|
|
|
602
821
|
|
|
603
|
-
@
|
|
604
|
-
|
|
605
|
-
|
|
822
|
+
@typing.overload
|
|
823
|
+
def tail_logs(
|
|
824
|
+
cluster_name: str,
|
|
825
|
+
job_id: Optional[int],
|
|
826
|
+
follow: bool,
|
|
827
|
+
tail: int = 0,
|
|
828
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
829
|
+
*, # keyword only separator
|
|
830
|
+
preload_content: Literal[True] = True) -> int:
|
|
831
|
+
...
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
@typing.overload
|
|
606
835
|
def tail_logs(cluster_name: str,
|
|
607
836
|
job_id: Optional[int],
|
|
608
837
|
follow: bool,
|
|
609
838
|
tail: int = 0,
|
|
610
|
-
output_stream:
|
|
839
|
+
output_stream: None = None,
|
|
840
|
+
*,
|
|
841
|
+
preload_content: Literal[False]) -> Iterator[Optional[str]]:
|
|
842
|
+
...
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
# TODO(aylei): when retry logs request, there will be duplicated log entries.
|
|
846
|
+
# We should fix this.
|
|
847
|
+
@usage_lib.entrypoint
|
|
848
|
+
@server_common.check_server_healthy_or_start
|
|
849
|
+
@annotations.client_api
|
|
850
|
+
@rest.retry_transient_errors()
|
|
851
|
+
def tail_logs(
|
|
852
|
+
cluster_name: str,
|
|
853
|
+
job_id: Optional[int],
|
|
854
|
+
follow: bool,
|
|
855
|
+
tail: int = 0,
|
|
856
|
+
output_stream: Optional['io.TextIOBase'] = None,
|
|
857
|
+
*, # keyword only separator
|
|
858
|
+
preload_content: bool = True
|
|
859
|
+
) -> Union[int, Iterator[Optional[str]]]:
|
|
611
860
|
"""Tails the logs of a job.
|
|
612
861
|
|
|
613
862
|
Args:
|
|
@@ -617,12 +866,21 @@ def tail_logs(cluster_name: str,
|
|
|
617
866
|
immediately.
|
|
618
867
|
tail: if > 0, tail the last N lines of the logs.
|
|
619
868
|
output_stream: the stream to write the logs to. If None, print to the
|
|
620
|
-
console.
|
|
869
|
+
console. Cannot be used with preload_content=False.
|
|
870
|
+
preload_content: if False, returns an Iterator[str | None] containing
|
|
871
|
+
the logs without the function blocking on the retrieval of entire
|
|
872
|
+
log. Iterator returns None when the log has been completely
|
|
873
|
+
streamed. Default True. Cannot be used with output_stream.
|
|
621
874
|
|
|
622
875
|
Returns:
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
876
|
+
If preload_content is True:
|
|
877
|
+
Exit code based on success or failure of the job. 0 if success,
|
|
878
|
+
100 if the job failed. See exceptions.JobExitCode for possible exit
|
|
879
|
+
codes.
|
|
880
|
+
If preload_content is False:
|
|
881
|
+
Iterator[str | None] containing the logs without the function
|
|
882
|
+
blocking on the retrieval of entire log. Iterator returns None
|
|
883
|
+
when the log has been completely streamed.
|
|
626
884
|
|
|
627
885
|
Request Raises:
|
|
628
886
|
ValueError: if arguments are invalid or the cluster is not supported.
|
|
@@ -635,21 +893,110 @@ def tail_logs(cluster_name: str,
|
|
|
635
893
|
sky.exceptions.CloudUserIdentityError: if we fail to get the current
|
|
636
894
|
user identity.
|
|
637
895
|
"""
|
|
896
|
+
if output_stream is not None and not preload_content:
|
|
897
|
+
raise ValueError(
|
|
898
|
+
'output_stream cannot be specified when preload_content is False')
|
|
899
|
+
|
|
638
900
|
body = payloads.ClusterJobBody(
|
|
639
901
|
cluster_name=cluster_name,
|
|
640
902
|
job_id=job_id,
|
|
641
903
|
follow=follow,
|
|
642
904
|
tail=tail,
|
|
643
905
|
)
|
|
644
|
-
response =
|
|
645
|
-
|
|
906
|
+
response = server_common.make_authenticated_request(
|
|
907
|
+
'POST',
|
|
908
|
+
'/logs',
|
|
646
909
|
json=json.loads(body.model_dump_json()),
|
|
647
910
|
stream=True,
|
|
648
911
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
649
|
-
None)
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
912
|
+
None))
|
|
913
|
+
request_id: server_common.RequestId[int] = server_common.get_request_id(
|
|
914
|
+
response)
|
|
915
|
+
if preload_content:
|
|
916
|
+
# Log request is idempotent when tail is 0, thus can resume previous
|
|
917
|
+
# streaming point on retry.
|
|
918
|
+
return stream_response(request_id=request_id,
|
|
919
|
+
response=response,
|
|
920
|
+
output_stream=output_stream,
|
|
921
|
+
resumable=(tail == 0))
|
|
922
|
+
else:
|
|
923
|
+
return rich_utils.decode_rich_status(response)
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
@usage_lib.entrypoint
|
|
927
|
+
@server_common.check_server_healthy_or_start
|
|
928
|
+
@versions.minimal_api_version(17)
|
|
929
|
+
@annotations.client_api
|
|
930
|
+
@rest.retry_transient_errors()
|
|
931
|
+
def tail_provision_logs(cluster_name: str,
|
|
932
|
+
worker: Optional[int] = None,
|
|
933
|
+
follow: bool = True,
|
|
934
|
+
tail: int = 0,
|
|
935
|
+
output_stream: Optional['io.TextIOBase'] = None) -> int:
|
|
936
|
+
"""Tails the provisioning logs (provision.log) for a cluster.
|
|
937
|
+
|
|
938
|
+
Args:
|
|
939
|
+
cluster_name: name of the cluster.
|
|
940
|
+
worker: worker id in multi-node cluster.
|
|
941
|
+
If None, stream the logs of the head node.
|
|
942
|
+
follow: follow the logs.
|
|
943
|
+
tail: lines from end to tail.
|
|
944
|
+
output_stream: optional stream to write logs.
|
|
945
|
+
Returns:
|
|
946
|
+
Exit code 0 on streaming success; raises on HTTP error.
|
|
947
|
+
"""
|
|
948
|
+
body = payloads.ProvisionLogsBody(cluster_name=cluster_name)
|
|
949
|
+
|
|
950
|
+
if worker is not None:
|
|
951
|
+
remote_api_version = versions.get_remote_api_version()
|
|
952
|
+
if remote_api_version is not None and remote_api_version >= 21:
|
|
953
|
+
if worker < 1:
|
|
954
|
+
raise ValueError('Worker must be a positive integer.')
|
|
955
|
+
body.worker = worker
|
|
956
|
+
else:
|
|
957
|
+
raise exceptions.APINotSupportedError(
|
|
958
|
+
'Worker node provision logs are not supported in your API '
|
|
959
|
+
'server. Please upgrade to a newer API server to use it.')
|
|
960
|
+
params = {
|
|
961
|
+
'follow': str(follow).lower(),
|
|
962
|
+
'tail': tail,
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
response = server_common.make_authenticated_request(
|
|
966
|
+
'POST',
|
|
967
|
+
'/provision_logs',
|
|
968
|
+
json=json.loads(body.model_dump_json()),
|
|
969
|
+
params=params,
|
|
970
|
+
stream=True,
|
|
971
|
+
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
972
|
+
None))
|
|
973
|
+
# Check for HTTP errors before streaming the response
|
|
974
|
+
if response.status_code != 200:
|
|
975
|
+
with ux_utils.print_exception_no_traceback():
|
|
976
|
+
raise exceptions.CommandError(response.status_code,
|
|
977
|
+
'tail_provision_logs',
|
|
978
|
+
'Failed to stream provision logs',
|
|
979
|
+
response.text)
|
|
980
|
+
|
|
981
|
+
# Log request is idempotent when tail is 0, thus can resume previous
|
|
982
|
+
# streaming point on retry.
|
|
983
|
+
# request_id=None here because /provision_logs does not create an async
|
|
984
|
+
# request. Instead, it streams a plain file from the server. This does NOT
|
|
985
|
+
# violate the stream_response doc warning about None in multi-user
|
|
986
|
+
# environments: we are not asking stream_response to select "the latest
|
|
987
|
+
# request". We already have the HTTP response to stream; request_id=None
|
|
988
|
+
# merely disables the follow-up GET. It is also necessary for --no-follow
|
|
989
|
+
# to return cleanly after printing the tailed lines. If we provided a
|
|
990
|
+
# non-None request_id here, the get(request_id) in stream_response(
|
|
991
|
+
# would fail since /provision_logs does not create a request record.
|
|
992
|
+
# By virtue of this, we set get_result to False to block get() from
|
|
993
|
+
# running.
|
|
994
|
+
stream_response(request_id=None,
|
|
995
|
+
response=response,
|
|
996
|
+
output_stream=output_stream,
|
|
997
|
+
resumable=(tail == 0),
|
|
998
|
+
get_result=False)
|
|
999
|
+
return 0
|
|
653
1000
|
|
|
654
1001
|
|
|
655
1002
|
@usage_lib.entrypoint
|
|
@@ -683,11 +1030,11 @@ def download_logs(cluster_name: str,
|
|
|
683
1030
|
cluster_name=cluster_name,
|
|
684
1031
|
job_ids=job_ids,
|
|
685
1032
|
)
|
|
686
|
-
response =
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
1033
|
+
response = server_common.make_authenticated_request(
|
|
1034
|
+
'POST', '/download_logs', json=json.loads(body.model_dump_json()))
|
|
1035
|
+
request_id: server_common.RequestId[Dict[
|
|
1036
|
+
str, str]] = server_common.get_request_id(response)
|
|
1037
|
+
job_id_remote_path_dict = stream_and_get(request_id)
|
|
691
1038
|
remote2local_path_dict = client_common.download_logs_from_api_server(
|
|
692
1039
|
job_id_remote_path_dict.values())
|
|
693
1040
|
return {
|
|
@@ -702,10 +1049,11 @@ def download_logs(cluster_name: str,
|
|
|
702
1049
|
def start(
|
|
703
1050
|
cluster_name: str,
|
|
704
1051
|
idle_minutes_to_autostop: Optional[int] = None,
|
|
1052
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
705
1053
|
retry_until_up: bool = False,
|
|
706
1054
|
down: bool = False, # pylint: disable=redefined-outer-name
|
|
707
1055
|
force: bool = False,
|
|
708
|
-
) -> server_common.RequestId:
|
|
1056
|
+
) -> server_common.RequestId['backends.CloudVmRayResourceHandle']:
|
|
709
1057
|
"""Restart a cluster.
|
|
710
1058
|
|
|
711
1059
|
If a cluster is previously stopped (status is STOPPED) or failed in
|
|
@@ -728,6 +1076,15 @@ def start(
|
|
|
728
1076
|
flag is equivalent to running ``sky.launch()`` and then
|
|
729
1077
|
``sky.autostop(idle_minutes=<minutes>)``. If not set, the
|
|
730
1078
|
cluster will not be autostopped.
|
|
1079
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
1080
|
+
This option works in conjunction with ``idle_minutes_to_autostop``.
|
|
1081
|
+
Choices:
|
|
1082
|
+
|
|
1083
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
1084
|
+
connections to finish.
|
|
1085
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
1086
|
+
3. "none" - Wait for nothing; autostop right after
|
|
1087
|
+
``idle_minutes_to_autostop``.
|
|
731
1088
|
retry_until_up: whether to retry launching the cluster until it is
|
|
732
1089
|
up.
|
|
733
1090
|
down: Autodown the cluster: tear down the cluster after specified
|
|
@@ -756,26 +1113,30 @@ def start(
|
|
|
756
1113
|
sky.exceptions.ClusterOwnerIdentitiesMismatchError: if the cluster to
|
|
757
1114
|
restart was launched by a different user.
|
|
758
1115
|
"""
|
|
1116
|
+
remote_api_version = versions.get_remote_api_version()
|
|
1117
|
+
if wait_for is not None and (remote_api_version is None or
|
|
1118
|
+
remote_api_version < 13):
|
|
1119
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
1120
|
+
'Please upgrade to a newer API server to use it.')
|
|
1121
|
+
|
|
759
1122
|
body = payloads.StartBody(
|
|
760
1123
|
cluster_name=cluster_name,
|
|
761
1124
|
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
1125
|
+
wait_for=wait_for,
|
|
762
1126
|
retry_until_up=retry_until_up,
|
|
763
1127
|
down=down,
|
|
764
1128
|
force=force,
|
|
765
1129
|
)
|
|
766
|
-
response =
|
|
767
|
-
|
|
768
|
-
json=json.loads(body.model_dump_json()),
|
|
769
|
-
timeout=5,
|
|
770
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
771
|
-
)
|
|
1130
|
+
response = server_common.make_authenticated_request(
|
|
1131
|
+
'POST', '/start', json=json.loads(body.model_dump_json()), timeout=5)
|
|
772
1132
|
return server_common.get_request_id(response)
|
|
773
1133
|
|
|
774
1134
|
|
|
775
1135
|
@usage_lib.entrypoint
|
|
776
1136
|
@server_common.check_server_healthy_or_start
|
|
777
1137
|
@annotations.client_api
|
|
778
|
-
def down(cluster_name: str,
|
|
1138
|
+
def down(cluster_name: str,
|
|
1139
|
+
purge: bool = False) -> server_common.RequestId[None]:
|
|
779
1140
|
"""Tears down a cluster.
|
|
780
1141
|
|
|
781
1142
|
Tearing down a cluster will delete all associated resources (all billing
|
|
@@ -809,19 +1170,16 @@ def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
809
1170
|
cluster_name=cluster_name,
|
|
810
1171
|
purge=purge,
|
|
811
1172
|
)
|
|
812
|
-
response =
|
|
813
|
-
|
|
814
|
-
json=json.loads(body.model_dump_json()),
|
|
815
|
-
timeout=5,
|
|
816
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
817
|
-
)
|
|
1173
|
+
response = server_common.make_authenticated_request(
|
|
1174
|
+
'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
|
|
818
1175
|
return server_common.get_request_id(response)
|
|
819
1176
|
|
|
820
1177
|
|
|
821
1178
|
@usage_lib.entrypoint
|
|
822
1179
|
@server_common.check_server_healthy_or_start
|
|
823
1180
|
@annotations.client_api
|
|
824
|
-
def stop(cluster_name: str,
|
|
1181
|
+
def stop(cluster_name: str,
|
|
1182
|
+
purge: bool = False) -> server_common.RequestId[None]:
|
|
825
1183
|
"""Stops a cluster.
|
|
826
1184
|
|
|
827
1185
|
Data on attached disks is not lost when a cluster is stopped. Billing for
|
|
@@ -858,12 +1216,8 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
858
1216
|
cluster_name=cluster_name,
|
|
859
1217
|
purge=purge,
|
|
860
1218
|
)
|
|
861
|
-
response =
|
|
862
|
-
|
|
863
|
-
json=json.loads(body.model_dump_json()),
|
|
864
|
-
timeout=5,
|
|
865
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
866
|
-
)
|
|
1219
|
+
response = server_common.make_authenticated_request(
|
|
1220
|
+
'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
|
|
867
1221
|
return server_common.get_request_id(response)
|
|
868
1222
|
|
|
869
1223
|
|
|
@@ -871,10 +1225,11 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
|
|
|
871
1225
|
@server_common.check_server_healthy_or_start
|
|
872
1226
|
@annotations.client_api
|
|
873
1227
|
def autostop(
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1228
|
+
cluster_name: str,
|
|
1229
|
+
idle_minutes: int,
|
|
1230
|
+
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
|
|
1231
|
+
down: bool = False, # pylint: disable=redefined-outer-name
|
|
1232
|
+
) -> server_common.RequestId[None]:
|
|
878
1233
|
"""Schedules an autostop/autodown for a cluster.
|
|
879
1234
|
|
|
880
1235
|
Autostop/autodown will automatically stop or teardown a cluster when it
|
|
@@ -904,6 +1259,14 @@ def autostop(
|
|
|
904
1259
|
idle_minutes: the number of minutes of idleness (no pending/running
|
|
905
1260
|
jobs) after which the cluster will be stopped automatically. Setting
|
|
906
1261
|
to a negative number cancels any autostop/autodown setting.
|
|
1262
|
+
wait_for: determines the condition for resetting the idleness timer.
|
|
1263
|
+
This option works in conjunction with ``idle_minutes``.
|
|
1264
|
+
Choices:
|
|
1265
|
+
|
|
1266
|
+
1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
|
|
1267
|
+
connections to finish.
|
|
1268
|
+
2. "jobs" - Only wait for in-progress jobs.
|
|
1269
|
+
3. "none" - Wait for nothing; autostop right after ``idle_minutes``.
|
|
907
1270
|
down: if true, use autodown (tear down the cluster; non-restartable),
|
|
908
1271
|
rather than autostop (restartable).
|
|
909
1272
|
|
|
@@ -923,26 +1286,31 @@ def autostop(
|
|
|
923
1286
|
sky.exceptions.CloudUserIdentityError: if we fail to get the current
|
|
924
1287
|
user identity.
|
|
925
1288
|
"""
|
|
1289
|
+
remote_api_version = versions.get_remote_api_version()
|
|
1290
|
+
if wait_for is not None and (remote_api_version is None or
|
|
1291
|
+
remote_api_version < 13):
|
|
1292
|
+
logger.warning('wait_for is not supported in your API server. '
|
|
1293
|
+
'Please upgrade to a newer API server to use it.')
|
|
1294
|
+
|
|
926
1295
|
body = payloads.AutostopBody(
|
|
927
1296
|
cluster_name=cluster_name,
|
|
928
1297
|
idle_minutes=idle_minutes,
|
|
1298
|
+
wait_for=wait_for,
|
|
929
1299
|
down=down,
|
|
930
1300
|
)
|
|
931
|
-
response =
|
|
932
|
-
|
|
933
|
-
json=json.loads(body.model_dump_json()),
|
|
934
|
-
timeout=5,
|
|
935
|
-
cookies=server_common.get_api_cookie_jar(),
|
|
936
|
-
)
|
|
1301
|
+
response = server_common.make_authenticated_request(
|
|
1302
|
+
'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
|
|
937
1303
|
return server_common.get_request_id(response)
|
|
938
1304
|
|
|
939
1305
|
|
|
940
1306
|
@usage_lib.entrypoint
|
|
941
1307
|
@server_common.check_server_healthy_or_start
|
|
942
1308
|
@annotations.client_api
|
|
943
|
-
def queue(
|
|
944
|
-
|
|
945
|
-
|
|
1309
|
+
def queue(
|
|
1310
|
+
cluster_name: str,
|
|
1311
|
+
skip_finished: bool = False,
|
|
1312
|
+
all_users: bool = False
|
|
1313
|
+
) -> server_common.RequestId[List[responses.ClusterJobRecord]]:
|
|
946
1314
|
"""Gets the job queue of a cluster.
|
|
947
1315
|
|
|
948
1316
|
Args:
|
|
@@ -955,8 +1323,8 @@ def queue(cluster_name: str,
|
|
|
955
1323
|
The request ID of the queue request.
|
|
956
1324
|
|
|
957
1325
|
Request Returns:
|
|
958
|
-
job_records (List[
|
|
959
|
-
queue.
|
|
1326
|
+
job_records (List[responses.ClusterJobRecord]): A list of job records
|
|
1327
|
+
for each job in the queue.
|
|
960
1328
|
|
|
961
1329
|
.. code-block:: python
|
|
962
1330
|
|
|
@@ -991,17 +1359,19 @@ def queue(cluster_name: str,
|
|
|
991
1359
|
skip_finished=skip_finished,
|
|
992
1360
|
all_users=all_users,
|
|
993
1361
|
)
|
|
994
|
-
response =
|
|
995
|
-
|
|
996
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1362
|
+
response = server_common.make_authenticated_request(
|
|
1363
|
+
'POST', '/queue', json=json.loads(body.model_dump_json()))
|
|
997
1364
|
return server_common.get_request_id(response)
|
|
998
1365
|
|
|
999
1366
|
|
|
1000
1367
|
@usage_lib.entrypoint
|
|
1001
1368
|
@server_common.check_server_healthy_or_start
|
|
1002
1369
|
@annotations.client_api
|
|
1003
|
-
def job_status(
|
|
1004
|
-
|
|
1370
|
+
def job_status(
|
|
1371
|
+
cluster_name: str,
|
|
1372
|
+
job_ids: Optional[List[int]] = None
|
|
1373
|
+
) -> server_common.RequestId[Dict[Optional[int],
|
|
1374
|
+
Optional['job_lib.JobStatus']]]:
|
|
1005
1375
|
"""Gets the status of jobs on a cluster.
|
|
1006
1376
|
|
|
1007
1377
|
Args:
|
|
@@ -1033,9 +1403,8 @@ def job_status(cluster_name: str,
|
|
|
1033
1403
|
cluster_name=cluster_name,
|
|
1034
1404
|
job_ids=job_ids,
|
|
1035
1405
|
)
|
|
1036
|
-
response =
|
|
1037
|
-
|
|
1038
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1406
|
+
response = server_common.make_authenticated_request(
|
|
1407
|
+
'POST', '/job_status', json=json.loads(body.model_dump_json()))
|
|
1039
1408
|
return server_common.get_request_id(response)
|
|
1040
1409
|
|
|
1041
1410
|
|
|
@@ -1049,7 +1418,7 @@ def cancel(
|
|
|
1049
1418
|
job_ids: Optional[List[int]] = None,
|
|
1050
1419
|
# pylint: disable=invalid-name
|
|
1051
1420
|
_try_cancel_if_cluster_is_init: bool = False
|
|
1052
|
-
) -> server_common.RequestId:
|
|
1421
|
+
) -> server_common.RequestId[None]:
|
|
1053
1422
|
"""Cancels jobs on a cluster.
|
|
1054
1423
|
|
|
1055
1424
|
Args:
|
|
@@ -1087,9 +1456,8 @@ def cancel(
|
|
|
1087
1456
|
job_ids=job_ids,
|
|
1088
1457
|
try_cancel_if_cluster_is_init=_try_cancel_if_cluster_is_init,
|
|
1089
1458
|
)
|
|
1090
|
-
response =
|
|
1091
|
-
|
|
1092
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1459
|
+
response = server_common.make_authenticated_request(
|
|
1460
|
+
'POST', '/cancel', json=json.loads(body.model_dump_json()))
|
|
1093
1461
|
return server_common.get_request_id(response)
|
|
1094
1462
|
|
|
1095
1463
|
|
|
@@ -1100,7 +1468,10 @@ def status(
|
|
|
1100
1468
|
cluster_names: Optional[List[str]] = None,
|
|
1101
1469
|
refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE,
|
|
1102
1470
|
all_users: bool = False,
|
|
1103
|
-
|
|
1471
|
+
*,
|
|
1472
|
+
_include_credentials: bool = False,
|
|
1473
|
+
_summary_response: bool = False,
|
|
1474
|
+
) -> server_common.RequestId[List[responses.StatusResponse]]:
|
|
1104
1475
|
"""Gets cluster statuses.
|
|
1105
1476
|
|
|
1106
1477
|
If cluster_names is given, return those clusters. Otherwise, return all
|
|
@@ -1148,6 +1519,8 @@ def status(
|
|
|
1148
1519
|
provider(s).
|
|
1149
1520
|
all_users: whether to include all users' clusters. By default, only
|
|
1150
1521
|
the current user's clusters are included.
|
|
1522
|
+
_include_credentials: (internal) whether to include cluster ssh
|
|
1523
|
+
credentials in the response (default: False).
|
|
1151
1524
|
|
|
1152
1525
|
Returns:
|
|
1153
1526
|
The request ID of the status request.
|
|
@@ -1182,10 +1555,11 @@ def status(
|
|
|
1182
1555
|
cluster_names=cluster_names,
|
|
1183
1556
|
refresh=refresh,
|
|
1184
1557
|
all_users=all_users,
|
|
1558
|
+
include_credentials=_include_credentials,
|
|
1559
|
+
summary_response=_summary_response,
|
|
1185
1560
|
)
|
|
1186
|
-
response =
|
|
1187
|
-
|
|
1188
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1561
|
+
response = server_common.make_authenticated_request(
|
|
1562
|
+
'POST', '/status', json=json.loads(body.model_dump_json()))
|
|
1189
1563
|
return server_common.get_request_id(response)
|
|
1190
1564
|
|
|
1191
1565
|
|
|
@@ -1193,10 +1567,19 @@ def status(
|
|
|
1193
1567
|
@server_common.check_server_healthy_or_start
|
|
1194
1568
|
@annotations.client_api
|
|
1195
1569
|
def endpoints(
|
|
1196
|
-
|
|
1197
|
-
|
|
1570
|
+
cluster: str,
|
|
1571
|
+
port: Optional[Union[int, str]] = None
|
|
1572
|
+
) -> server_common.RequestId[Dict[int, str]]:
|
|
1198
1573
|
"""Gets the endpoint for a given cluster and port number (endpoint).
|
|
1199
1574
|
|
|
1575
|
+
Example:
|
|
1576
|
+
.. code-block:: python
|
|
1577
|
+
|
|
1578
|
+
import sky
|
|
1579
|
+
request_id = sky.endpoints('test-cluster')
|
|
1580
|
+
sky.get(request_id)
|
|
1581
|
+
|
|
1582
|
+
|
|
1200
1583
|
Args:
|
|
1201
1584
|
cluster: The name of the cluster.
|
|
1202
1585
|
port: The port number to get the endpoint for. If None, endpoints
|
|
@@ -1206,8 +1589,9 @@ def endpoints(
|
|
|
1206
1589
|
The request ID of the endpoints request.
|
|
1207
1590
|
|
|
1208
1591
|
Request Returns:
|
|
1209
|
-
A dictionary of port numbers to endpoints.
|
|
1210
|
-
|
|
1592
|
+
A dictionary of port numbers to endpoints.
|
|
1593
|
+
If port is None, the dictionary contains all
|
|
1594
|
+
ports:endpoints exposed on the cluster.
|
|
1211
1595
|
|
|
1212
1596
|
Request Raises:
|
|
1213
1597
|
ValueError: if the cluster is not UP or the endpoint is not exposed.
|
|
@@ -1218,16 +1602,17 @@ def endpoints(
|
|
|
1218
1602
|
cluster=cluster,
|
|
1219
1603
|
port=port,
|
|
1220
1604
|
)
|
|
1221
|
-
response =
|
|
1222
|
-
|
|
1223
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1605
|
+
response = server_common.make_authenticated_request(
|
|
1606
|
+
'POST', '/endpoints', json=json.loads(body.model_dump_json()))
|
|
1224
1607
|
return server_common.get_request_id(response)
|
|
1225
1608
|
|
|
1226
1609
|
|
|
1227
1610
|
@usage_lib.entrypoint
|
|
1228
1611
|
@server_common.check_server_healthy_or_start
|
|
1229
1612
|
@annotations.client_api
|
|
1230
|
-
def cost_report(
|
|
1613
|
+
def cost_report(
|
|
1614
|
+
days: Optional[int] = None
|
|
1615
|
+
) -> server_common.RequestId[List[Dict[str, Any]]]: # pylint: disable=redefined-builtin
|
|
1231
1616
|
"""Gets all cluster cost reports, including those that have been downed.
|
|
1232
1617
|
|
|
1233
1618
|
The estimated cost column indicates price for the cluster based on the type
|
|
@@ -1237,6 +1622,10 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1237
1622
|
cache of the cluster status, and may not be accurate for the cluster with
|
|
1238
1623
|
autostop/use_spot set or terminated/stopped on the cloud console.
|
|
1239
1624
|
|
|
1625
|
+
Args:
|
|
1626
|
+
days: The number of days to get the cost report for. If not provided,
|
|
1627
|
+
the default is 30 days.
|
|
1628
|
+
|
|
1240
1629
|
Returns:
|
|
1241
1630
|
The request ID of the cost report request.
|
|
1242
1631
|
|
|
@@ -1258,8 +1647,9 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1258
1647
|
'total_cost': (float) cost given resources and usage intervals,
|
|
1259
1648
|
}
|
|
1260
1649
|
"""
|
|
1261
|
-
|
|
1262
|
-
|
|
1650
|
+
body = payloads.CostReportBody(days=days)
|
|
1651
|
+
response = server_common.make_authenticated_request(
|
|
1652
|
+
'POST', '/cost_report', json=json.loads(body.model_dump_json()))
|
|
1263
1653
|
return server_common.get_request_id(response)
|
|
1264
1654
|
|
|
1265
1655
|
|
|
@@ -1267,36 +1657,24 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
|
|
|
1267
1657
|
@usage_lib.entrypoint
|
|
1268
1658
|
@server_common.check_server_healthy_or_start
|
|
1269
1659
|
@annotations.client_api
|
|
1270
|
-
def storage_ls() -> server_common.RequestId:
|
|
1660
|
+
def storage_ls() -> server_common.RequestId[List[responses.StorageRecord]]:
|
|
1271
1661
|
"""Gets the storages.
|
|
1272
1662
|
|
|
1273
1663
|
Returns:
|
|
1274
1664
|
The request ID of the storage list request.
|
|
1275
1665
|
|
|
1276
1666
|
Request Returns:
|
|
1277
|
-
storage_records (List[
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
.. code-block:: python
|
|
1281
|
-
|
|
1282
|
-
{
|
|
1283
|
-
'name': (str) storage name,
|
|
1284
|
-
'launched_at': (int) timestamp of creation,
|
|
1285
|
-
'store': (List[sky.StoreType]) storage type,
|
|
1286
|
-
'last_use': (int) timestamp of last use,
|
|
1287
|
-
'status': (sky.StorageStatus) storage status,
|
|
1288
|
-
}
|
|
1289
|
-
]
|
|
1667
|
+
storage_records (List[responses.StorageRecord]):
|
|
1668
|
+
A list of storage records.
|
|
1290
1669
|
"""
|
|
1291
|
-
response =
|
|
1292
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1670
|
+
response = server_common.make_authenticated_request('GET', '/storage/ls')
|
|
1293
1671
|
return server_common.get_request_id(response)
|
|
1294
1672
|
|
|
1295
1673
|
|
|
1296
1674
|
@usage_lib.entrypoint
|
|
1297
1675
|
@server_common.check_server_healthy_or_start
|
|
1298
1676
|
@annotations.client_api
|
|
1299
|
-
def storage_delete(name: str) -> server_common.RequestId:
|
|
1677
|
+
def storage_delete(name: str) -> server_common.RequestId[None]:
|
|
1300
1678
|
"""Deletes a storage.
|
|
1301
1679
|
|
|
1302
1680
|
Args:
|
|
@@ -1312,9 +1690,8 @@ def storage_delete(name: str) -> server_common.RequestId:
|
|
|
1312
1690
|
ValueError: If the storage does not exist.
|
|
1313
1691
|
"""
|
|
1314
1692
|
body = payloads.StorageBody(name=name)
|
|
1315
|
-
response =
|
|
1316
|
-
|
|
1317
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1693
|
+
response = server_common.make_authenticated_request(
|
|
1694
|
+
'POST', '/storage/delete', json=json.loads(body.model_dump_json()))
|
|
1318
1695
|
return server_common.get_request_id(response)
|
|
1319
1696
|
|
|
1320
1697
|
|
|
@@ -1330,7 +1707,9 @@ def local_up(gpus: bool,
|
|
|
1330
1707
|
ssh_key: Optional[str],
|
|
1331
1708
|
cleanup: bool,
|
|
1332
1709
|
context_name: Optional[str] = None,
|
|
1333
|
-
password: Optional[str] = None
|
|
1710
|
+
password: Optional[str] = None,
|
|
1711
|
+
name: Optional[str] = None,
|
|
1712
|
+
port_start: Optional[int] = None) -> server_common.RequestId[None]:
|
|
1334
1713
|
"""Launches a Kubernetes cluster on local machines.
|
|
1335
1714
|
|
|
1336
1715
|
Returns:
|
|
@@ -1341,8 +1720,8 @@ def local_up(gpus: bool,
|
|
|
1341
1720
|
# TODO: move this check to server.
|
|
1342
1721
|
if not server_common.is_api_server_local():
|
|
1343
1722
|
with ux_utils.print_exception_no_traceback():
|
|
1344
|
-
raise ValueError(
|
|
1345
|
-
|
|
1723
|
+
raise ValueError('`sky local up` is only supported when '
|
|
1724
|
+
'running SkyPilot locally.')
|
|
1346
1725
|
|
|
1347
1726
|
body = payloads.LocalUpBody(gpus=gpus,
|
|
1348
1727
|
ips=ips,
|
|
@@ -1350,27 +1729,150 @@ def local_up(gpus: bool,
|
|
|
1350
1729
|
ssh_key=ssh_key,
|
|
1351
1730
|
cleanup=cleanup,
|
|
1352
1731
|
context_name=context_name,
|
|
1353
|
-
password=password
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1732
|
+
password=password,
|
|
1733
|
+
name=name,
|
|
1734
|
+
port_start=port_start)
|
|
1735
|
+
response = server_common.make_authenticated_request(
|
|
1736
|
+
'POST', '/local_up', json=json.loads(body.model_dump_json()))
|
|
1357
1737
|
return server_common.get_request_id(response)
|
|
1358
1738
|
|
|
1359
1739
|
|
|
1360
1740
|
@usage_lib.entrypoint
|
|
1361
1741
|
@server_common.check_server_healthy_or_start
|
|
1362
1742
|
@annotations.client_api
|
|
1363
|
-
def local_down() -> server_common.RequestId:
|
|
1743
|
+
def local_down(name: Optional[str]) -> server_common.RequestId[None]:
|
|
1364
1744
|
"""Tears down the Kubernetes cluster started by local_up."""
|
|
1365
1745
|
# We do not allow local up when the API server is running remotely since it
|
|
1366
1746
|
# will modify the kubeconfig.
|
|
1367
1747
|
# TODO: move this check to remote server.
|
|
1368
1748
|
if not server_common.is_api_server_local():
|
|
1369
1749
|
with ux_utils.print_exception_no_traceback():
|
|
1370
|
-
raise ValueError('sky local down is only supported when running '
|
|
1750
|
+
raise ValueError('`sky local down` is only supported when running '
|
|
1371
1751
|
'SkyPilot locally.')
|
|
1372
|
-
|
|
1373
|
-
|
|
1752
|
+
|
|
1753
|
+
body = payloads.LocalDownBody(name=name)
|
|
1754
|
+
response = server_common.make_authenticated_request(
|
|
1755
|
+
'POST', '/local_down', json=json.loads(body.model_dump_json()))
|
|
1756
|
+
return server_common.get_request_id(response)
|
|
1757
|
+
|
|
1758
|
+
|
|
1759
|
+
def _update_remote_ssh_node_pools(file: str,
|
|
1760
|
+
infra: Optional[str] = None) -> None:
|
|
1761
|
+
"""Update the SSH node pools on the remote server.
|
|
1762
|
+
|
|
1763
|
+
This function will also upload the local SSH key to the remote server, and
|
|
1764
|
+
replace the file path to the remote SSH key file path.
|
|
1765
|
+
|
|
1766
|
+
Args:
|
|
1767
|
+
file: The path to the local SSH node pools config file.
|
|
1768
|
+
infra: The name of the cluster configuration in the local SSH node
|
|
1769
|
+
pools config file. If None, all clusters in the file are updated.
|
|
1770
|
+
"""
|
|
1771
|
+
file = os.path.expanduser(file)
|
|
1772
|
+
if not os.path.exists(file):
|
|
1773
|
+
with ux_utils.print_exception_no_traceback():
|
|
1774
|
+
raise ValueError(
|
|
1775
|
+
f'SSH Node Pool config file {file} does not exist. '
|
|
1776
|
+
'Please check if the file exists and the path is correct.')
|
|
1777
|
+
config = ssh_utils.load_ssh_targets(file)
|
|
1778
|
+
config = ssh_utils.get_cluster_config(config, infra)
|
|
1779
|
+
pools_config = {}
|
|
1780
|
+
for name, pool_config in config.items():
|
|
1781
|
+
hosts_info = ssh_utils.prepare_hosts_info(
|
|
1782
|
+
name, pool_config, upload_ssh_key_func=_upload_ssh_key_and_wait)
|
|
1783
|
+
pools_config[name] = {'hosts': hosts_info}
|
|
1784
|
+
server_common.make_authenticated_request('POST',
|
|
1785
|
+
'/ssh_node_pools',
|
|
1786
|
+
json=pools_config)
|
|
1787
|
+
|
|
1788
|
+
|
|
1789
|
+
def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
|
|
1790
|
+
"""Upload the SSH key to the remote server and wait for the key to be
|
|
1791
|
+
uploaded.
|
|
1792
|
+
|
|
1793
|
+
Args:
|
|
1794
|
+
key_name: The name of the SSH key.
|
|
1795
|
+
key_file_path: The path to the local SSH key file.
|
|
1796
|
+
|
|
1797
|
+
Returns:
|
|
1798
|
+
The path for the remote SSH key file on the API server.
|
|
1799
|
+
"""
|
|
1800
|
+
if not os.path.exists(os.path.expanduser(key_file_path)):
|
|
1801
|
+
with ux_utils.print_exception_no_traceback():
|
|
1802
|
+
raise ValueError(f'SSH key file not found: {key_file_path}')
|
|
1803
|
+
|
|
1804
|
+
with open(os.path.expanduser(key_file_path), 'rb') as key_file:
|
|
1805
|
+
response = server_common.make_authenticated_request(
|
|
1806
|
+
'POST',
|
|
1807
|
+
'/ssh_node_pools/keys',
|
|
1808
|
+
files={
|
|
1809
|
+
'key_file': (key_name, key_file, 'application/octet-stream')
|
|
1810
|
+
},
|
|
1811
|
+
data={'key_name': key_name},
|
|
1812
|
+
cookies=server_common.get_api_cookie_jar())
|
|
1813
|
+
|
|
1814
|
+
return response.json()['key_path']
|
|
1815
|
+
|
|
1816
|
+
|
|
1817
|
+
@usage_lib.entrypoint
|
|
1818
|
+
@server_common.check_server_healthy_or_start
|
|
1819
|
+
@annotations.client_api
|
|
1820
|
+
def ssh_up(infra: Optional[str] = None,
|
|
1821
|
+
file: Optional[str] = None) -> server_common.RequestId[None]:
|
|
1822
|
+
"""Deploys the SSH Node Pools defined in ~/.sky/ssh_targets.yaml.
|
|
1823
|
+
|
|
1824
|
+
Args:
|
|
1825
|
+
infra: Name of the cluster configuration in ssh_targets.yaml.
|
|
1826
|
+
If None, the first cluster in the file is used.
|
|
1827
|
+
file: Name of the ssh node pool configuration file to use. If
|
|
1828
|
+
None, the default path, ~/.sky/ssh_node_pools.yaml is used.
|
|
1829
|
+
|
|
1830
|
+
Returns:
|
|
1831
|
+
request_id: The request ID of the SSH cluster deployment request.
|
|
1832
|
+
"""
|
|
1833
|
+
if file is not None:
|
|
1834
|
+
_update_remote_ssh_node_pools(file, infra)
|
|
1835
|
+
|
|
1836
|
+
# Use SSH node pools router endpoint
|
|
1837
|
+
body = payloads.SSHUpBody(infra=infra, cleanup=False)
|
|
1838
|
+
if infra is not None:
|
|
1839
|
+
# Call the specific pool deployment endpoint
|
|
1840
|
+
response = server_common.make_authenticated_request(
|
|
1841
|
+
'POST', f'/ssh_node_pools/{infra}/deploy')
|
|
1842
|
+
else:
|
|
1843
|
+
# Call the general deployment endpoint
|
|
1844
|
+
response = server_common.make_authenticated_request(
|
|
1845
|
+
'POST',
|
|
1846
|
+
'/ssh_node_pools/deploy',
|
|
1847
|
+
json=json.loads(body.model_dump_json()))
|
|
1848
|
+
return server_common.get_request_id(response)
|
|
1849
|
+
|
|
1850
|
+
|
|
1851
|
+
@usage_lib.entrypoint
|
|
1852
|
+
@server_common.check_server_healthy_or_start
|
|
1853
|
+
@annotations.client_api
|
|
1854
|
+
def ssh_down(infra: Optional[str] = None) -> server_common.RequestId[None]:
|
|
1855
|
+
"""Tears down a Kubernetes cluster on SSH targets.
|
|
1856
|
+
|
|
1857
|
+
Args:
|
|
1858
|
+
infra: Name of the cluster configuration in ssh_targets.yaml.
|
|
1859
|
+
If None, the first cluster in the file is used.
|
|
1860
|
+
|
|
1861
|
+
Returns:
|
|
1862
|
+
request_id: The request ID of the SSH cluster teardown request.
|
|
1863
|
+
"""
|
|
1864
|
+
# Use SSH node pools router endpoint
|
|
1865
|
+
body = payloads.SSHUpBody(infra=infra, cleanup=True)
|
|
1866
|
+
if infra is not None:
|
|
1867
|
+
# Call the specific pool down endpoint
|
|
1868
|
+
response = server_common.make_authenticated_request(
|
|
1869
|
+
'POST', f'/ssh_node_pools/{infra}/down')
|
|
1870
|
+
else:
|
|
1871
|
+
# Call the general down endpoint
|
|
1872
|
+
response = server_common.make_authenticated_request(
|
|
1873
|
+
'POST',
|
|
1874
|
+
'/ssh_node_pools/down',
|
|
1875
|
+
json=json.loads(body.model_dump_json()))
|
|
1374
1876
|
return server_common.get_request_id(response)
|
|
1375
1877
|
|
|
1376
1878
|
|
|
@@ -1378,9 +1880,12 @@ def local_down() -> server_common.RequestId:
|
|
|
1378
1880
|
@server_common.check_server_healthy_or_start
|
|
1379
1881
|
@annotations.client_api
|
|
1380
1882
|
def realtime_kubernetes_gpu_availability(
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1883
|
+
context: Optional[str] = None,
|
|
1884
|
+
name_filter: Optional[str] = None,
|
|
1885
|
+
quantity_filter: Optional[int] = None,
|
|
1886
|
+
is_ssh: Optional[bool] = None
|
|
1887
|
+
) -> server_common.RequestId[List[Tuple[
|
|
1888
|
+
str, List['models.RealtimeGpuAvailability']]]]:
|
|
1384
1889
|
"""Gets the real-time Kubernetes GPU availability.
|
|
1385
1890
|
|
|
1386
1891
|
Returns:
|
|
@@ -1390,12 +1895,12 @@ def realtime_kubernetes_gpu_availability(
|
|
|
1390
1895
|
context=context,
|
|
1391
1896
|
name_filter=name_filter,
|
|
1392
1897
|
quantity_filter=quantity_filter,
|
|
1898
|
+
is_ssh=is_ssh,
|
|
1393
1899
|
)
|
|
1394
|
-
response =
|
|
1395
|
-
|
|
1396
|
-
'realtime_kubernetes_gpu_availability',
|
|
1397
|
-
json=json.loads(body.model_dump_json())
|
|
1398
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1900
|
+
response = server_common.make_authenticated_request(
|
|
1901
|
+
'POST',
|
|
1902
|
+
'/realtime_kubernetes_gpu_availability',
|
|
1903
|
+
json=json.loads(body.model_dump_json()))
|
|
1399
1904
|
return server_common.get_request_id(response)
|
|
1400
1905
|
|
|
1401
1906
|
|
|
@@ -1403,7 +1908,8 @@ def realtime_kubernetes_gpu_availability(
|
|
|
1403
1908
|
@server_common.check_server_healthy_or_start
|
|
1404
1909
|
@annotations.client_api
|
|
1405
1910
|
def kubernetes_node_info(
|
|
1406
|
-
|
|
1911
|
+
context: Optional[str] = None
|
|
1912
|
+
) -> server_common.RequestId['models.KubernetesNodesInfo']:
|
|
1407
1913
|
"""Gets the resource information for all the nodes in the cluster.
|
|
1408
1914
|
|
|
1409
1915
|
Currently only GPU resources are supported. The function returns the total
|
|
@@ -1424,17 +1930,20 @@ def kubernetes_node_info(
|
|
|
1424
1930
|
information.
|
|
1425
1931
|
"""
|
|
1426
1932
|
body = payloads.KubernetesNodeInfoRequestBody(context=context)
|
|
1427
|
-
response =
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1933
|
+
response = server_common.make_authenticated_request(
|
|
1934
|
+
'POST',
|
|
1935
|
+
'/kubernetes_node_info',
|
|
1936
|
+
json=json.loads(body.model_dump_json()))
|
|
1431
1937
|
return server_common.get_request_id(response)
|
|
1432
1938
|
|
|
1433
1939
|
|
|
1434
1940
|
@usage_lib.entrypoint
|
|
1435
1941
|
@server_common.check_server_healthy_or_start
|
|
1436
1942
|
@annotations.client_api
|
|
1437
|
-
def status_kubernetes() -> server_common.RequestId
|
|
1943
|
+
def status_kubernetes() -> server_common.RequestId[
|
|
1944
|
+
Tuple[List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
|
|
1945
|
+
List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
|
|
1946
|
+
List[responses.ManagedJobRecord], Optional[str]]]:
|
|
1438
1947
|
"""Gets all SkyPilot clusters and jobs in the Kubernetes cluster.
|
|
1439
1948
|
|
|
1440
1949
|
Managed jobs and services are also included in the clusters returned.
|
|
@@ -1455,21 +1964,24 @@ def status_kubernetes() -> server_common.RequestId:
|
|
|
1455
1964
|
dictionary job info, see jobs.queue_from_kubernetes_pod for details.
|
|
1456
1965
|
- context: Kubernetes context used to fetch the cluster information.
|
|
1457
1966
|
"""
|
|
1458
|
-
response =
|
|
1459
|
-
|
|
1460
|
-
cookies=server_common.get_api_cookie_jar())
|
|
1967
|
+
response = server_common.make_authenticated_request('GET',
|
|
1968
|
+
'/status_kubernetes')
|
|
1461
1969
|
return server_common.get_request_id(response)
|
|
1462
1970
|
|
|
1463
1971
|
|
|
1464
1972
|
# === API request APIs ===
|
|
1465
1973
|
@usage_lib.entrypoint
|
|
1466
|
-
@server_common.check_server_healthy_or_start
|
|
1467
1974
|
@annotations.client_api
|
|
1468
|
-
def get(request_id:
|
|
1975
|
+
def get(request_id: server_common.RequestId[T]) -> T:
|
|
1469
1976
|
"""Waits for and gets the result of a request.
|
|
1470
1977
|
|
|
1978
|
+
This function will not check the server health since /api/get is typically
|
|
1979
|
+
not the first API call in an SDK session and checking the server health
|
|
1980
|
+
may cause GET /api/get being sent to a restarted API server.
|
|
1981
|
+
|
|
1471
1982
|
Args:
|
|
1472
|
-
request_id: The request ID of the request to get.
|
|
1983
|
+
request_id: The request ID of the request to get. May be a full request
|
|
1984
|
+
ID or a prefix.
|
|
1473
1985
|
|
|
1474
1986
|
Returns:
|
|
1475
1987
|
The ``Request Returns`` of the specified request. See the documentation
|
|
@@ -1480,19 +1992,20 @@ def get(request_id: str) -> Any:
|
|
|
1480
1992
|
see ``Request Raises`` in the documentation of the specific requests
|
|
1481
1993
|
above.
|
|
1482
1994
|
"""
|
|
1483
|
-
response =
|
|
1484
|
-
|
|
1995
|
+
response = server_common.make_authenticated_request(
|
|
1996
|
+
'GET',
|
|
1997
|
+
f'/api/get?request_id={request_id}',
|
|
1998
|
+
retry=False,
|
|
1485
1999
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1486
|
-
None)
|
|
1487
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2000
|
+
None))
|
|
1488
2001
|
request_task = None
|
|
1489
2002
|
if response.status_code == 200:
|
|
1490
2003
|
request_task = requests_lib.Request.decode(
|
|
1491
|
-
|
|
2004
|
+
payloads.RequestPayload(**response.json()))
|
|
1492
2005
|
elif response.status_code == 500:
|
|
1493
2006
|
try:
|
|
1494
2007
|
request_task = requests_lib.Request.decode(
|
|
1495
|
-
|
|
2008
|
+
payloads.RequestPayload(**response.json().get('detail')))
|
|
1496
2009
|
logger.debug(f'Got request with error: {request_task.name}')
|
|
1497
2010
|
except Exception: # pylint: disable=broad-except
|
|
1498
2011
|
request_task = None
|
|
@@ -1518,23 +2031,45 @@ def get(request_id: str) -> Any:
|
|
|
1518
2031
|
return request_task.get_return_value()
|
|
1519
2032
|
|
|
1520
2033
|
|
|
2034
|
+
@typing.overload
|
|
2035
|
+
def stream_and_get(request_id: server_common.RequestId[T],
|
|
2036
|
+
log_path: Optional[str] = None,
|
|
2037
|
+
tail: Optional[int] = None,
|
|
2038
|
+
follow: bool = True,
|
|
2039
|
+
output_stream: Optional['io.TextIOBase'] = None) -> T:
|
|
2040
|
+
...
|
|
2041
|
+
|
|
2042
|
+
|
|
2043
|
+
@typing.overload
|
|
2044
|
+
def stream_and_get(request_id: None = None,
|
|
2045
|
+
log_path: Optional[str] = None,
|
|
2046
|
+
tail: Optional[int] = None,
|
|
2047
|
+
follow: bool = True,
|
|
2048
|
+
output_stream: Optional['io.TextIOBase'] = None) -> None:
|
|
2049
|
+
...
|
|
2050
|
+
|
|
2051
|
+
|
|
1521
2052
|
@usage_lib.entrypoint
|
|
1522
2053
|
@server_common.check_server_healthy_or_start
|
|
1523
2054
|
@annotations.client_api
|
|
2055
|
+
@rest.retry_transient_errors()
|
|
1524
2056
|
def stream_and_get(
|
|
1525
|
-
request_id: Optional[
|
|
2057
|
+
request_id: Optional[server_common.RequestId[T]] = None,
|
|
1526
2058
|
log_path: Optional[str] = None,
|
|
1527
2059
|
tail: Optional[int] = None,
|
|
1528
2060
|
follow: bool = True,
|
|
1529
2061
|
output_stream: Optional['io.TextIOBase'] = None,
|
|
1530
|
-
) ->
|
|
2062
|
+
) -> Optional[T]:
|
|
1531
2063
|
"""Streams the logs of a request or a log file and gets the final result.
|
|
1532
2064
|
|
|
1533
2065
|
This will block until the request is finished. The request id can be a
|
|
1534
2066
|
prefix of the full request id.
|
|
1535
2067
|
|
|
1536
2068
|
Args:
|
|
1537
|
-
request_id: The
|
|
2069
|
+
request_id: The request ID of the request to stream. May be a full
|
|
2070
|
+
request ID or a prefix.
|
|
2071
|
+
If None, the latest request submitted to the API server is streamed.
|
|
2072
|
+
Using None request_id is not recommended in multi-user environments.
|
|
1538
2073
|
log_path: The path to the log file to stream.
|
|
1539
2074
|
tail: The number of lines to show from the end of the logs.
|
|
1540
2075
|
If None, show all logs.
|
|
@@ -1545,6 +2080,8 @@ def stream_and_get(
|
|
|
1545
2080
|
Returns:
|
|
1546
2081
|
The ``Request Returns`` of the specified request. See the documentation
|
|
1547
2082
|
of the specific requests above for more details.
|
|
2083
|
+
If follow is False, will always return None. See note on
|
|
2084
|
+
stream_response.
|
|
1548
2085
|
|
|
1549
2086
|
Raises:
|
|
1550
2087
|
Exception: It raises the same exceptions as the specific requests,
|
|
@@ -1558,27 +2095,44 @@ def stream_and_get(
|
|
|
1558
2095
|
'follow': follow,
|
|
1559
2096
|
'format': 'console',
|
|
1560
2097
|
}
|
|
1561
|
-
response =
|
|
1562
|
-
|
|
2098
|
+
response = server_common.make_authenticated_request(
|
|
2099
|
+
'GET',
|
|
2100
|
+
'/api/stream',
|
|
1563
2101
|
params=params,
|
|
2102
|
+
retry=False,
|
|
1564
2103
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1565
2104
|
None),
|
|
1566
|
-
stream=True
|
|
1567
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2105
|
+
stream=True)
|
|
1568
2106
|
if response.status_code in [404, 400]:
|
|
1569
2107
|
detail = response.json().get('detail')
|
|
1570
2108
|
with ux_utils.print_exception_no_traceback():
|
|
1571
|
-
raise
|
|
2109
|
+
raise exceptions.ClientError(f'Failed to stream logs: {detail}')
|
|
2110
|
+
stream_request_id: Optional[server_common.RequestId[
|
|
2111
|
+
T]] = server_common.get_stream_request_id(response)
|
|
2112
|
+
if request_id is not None and stream_request_id is not None:
|
|
2113
|
+
assert request_id == stream_request_id
|
|
2114
|
+
if request_id is None:
|
|
2115
|
+
request_id = stream_request_id
|
|
1572
2116
|
elif response.status_code != 200:
|
|
2117
|
+
# TODO(syang): handle the case where the requestID is not provided
|
|
2118
|
+
# see https://github.com/skypilot-org/skypilot/issues/6549
|
|
2119
|
+
if request_id is None:
|
|
2120
|
+
return None
|
|
1573
2121
|
return get(request_id)
|
|
1574
|
-
return stream_response(request_id,
|
|
2122
|
+
return stream_response(request_id,
|
|
2123
|
+
response,
|
|
2124
|
+
output_stream,
|
|
2125
|
+
resumable=True,
|
|
2126
|
+
get_result=follow)
|
|
1575
2127
|
|
|
1576
2128
|
|
|
1577
2129
|
@usage_lib.entrypoint
|
|
1578
2130
|
@annotations.client_api
|
|
1579
|
-
def api_cancel(request_ids: Optional[Union[
|
|
2131
|
+
def api_cancel(request_ids: Optional[Union[server_common.RequestId[T],
|
|
2132
|
+
List[server_common.RequestId[T]],
|
|
2133
|
+
str, List[str]]] = None,
|
|
1580
2134
|
all_users: bool = False,
|
|
1581
|
-
silent: bool = False) -> server_common.RequestId:
|
|
2135
|
+
silent: bool = False) -> server_common.RequestId[List[str]]:
|
|
1582
2136
|
"""Aborts a request or all requests.
|
|
1583
2137
|
|
|
1584
2138
|
Args:
|
|
@@ -1618,20 +2172,35 @@ def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
|
|
|
1618
2172
|
echo(f'Cancelling {len(request_ids)} request{plural}: '
|
|
1619
2173
|
f'{request_id_str}...')
|
|
1620
2174
|
|
|
1621
|
-
response =
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
2175
|
+
response = server_common.make_authenticated_request(
|
|
2176
|
+
'POST',
|
|
2177
|
+
'/api/cancel',
|
|
2178
|
+
json=json.loads(body.model_dump_json()),
|
|
2179
|
+
timeout=5)
|
|
1625
2180
|
return server_common.get_request_id(response)
|
|
1626
2181
|
|
|
1627
2182
|
|
|
2183
|
+
def _local_api_server_running(kill: bool = False) -> bool:
|
|
2184
|
+
"""Checks if the local api server is running."""
|
|
2185
|
+
for process in psutil.process_iter(attrs=['pid', 'cmdline']):
|
|
2186
|
+
cmdline = process.info['cmdline']
|
|
2187
|
+
if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
|
|
2188
|
+
if kill:
|
|
2189
|
+
subprocess_utils.kill_children_processes(
|
|
2190
|
+
parent_pids=[process.pid], force=True)
|
|
2191
|
+
return True
|
|
2192
|
+
return False
|
|
2193
|
+
|
|
2194
|
+
|
|
1628
2195
|
@usage_lib.entrypoint
|
|
1629
2196
|
@annotations.client_api
|
|
1630
2197
|
def api_status(
|
|
1631
|
-
request_ids: Optional[List[str]] = None,
|
|
2198
|
+
request_ids: Optional[List[Union[server_common.RequestId[T], str]]] = None,
|
|
1632
2199
|
# pylint: disable=redefined-builtin
|
|
1633
|
-
all_status: bool = False
|
|
1634
|
-
|
|
2200
|
+
all_status: bool = False,
|
|
2201
|
+
limit: Optional[int] = None,
|
|
2202
|
+
fields: Optional[List[str]] = None,
|
|
2203
|
+
) -> List[payloads.RequestPayload]:
|
|
1635
2204
|
"""Lists all requests.
|
|
1636
2205
|
|
|
1637
2206
|
Args:
|
|
@@ -1639,29 +2208,37 @@ def api_status(
|
|
|
1639
2208
|
If None, all requests are queried.
|
|
1640
2209
|
all_status: Whether to list all finished requests as well. This argument
|
|
1641
2210
|
is ignored if request_ids is not None.
|
|
2211
|
+
limit: The number of requests to show. If None, show all requests.
|
|
2212
|
+
fields: The fields to get. If None, get all fields.
|
|
1642
2213
|
|
|
1643
2214
|
Returns:
|
|
1644
2215
|
A list of request payloads.
|
|
1645
2216
|
"""
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
2217
|
+
if server_common.is_api_server_local() and not _local_api_server_running():
|
|
2218
|
+
logger.info('SkyPilot API server is not running.')
|
|
2219
|
+
return []
|
|
2220
|
+
|
|
2221
|
+
body = payloads.RequestStatusBody(
|
|
2222
|
+
request_ids=request_ids,
|
|
2223
|
+
all_status=all_status,
|
|
2224
|
+
limit=limit,
|
|
2225
|
+
fields=fields,
|
|
2226
|
+
)
|
|
2227
|
+
response = server_common.make_authenticated_request(
|
|
2228
|
+
'GET',
|
|
2229
|
+
'/api/status',
|
|
1650
2230
|
params=server_common.request_body_to_params(body),
|
|
1651
2231
|
timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
|
|
1652
|
-
None)
|
|
1653
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2232
|
+
None))
|
|
1654
2233
|
server_common.handle_request_error(response)
|
|
1655
|
-
return [
|
|
1656
|
-
requests_lib.RequestPayload(**request) for request in response.json()
|
|
1657
|
-
]
|
|
2234
|
+
return [payloads.RequestPayload(**request) for request in response.json()]
|
|
1658
2235
|
|
|
1659
2236
|
|
|
1660
2237
|
# === API server management APIs ===
|
|
1661
2238
|
@usage_lib.entrypoint
|
|
1662
2239
|
@server_common.check_server_healthy_or_start
|
|
1663
2240
|
@annotations.client_api
|
|
1664
|
-
def api_info() ->
|
|
2241
|
+
def api_info() -> responses.APIHealthResponse:
|
|
1665
2242
|
"""Gets the server's status, commit and version.
|
|
1666
2243
|
|
|
1667
2244
|
Returns:
|
|
@@ -1674,13 +2251,19 @@ def api_info() -> Dict[str, str]:
|
|
|
1674
2251
|
'api_version': '1',
|
|
1675
2252
|
'commit': 'abc1234567890',
|
|
1676
2253
|
'version': '1.0.0',
|
|
2254
|
+
'version_on_disk': '1.0.0',
|
|
2255
|
+
'user': {
|
|
2256
|
+
'name': 'test@example.com',
|
|
2257
|
+
'id': '12345abcd',
|
|
2258
|
+
},
|
|
1677
2259
|
}
|
|
1678
2260
|
|
|
2261
|
+
Note that user may be None if we are not using an auth proxy.
|
|
2262
|
+
|
|
1679
2263
|
"""
|
|
1680
|
-
response =
|
|
1681
|
-
cookies=server_common.get_api_cookie_jar())
|
|
2264
|
+
response = server_common.make_authenticated_request('GET', '/api/health')
|
|
1682
2265
|
response.raise_for_status()
|
|
1683
|
-
return response.json()
|
|
2266
|
+
return responses.APIHealthResponse(**response.json())
|
|
1684
2267
|
|
|
1685
2268
|
|
|
1686
2269
|
@usage_lib.entrypoint
|
|
@@ -1690,6 +2273,9 @@ def api_start(
|
|
|
1690
2273
|
deploy: bool = False,
|
|
1691
2274
|
host: str = '127.0.0.1',
|
|
1692
2275
|
foreground: bool = False,
|
|
2276
|
+
metrics: bool = False,
|
|
2277
|
+
metrics_port: Optional[int] = None,
|
|
2278
|
+
enable_basic_auth: bool = False,
|
|
1693
2279
|
) -> None:
|
|
1694
2280
|
"""Starts the API server.
|
|
1695
2281
|
|
|
@@ -1703,6 +2289,10 @@ def api_start(
|
|
|
1703
2289
|
if deploy is True, to allow remote access.
|
|
1704
2290
|
foreground: Whether to run the API server in the foreground (run in
|
|
1705
2291
|
the current process).
|
|
2292
|
+
metrics: Whether to export metrics of the API server.
|
|
2293
|
+
metrics_port: The port to export metrics of the API server.
|
|
2294
|
+
enable_basic_auth: Whether to enable basic authentication
|
|
2295
|
+
in the API server.
|
|
1706
2296
|
Returns:
|
|
1707
2297
|
None
|
|
1708
2298
|
"""
|
|
@@ -1721,15 +2311,15 @@ def api_start(
|
|
|
1721
2311
|
'from the config file and/or unset the '
|
|
1722
2312
|
'SKYPILOT_API_SERVER_ENDPOINT environment '
|
|
1723
2313
|
'variable.')
|
|
1724
|
-
server_common.check_server_healthy_or_start_fn(deploy, host, foreground
|
|
2314
|
+
server_common.check_server_healthy_or_start_fn(deploy, host, foreground,
|
|
2315
|
+
metrics, metrics_port,
|
|
2316
|
+
enable_basic_auth)
|
|
1725
2317
|
if foreground:
|
|
1726
2318
|
# Explain why current process exited
|
|
1727
2319
|
logger.info('API server is already running:')
|
|
1728
2320
|
api_server_url = server_common.get_server_url(host)
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server: '
|
|
1732
|
-
f'{api_server_url} {dashboard_msg}\n'
|
|
2321
|
+
logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server and dashboard: '
|
|
2322
|
+
f'{api_server_url}\n'
|
|
1733
2323
|
f'{ux_utils.INDENT_LAST_SYMBOL}'
|
|
1734
2324
|
f'View API server logs at: {constants.API_SERVER_LOGS}')
|
|
1735
2325
|
|
|
@@ -1752,16 +2342,30 @@ def api_stop() -> None:
|
|
|
1752
2342
|
f'Cannot kill the API server at {server_url} because it is not '
|
|
1753
2343
|
f'the default SkyPilot API server started locally.')
|
|
1754
2344
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
2345
|
+
# Acquire the api server creation lock to prevent multiple processes from
|
|
2346
|
+
# stopping and starting the API server at the same time.
|
|
2347
|
+
with filelock.FileLock(
|
|
2348
|
+
os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
|
|
2349
|
+
try:
|
|
2350
|
+
with open(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH),
|
|
2351
|
+
'r',
|
|
2352
|
+
encoding='utf-8') as f:
|
|
2353
|
+
pids = f.read().split('\n')[:-1]
|
|
2354
|
+
for pid in pids:
|
|
2355
|
+
if subprocess_utils.is_process_alive(int(pid.strip())):
|
|
2356
|
+
subprocess_utils.kill_children_processes(
|
|
2357
|
+
parent_pids=[int(pid.strip())], force=True)
|
|
2358
|
+
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
|
|
2359
|
+
except FileNotFoundError:
|
|
2360
|
+
# its fine we will create it
|
|
2361
|
+
pass
|
|
2362
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2363
|
+
# in case we get perm issues or something is messed up, just ignore
|
|
2364
|
+
# it and assume the process is dead
|
|
2365
|
+
logger.error(f'Error looking at job controller pid file: {e}')
|
|
2366
|
+
pass
|
|
2367
|
+
|
|
2368
|
+
found = _local_api_server_running(kill=True)
|
|
1765
2369
|
|
|
1766
2370
|
if found:
|
|
1767
2371
|
logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.'
|
|
@@ -1796,9 +2400,86 @@ def api_server_logs(follow: bool = True, tail: Optional[int] = None) -> None:
|
|
|
1796
2400
|
stream_and_get(log_path=constants.API_SERVER_LOGS, tail=tail)
|
|
1797
2401
|
|
|
1798
2402
|
|
|
2403
|
+
def _save_config_updates(endpoint: Optional[str] = None,
|
|
2404
|
+
service_account_token: Optional[str] = None) -> None:
|
|
2405
|
+
"""Save endpoint and/or service account token to config file."""
|
|
2406
|
+
config_path = pathlib.Path(
|
|
2407
|
+
skypilot_config.get_user_config_path()).expanduser()
|
|
2408
|
+
with filelock.FileLock(config_path.with_suffix('.lock')):
|
|
2409
|
+
if not config_path.exists():
|
|
2410
|
+
config_path.touch()
|
|
2411
|
+
config: Dict[str, Any] = {}
|
|
2412
|
+
else:
|
|
2413
|
+
config = skypilot_config.get_user_config()
|
|
2414
|
+
config = dict(config)
|
|
2415
|
+
|
|
2416
|
+
# Update endpoint if provided
|
|
2417
|
+
if endpoint is not None:
|
|
2418
|
+
# We should always reset the api_server config to avoid legacy
|
|
2419
|
+
# service account token.
|
|
2420
|
+
config['api_server'] = {}
|
|
2421
|
+
config['api_server']['endpoint'] = endpoint
|
|
2422
|
+
|
|
2423
|
+
# Update service account token if provided
|
|
2424
|
+
if service_account_token is not None:
|
|
2425
|
+
if 'api_server' not in config:
|
|
2426
|
+
config['api_server'] = {}
|
|
2427
|
+
config['api_server'][
|
|
2428
|
+
'service_account_token'] = service_account_token
|
|
2429
|
+
|
|
2430
|
+
yaml_utils.dump_yaml(str(config_path), config)
|
|
2431
|
+
skypilot_config.reload_config()
|
|
2432
|
+
|
|
2433
|
+
|
|
2434
|
+
def _clear_api_server_config() -> None:
|
|
2435
|
+
"""Clear endpoint and service account token from config file."""
|
|
2436
|
+
config_path = pathlib.Path(
|
|
2437
|
+
skypilot_config.get_user_config_path()).expanduser()
|
|
2438
|
+
with filelock.FileLock(config_path.with_suffix('.lock')):
|
|
2439
|
+
if not config_path.exists():
|
|
2440
|
+
return
|
|
2441
|
+
|
|
2442
|
+
config = skypilot_config.get_user_config()
|
|
2443
|
+
config = dict(config)
|
|
2444
|
+
if 'api_server' in config:
|
|
2445
|
+
# We might not have set the endpoint in the config file, so we
|
|
2446
|
+
# need to check before deleting.
|
|
2447
|
+
del config['api_server']
|
|
2448
|
+
|
|
2449
|
+
yaml_utils.dump_yaml(str(config_path), config, blank=True)
|
|
2450
|
+
skypilot_config.reload_config()
|
|
2451
|
+
|
|
2452
|
+
|
|
2453
|
+
def _validate_endpoint(endpoint: Optional[str]) -> str:
|
|
2454
|
+
"""Validate and normalize the endpoint URL."""
|
|
2455
|
+
if endpoint is None:
|
|
2456
|
+
endpoint = click.prompt('Enter your SkyPilot API server endpoint')
|
|
2457
|
+
# Check endpoint is a valid URL
|
|
2458
|
+
if (endpoint is not None and not endpoint.startswith('http://') and
|
|
2459
|
+
not endpoint.startswith('https://')):
|
|
2460
|
+
raise click.BadParameter('Endpoint must be a valid URL.')
|
|
2461
|
+
return endpoint.rstrip('/')
|
|
2462
|
+
|
|
2463
|
+
|
|
2464
|
+
def _check_endpoint_in_env_var(is_login: bool) -> None:
|
|
2465
|
+
# If the user has set the endpoint via the environment variable, we should
|
|
2466
|
+
# not do anything as we can't disambiguate between the env var and the
|
|
2467
|
+
# config file.
|
|
2468
|
+
"""Check if the endpoint is set in the environment variable."""
|
|
2469
|
+
if constants.SKY_API_SERVER_URL_ENV_VAR in os.environ:
|
|
2470
|
+
with ux_utils.print_exception_no_traceback():
|
|
2471
|
+
action = 'login to' if is_login else 'logout of'
|
|
2472
|
+
raise RuntimeError(f'Cannot {action} API server when the endpoint '
|
|
2473
|
+
'is set via the environment variable. Run unset '
|
|
2474
|
+
f'{constants.SKY_API_SERVER_URL_ENV_VAR} to '
|
|
2475
|
+
'clear the environment variable.')
|
|
2476
|
+
|
|
2477
|
+
|
|
1799
2478
|
@usage_lib.entrypoint
|
|
1800
2479
|
@annotations.client_api
|
|
1801
|
-
def api_login(endpoint: Optional[str] = None
|
|
2480
|
+
def api_login(endpoint: Optional[str] = None,
|
|
2481
|
+
relogin: bool = False,
|
|
2482
|
+
service_account_token: Optional[str] = None) -> None:
|
|
1802
2483
|
"""Logs into a SkyPilot API server.
|
|
1803
2484
|
|
|
1804
2485
|
This sets the endpoint globally, i.e., all SkyPilot CLI and SDK calls will
|
|
@@ -1810,37 +2491,262 @@ def api_login(endpoint: Optional[str] = None) -> None:
|
|
|
1810
2491
|
Args:
|
|
1811
2492
|
endpoint: The endpoint of the SkyPilot API server, e.g.,
|
|
1812
2493
|
http://1.2.3.4:46580 or https://skypilot.mydomain.com.
|
|
2494
|
+
relogin: Whether to force relogin with OAuth2 when enabled.
|
|
2495
|
+
service_account_token: Service account token for authentication.
|
|
1813
2496
|
|
|
1814
2497
|
Returns:
|
|
1815
2498
|
None
|
|
1816
2499
|
"""
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
#
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
2500
|
+
_check_endpoint_in_env_var(is_login=True)
|
|
2501
|
+
|
|
2502
|
+
# Validate and normalize endpoint
|
|
2503
|
+
endpoint = _validate_endpoint(endpoint)
|
|
2504
|
+
|
|
2505
|
+
def _show_logged_in_message(
|
|
2506
|
+
endpoint: str, dashboard_url: str, user: Optional[Dict[str, Any]],
|
|
2507
|
+
server_status: server_common.ApiServerStatus) -> None:
|
|
2508
|
+
"""Show the logged in message."""
|
|
2509
|
+
if server_status != server_common.ApiServerStatus.HEALTHY:
|
|
2510
|
+
with ux_utils.print_exception_no_traceback():
|
|
2511
|
+
raise ValueError(f'Cannot log in API server at '
|
|
2512
|
+
f'{endpoint} (status: {server_status.value})')
|
|
2513
|
+
|
|
2514
|
+
identity_info = f'\n{ux_utils.INDENT_SYMBOL}{colorama.Fore.GREEN}User: '
|
|
2515
|
+
if user:
|
|
2516
|
+
user_name = user.get('name')
|
|
2517
|
+
user_id = user.get('id')
|
|
2518
|
+
if user_name and user_id:
|
|
2519
|
+
identity_info += f'{user_name} ({user_id})'
|
|
2520
|
+
elif user_id:
|
|
2521
|
+
identity_info += user_id
|
|
1836
2522
|
else:
|
|
1837
|
-
|
|
1838
|
-
config.set_nested(('api_server', 'endpoint'), endpoint)
|
|
1839
|
-
common_utils.dump_yaml(str(config_path), dict(config))
|
|
1840
|
-
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2523
|
+
identity_info = ''
|
|
1841
2524
|
dashboard_msg = f'Dashboard: {dashboard_url}'
|
|
1842
2525
|
click.secho(
|
|
1843
2526
|
f'Logged into SkyPilot API server at: {endpoint}'
|
|
2527
|
+
f'{identity_info}'
|
|
1844
2528
|
f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
|
|
1845
2529
|
f'{dashboard_msg}',
|
|
1846
2530
|
fg='green')
|
|
2531
|
+
|
|
2532
|
+
def _set_user_hash(user_hash: Optional[str]) -> None:
|
|
2533
|
+
if user_hash is not None:
|
|
2534
|
+
if not common_utils.is_valid_user_hash(user_hash):
|
|
2535
|
+
raise ValueError(f'Invalid user hash: {user_hash}')
|
|
2536
|
+
common_utils.set_user_hash_locally(user_hash)
|
|
2537
|
+
|
|
2538
|
+
# Handle service account token authentication
|
|
2539
|
+
if service_account_token:
|
|
2540
|
+
if not service_account_token.startswith('sky_'):
|
|
2541
|
+
raise ValueError('Invalid service account token format. '
|
|
2542
|
+
'Token must start with "sky_"')
|
|
2543
|
+
|
|
2544
|
+
# Save both endpoint and token to config in a single operation
|
|
2545
|
+
_save_config_updates(endpoint=endpoint,
|
|
2546
|
+
service_account_token=service_account_token)
|
|
2547
|
+
|
|
2548
|
+
# Test the authentication by checking server health
|
|
2549
|
+
try:
|
|
2550
|
+
server_status, api_server_info = server_common.check_server_healthy(
|
|
2551
|
+
endpoint)
|
|
2552
|
+
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2553
|
+
if api_server_info.user is not None:
|
|
2554
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2555
|
+
_show_logged_in_message(endpoint, dashboard_url,
|
|
2556
|
+
api_server_info.user, server_status)
|
|
2557
|
+
|
|
2558
|
+
return
|
|
2559
|
+
except exceptions.ApiServerConnectionError as e:
|
|
2560
|
+
with ux_utils.print_exception_no_traceback():
|
|
2561
|
+
raise RuntimeError(
|
|
2562
|
+
f'Failed to connect to API server at {endpoint}: {e}'
|
|
2563
|
+
) from e
|
|
2564
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2565
|
+
with ux_utils.print_exception_no_traceback():
|
|
2566
|
+
raise RuntimeError(
|
|
2567
|
+
f'{colorama.Fore.RED}Service account token authentication '
|
|
2568
|
+
f'failed:{colorama.Style.RESET_ALL} {e}') from None
|
|
2569
|
+
|
|
2570
|
+
# OAuth2/cookie-based authentication flow
|
|
2571
|
+
# TODO(zhwu): this SDK sets global endpoint, which may not be the best
|
|
2572
|
+
# design as a user may expect this is only effective for the current
|
|
2573
|
+
# session. We should consider using env var for specifying endpoint.
|
|
2574
|
+
|
|
2575
|
+
server_status, api_server_info = server_common.check_server_healthy(
|
|
2576
|
+
endpoint)
|
|
2577
|
+
if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
|
|
2578
|
+
# We detected an auth proxy, so go through the auth proxy cookie flow.
|
|
2579
|
+
token: Optional[str] = None
|
|
2580
|
+
server: Optional[oauth_lib.HTTPServer] = None
|
|
2581
|
+
try:
|
|
2582
|
+
callback_port = common_utils.find_free_port(8000)
|
|
2583
|
+
|
|
2584
|
+
token_container: Dict[str, Optional[str]] = {'token': None}
|
|
2585
|
+
logger.debug('Starting local authentication server...')
|
|
2586
|
+
server = oauth_lib.start_local_auth_server(callback_port,
|
|
2587
|
+
token_container,
|
|
2588
|
+
endpoint)
|
|
2589
|
+
|
|
2590
|
+
token_url = (f'{endpoint}/token?local_port={callback_port}')
|
|
2591
|
+
if webbrowser.open(token_url):
|
|
2592
|
+
click.echo(f'{colorama.Fore.GREEN}A web browser has been '
|
|
2593
|
+
f'opened at {token_url}. Please continue the login '
|
|
2594
|
+
f'in the web browser.{colorama.Style.RESET_ALL}\n'
|
|
2595
|
+
f'{colorama.Style.DIM}To manually copy the token, '
|
|
2596
|
+
f'press ctrl+c.{colorama.Style.RESET_ALL}')
|
|
2597
|
+
else:
|
|
2598
|
+
raise ValueError('Failed to open browser.')
|
|
2599
|
+
|
|
2600
|
+
start_time = time.time()
|
|
2601
|
+
|
|
2602
|
+
while (token_container['token'] is None and
|
|
2603
|
+
time.time() - start_time < oauth_lib.AUTH_TIMEOUT):
|
|
2604
|
+
time.sleep(1)
|
|
2605
|
+
|
|
2606
|
+
if token_container['token'] is None:
|
|
2607
|
+
click.echo(f'{colorama.Fore.YELLOW}Authentication timed out '
|
|
2608
|
+
f'after {oauth_lib.AUTH_TIMEOUT} seconds.')
|
|
2609
|
+
else:
|
|
2610
|
+
token = token_container['token']
|
|
2611
|
+
|
|
2612
|
+
except (Exception, KeyboardInterrupt) as e: # pylint: disable=broad-except
|
|
2613
|
+
logger.debug(f'Automatic authentication failed: {e}, '
|
|
2614
|
+
'falling back to manual token entry.')
|
|
2615
|
+
if isinstance(e, KeyboardInterrupt):
|
|
2616
|
+
click.echo(f'\n{colorama.Style.DIM}Interrupted. Press ctrl+c '
|
|
2617
|
+
f'again to exit.{colorama.Style.RESET_ALL}')
|
|
2618
|
+
# Fall back to manual token entry
|
|
2619
|
+
token_url = f'{endpoint}/token'
|
|
2620
|
+
click.echo('Authentication is needed. Please visit this URL '
|
|
2621
|
+
f'to set up the token:{colorama.Style.BRIGHT}\n\n'
|
|
2622
|
+
f'{token_url}\n{colorama.Style.RESET_ALL}')
|
|
2623
|
+
token = click.prompt('Paste the token')
|
|
2624
|
+
finally:
|
|
2625
|
+
if server is not None:
|
|
2626
|
+
try:
|
|
2627
|
+
server.server_close()
|
|
2628
|
+
except Exception: # pylint: disable=broad-except
|
|
2629
|
+
pass
|
|
2630
|
+
if not token:
|
|
2631
|
+
with ux_utils.print_exception_no_traceback():
|
|
2632
|
+
raise ValueError('Authentication failed.')
|
|
2633
|
+
|
|
2634
|
+
# Parse the token.
|
|
2635
|
+
# b64decode will ignore invalid characters, but does some length and
|
|
2636
|
+
# padding checks.
|
|
2637
|
+
try:
|
|
2638
|
+
data = base64.b64decode(token)
|
|
2639
|
+
except binascii.Error as e:
|
|
2640
|
+
raise ValueError(f'Malformed token: {token}') from e
|
|
2641
|
+
logger.debug(f'Token data: {data!r}')
|
|
2642
|
+
try:
|
|
2643
|
+
json_data = json.loads(data)
|
|
2644
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
2645
|
+
raise ValueError(f'Malformed token data: {data!r}') from e
|
|
2646
|
+
if not isinstance(json_data, dict):
|
|
2647
|
+
raise ValueError(f'Malformed token JSON: {json_data}')
|
|
2648
|
+
|
|
2649
|
+
if json_data.get('v') == 1:
|
|
2650
|
+
user_hash = json_data.get('user')
|
|
2651
|
+
cookie_dict = json_data['cookies']
|
|
2652
|
+
elif 'v' not in json_data:
|
|
2653
|
+
user_hash = None
|
|
2654
|
+
cookie_dict = json_data
|
|
2655
|
+
else:
|
|
2656
|
+
raise ValueError(f'Unsupported token version: {json_data.get("v")}')
|
|
2657
|
+
|
|
2658
|
+
parsed_url = urlparse.urlparse(endpoint)
|
|
2659
|
+
cookie_jar = cookiejar.MozillaCookieJar()
|
|
2660
|
+
for (name, value) in cookie_dict.items():
|
|
2661
|
+
# dict keys in JSON must be strings
|
|
2662
|
+
assert isinstance(name, str)
|
|
2663
|
+
if not isinstance(value, str):
|
|
2664
|
+
raise ValueError('Malformed token - bad key/value: '
|
|
2665
|
+
f'{name}: {value}')
|
|
2666
|
+
|
|
2667
|
+
# See CookieJar._cookie_from_cookie_tuple
|
|
2668
|
+
# oauth2proxy default is Max-Age 604800
|
|
2669
|
+
expires = int(time.time()) + 604800
|
|
2670
|
+
domain = str(parsed_url.hostname)
|
|
2671
|
+
domain_initial_dot = domain.startswith('.')
|
|
2672
|
+
secure = parsed_url.scheme == 'https'
|
|
2673
|
+
if not domain_initial_dot:
|
|
2674
|
+
domain = '.' + domain
|
|
2675
|
+
|
|
2676
|
+
cookie_jar.set_cookie(
|
|
2677
|
+
cookiejar.Cookie(
|
|
2678
|
+
version=0,
|
|
2679
|
+
name=name,
|
|
2680
|
+
value=value,
|
|
2681
|
+
port=None,
|
|
2682
|
+
port_specified=False,
|
|
2683
|
+
domain=domain,
|
|
2684
|
+
domain_specified=True,
|
|
2685
|
+
domain_initial_dot=domain_initial_dot,
|
|
2686
|
+
path='',
|
|
2687
|
+
path_specified=False,
|
|
2688
|
+
secure=secure,
|
|
2689
|
+
expires=expires,
|
|
2690
|
+
discard=False,
|
|
2691
|
+
comment=None,
|
|
2692
|
+
comment_url=None,
|
|
2693
|
+
rest=dict(),
|
|
2694
|
+
))
|
|
2695
|
+
|
|
2696
|
+
# Now that the cookies are parsed, save them to the cookie jar.
|
|
2697
|
+
server_common.set_api_cookie_jar(cookie_jar)
|
|
2698
|
+
|
|
2699
|
+
# Set the user hash in the local file.
|
|
2700
|
+
# If the server already has a token for this user set it to the local
|
|
2701
|
+
# file, otherwise use the new user hash.
|
|
2702
|
+
if (api_server_info.user is not None and
|
|
2703
|
+
api_server_info.user.get('id') is not None):
|
|
2704
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2705
|
+
else:
|
|
2706
|
+
_set_user_hash(user_hash)
|
|
2707
|
+
else:
|
|
2708
|
+
# Check if basic auth is enabled
|
|
2709
|
+
if api_server_info.basic_auth_enabled:
|
|
2710
|
+
if api_server_info.user is None:
|
|
2711
|
+
with ux_utils.print_exception_no_traceback():
|
|
2712
|
+
raise ValueError(
|
|
2713
|
+
'Basic auth is enabled but no valid user is found')
|
|
2714
|
+
|
|
2715
|
+
# Set the user hash in the local file.
|
|
2716
|
+
if api_server_info.user is not None:
|
|
2717
|
+
_set_user_hash(api_server_info.user.get('id'))
|
|
2718
|
+
|
|
2719
|
+
# Set the endpoint in the config file
|
|
2720
|
+
_save_config_updates(endpoint=endpoint)
|
|
2721
|
+
dashboard_url = server_common.get_dashboard_url(endpoint)
|
|
2722
|
+
|
|
2723
|
+
# see https://github.com/python/mypy/issues/5107 on why
|
|
2724
|
+
# typing is disabled on this line
|
|
2725
|
+
server_common.get_api_server_status.cache_clear() # type: ignore
|
|
2726
|
+
# After successful authentication, check server health again to get user
|
|
2727
|
+
# identity
|
|
2728
|
+
server_status, final_api_server_info = server_common.check_server_healthy(
|
|
2729
|
+
endpoint)
|
|
2730
|
+
_show_logged_in_message(endpoint, dashboard_url, final_api_server_info.user,
|
|
2731
|
+
server_status)
|
|
2732
|
+
|
|
2733
|
+
|
|
2734
|
+
@usage_lib.entrypoint
|
|
2735
|
+
@annotations.client_api
|
|
2736
|
+
def api_logout() -> None:
|
|
2737
|
+
"""Logout of the API server.
|
|
2738
|
+
|
|
2739
|
+
Clears all cookies and settings stored in ~/.sky/config.yaml"""
|
|
2740
|
+
_check_endpoint_in_env_var(is_login=False)
|
|
2741
|
+
|
|
2742
|
+
if server_common.is_api_server_local():
|
|
2743
|
+
with ux_utils.print_exception_no_traceback():
|
|
2744
|
+
raise RuntimeError('Local api server cannot be logged out. '
|
|
2745
|
+
'Use `sky api stop` instead.')
|
|
2746
|
+
|
|
2747
|
+
# no need to clear cookies if it doesn't exist.
|
|
2748
|
+
server_common.set_api_cookie_jar(cookiejar.MozillaCookieJar(),
|
|
2749
|
+
create_if_not_exists=False)
|
|
2750
|
+
_clear_api_server_config()
|
|
2751
|
+
logger.info(f'{colorama.Fore.GREEN}Logged out of SkyPilot API server.'
|
|
2752
|
+
f'{colorama.Style.RESET_ALL}')
|