skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +22 -6
- sky/adaptors/aws.py +25 -7
- sky/adaptors/common.py +24 -1
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/do.py +8 -2
- sky/adaptors/hyperbolic.py +8 -0
- sky/adaptors/kubernetes.py +149 -18
- sky/adaptors/nebius.py +170 -17
- sky/adaptors/primeintellect.py +1 -0
- sky/adaptors/runpod.py +68 -0
- sky/adaptors/seeweb.py +167 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +187 -4
- sky/authentication.py +179 -225
- sky/backends/__init__.py +4 -2
- sky/backends/backend.py +22 -9
- sky/backends/backend_utils.py +1299 -380
- sky/backends/cloud_vm_ray_backend.py +1715 -518
- sky/backends/docker_utils.py +1 -1
- sky/backends/local_docker_backend.py +11 -6
- sky/backends/wheel_utils.py +37 -9
- sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
- sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
- sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
- sky/{clouds/service_catalog → catalog}/common.py +89 -48
- sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
- sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +30 -40
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +42 -15
- sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
- sky/catalog/data_fetchers/fetch_nebius.py +335 -0
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
- sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
- sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
- sky/catalog/hyperbolic_catalog.py +136 -0
- sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
- sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
- sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
- sky/catalog/primeintellect_catalog.py +95 -0
- sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
- sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
- sky/catalog/seeweb_catalog.py +184 -0
- sky/catalog/shadeform_catalog.py +165 -0
- sky/catalog/ssh_catalog.py +167 -0
- sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
- sky/check.py +491 -203
- sky/cli.py +5 -6005
- sky/client/{cli.py → cli/command.py} +2477 -1885
- sky/client/cli/deprecation_utils.py +99 -0
- sky/client/cli/flags.py +359 -0
- sky/client/cli/table_utils.py +320 -0
- sky/client/common.py +70 -32
- sky/client/oauth.py +82 -0
- sky/client/sdk.py +1203 -297
- sky/client/sdk_async.py +833 -0
- sky/client/service_account_auth.py +47 -0
- sky/cloud_stores.py +73 -0
- sky/clouds/__init__.py +13 -0
- sky/clouds/aws.py +358 -93
- sky/clouds/azure.py +105 -83
- sky/clouds/cloud.py +127 -36
- sky/clouds/cudo.py +68 -50
- sky/clouds/do.py +66 -48
- sky/clouds/fluidstack.py +63 -44
- sky/clouds/gcp.py +339 -110
- sky/clouds/hyperbolic.py +293 -0
- sky/clouds/ibm.py +70 -49
- sky/clouds/kubernetes.py +563 -162
- sky/clouds/lambda_cloud.py +74 -54
- sky/clouds/nebius.py +206 -80
- sky/clouds/oci.py +88 -66
- sky/clouds/paperspace.py +61 -44
- sky/clouds/primeintellect.py +317 -0
- sky/clouds/runpod.py +164 -74
- sky/clouds/scp.py +89 -83
- sky/clouds/seeweb.py +466 -0
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +263 -0
- sky/clouds/utils/aws_utils.py +10 -4
- sky/clouds/utils/gcp_utils.py +87 -11
- sky/clouds/utils/oci_utils.py +38 -14
- sky/clouds/utils/scp_utils.py +177 -124
- sky/clouds/vast.py +99 -77
- sky/clouds/vsphere.py +51 -40
- sky/core.py +349 -139
- sky/dag.py +15 -0
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-74503c8e80fd253b.js +6 -0
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
- sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
- sky/dashboard/out/_next/static/chunks/3785.ad6adaa2a0fa9768.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
- sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
- sky/dashboard/out/_next/static/chunks/4725.a830b5c9e7867c92.js +1 -0
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
- sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
- sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
- sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
- sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
- sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
- sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
- sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
- sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
- sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
- sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
- sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
- sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a37d2063af475a1c.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters-d44859594e6f8064.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-6edeb7d06032adfc.js +21 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-479dde13399cf270.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-5ab3b907622cf0fe.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c5a3eeee1c218af1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces-22b23febb3e89ce1.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -0
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -0
- sky/dashboard/out/infra.html +1 -0
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -0
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -0
- sky/dashboard/out/volumes.html +1 -0
- sky/dashboard/out/workspace/new.html +1 -0
- sky/dashboard/out/workspaces/[name].html +1 -0
- sky/dashboard/out/workspaces.html +1 -0
- sky/data/data_utils.py +137 -1
- sky/data/mounting_utils.py +269 -84
- sky/data/storage.py +1451 -1807
- sky/data/storage_utils.py +43 -57
- sky/exceptions.py +132 -2
- sky/execution.py +206 -63
- sky/global_user_state.py +2374 -586
- sky/jobs/__init__.py +5 -0
- sky/jobs/client/sdk.py +242 -65
- sky/jobs/client/sdk_async.py +143 -0
- sky/jobs/constants.py +9 -8
- sky/jobs/controller.py +839 -277
- sky/jobs/file_content_utils.py +80 -0
- sky/jobs/log_gc.py +201 -0
- sky/jobs/recovery_strategy.py +398 -152
- sky/jobs/scheduler.py +315 -189
- sky/jobs/server/core.py +829 -255
- sky/jobs/server/server.py +156 -115
- sky/jobs/server/utils.py +136 -0
- sky/jobs/state.py +2092 -701
- sky/jobs/utils.py +1242 -160
- sky/logs/__init__.py +21 -0
- sky/logs/agent.py +108 -0
- sky/logs/aws.py +243 -0
- sky/logs/gcp.py +91 -0
- sky/metrics/__init__.py +0 -0
- sky/metrics/utils.py +443 -0
- sky/models.py +78 -1
- sky/optimizer.py +164 -70
- sky/provision/__init__.py +90 -4
- sky/provision/aws/config.py +147 -26
- sky/provision/aws/instance.py +135 -50
- sky/provision/azure/instance.py +10 -5
- sky/provision/common.py +13 -1
- sky/provision/cudo/cudo_machine_type.py +1 -1
- sky/provision/cudo/cudo_utils.py +14 -8
- sky/provision/cudo/cudo_wrapper.py +72 -71
- sky/provision/cudo/instance.py +10 -6
- sky/provision/do/instance.py +10 -6
- sky/provision/do/utils.py +4 -3
- sky/provision/docker_utils.py +114 -23
- sky/provision/fluidstack/instance.py +13 -8
- sky/provision/gcp/__init__.py +1 -0
- sky/provision/gcp/config.py +301 -19
- sky/provision/gcp/constants.py +218 -0
- sky/provision/gcp/instance.py +36 -8
- sky/provision/gcp/instance_utils.py +18 -4
- sky/provision/gcp/volume_utils.py +247 -0
- sky/provision/hyperbolic/__init__.py +12 -0
- sky/provision/hyperbolic/config.py +10 -0
- sky/provision/hyperbolic/instance.py +437 -0
- sky/provision/hyperbolic/utils.py +373 -0
- sky/provision/instance_setup.py +93 -14
- sky/provision/kubernetes/__init__.py +5 -0
- sky/provision/kubernetes/config.py +9 -52
- sky/provision/kubernetes/constants.py +17 -0
- sky/provision/kubernetes/instance.py +789 -247
- sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
- sky/provision/kubernetes/network.py +27 -17
- sky/provision/kubernetes/network_utils.py +40 -43
- sky/provision/kubernetes/utils.py +1192 -531
- sky/provision/kubernetes/volume.py +282 -0
- sky/provision/lambda_cloud/instance.py +22 -16
- sky/provision/nebius/constants.py +50 -0
- sky/provision/nebius/instance.py +19 -6
- sky/provision/nebius/utils.py +196 -91
- sky/provision/oci/instance.py +10 -5
- sky/provision/paperspace/instance.py +10 -7
- sky/provision/paperspace/utils.py +1 -1
- sky/provision/primeintellect/__init__.py +10 -0
- sky/provision/primeintellect/config.py +11 -0
- sky/provision/primeintellect/instance.py +454 -0
- sky/provision/primeintellect/utils.py +398 -0
- sky/provision/provisioner.py +110 -36
- sky/provision/runpod/__init__.py +5 -0
- sky/provision/runpod/instance.py +27 -6
- sky/provision/runpod/utils.py +51 -18
- sky/provision/runpod/volume.py +180 -0
- sky/provision/scp/__init__.py +15 -0
- sky/provision/scp/config.py +93 -0
- sky/provision/scp/instance.py +531 -0
- sky/provision/seeweb/__init__.py +11 -0
- sky/provision/seeweb/config.py +13 -0
- sky/provision/seeweb/instance.py +807 -0
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/ssh/__init__.py +18 -0
- sky/provision/vast/instance.py +13 -8
- sky/provision/vast/utils.py +10 -7
- sky/provision/vsphere/common/vim_utils.py +1 -2
- sky/provision/vsphere/instance.py +15 -10
- sky/provision/vsphere/vsphere_utils.py +9 -19
- sky/py.typed +0 -0
- sky/resources.py +844 -118
- sky/schemas/__init__.py +0 -0
- sky/schemas/api/__init__.py +0 -0
- sky/schemas/api/responses.py +225 -0
- sky/schemas/db/README +4 -0
- sky/schemas/db/env.py +90 -0
- sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
- sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
- sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
- sky/schemas/db/global_user_state/004_is_managed.py +34 -0
- sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
- sky/schemas/db/global_user_state/006_provision_log.py +41 -0
- sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
- sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
- sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
- sky/schemas/db/script.py.mako +28 -0
- sky/schemas/db/serve_state/001_initial_schema.py +67 -0
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
- sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
- sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/generated/__init__.py +0 -0
- sky/schemas/generated/autostopv1_pb2.py +36 -0
- sky/schemas/generated/autostopv1_pb2.pyi +43 -0
- sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
- sky/schemas/generated/jobsv1_pb2.py +86 -0
- sky/schemas/generated/jobsv1_pb2.pyi +254 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +74 -0
- sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
- sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
- sky/schemas/generated/servev1_pb2.py +58 -0
- sky/schemas/generated/servev1_pb2.pyi +115 -0
- sky/schemas/generated/servev1_pb2_grpc.py +322 -0
- sky/serve/autoscalers.py +357 -5
- sky/serve/client/impl.py +310 -0
- sky/serve/client/sdk.py +47 -139
- sky/serve/client/sdk_async.py +130 -0
- sky/serve/constants.py +10 -8
- sky/serve/controller.py +64 -19
- sky/serve/load_balancer.py +106 -60
- sky/serve/load_balancing_policies.py +115 -1
- sky/serve/replica_managers.py +273 -162
- sky/serve/serve_rpc_utils.py +179 -0
- sky/serve/serve_state.py +554 -251
- sky/serve/serve_utils.py +733 -220
- sky/serve/server/core.py +66 -711
- sky/serve/server/impl.py +1093 -0
- sky/serve/server/server.py +21 -18
- sky/serve/service.py +133 -48
- sky/serve/service_spec.py +135 -16
- sky/serve/spot_placer.py +3 -0
- sky/server/auth/__init__.py +0 -0
- sky/server/auth/authn.py +50 -0
- sky/server/auth/loopback.py +38 -0
- sky/server/auth/oauth2_proxy.py +200 -0
- sky/server/common.py +475 -181
- sky/server/config.py +81 -23
- sky/server/constants.py +44 -6
- sky/server/daemons.py +229 -0
- sky/server/html/token_page.html +185 -0
- sky/server/metrics.py +160 -0
- sky/server/requests/executor.py +528 -138
- sky/server/requests/payloads.py +351 -17
- sky/server/requests/preconditions.py +21 -17
- sky/server/requests/process.py +112 -29
- sky/server/requests/request_names.py +120 -0
- sky/server/requests/requests.py +817 -224
- sky/server/requests/serializers/decoders.py +82 -31
- sky/server/requests/serializers/encoders.py +140 -22
- sky/server/requests/threads.py +106 -0
- sky/server/rest.py +417 -0
- sky/server/server.py +1290 -284
- sky/server/state.py +20 -0
- sky/server/stream_utils.py +345 -57
- sky/server/uvicorn.py +217 -3
- sky/server/versions.py +270 -0
- sky/setup_files/MANIFEST.in +5 -0
- sky/setup_files/alembic.ini +156 -0
- sky/setup_files/dependencies.py +136 -31
- sky/setup_files/setup.py +44 -42
- sky/sky_logging.py +102 -5
- sky/skylet/attempt_skylet.py +1 -0
- sky/skylet/autostop_lib.py +129 -8
- sky/skylet/configs.py +27 -20
- sky/skylet/constants.py +171 -19
- sky/skylet/events.py +105 -21
- sky/skylet/job_lib.py +335 -104
- sky/skylet/log_lib.py +297 -18
- sky/skylet/log_lib.pyi +44 -1
- sky/skylet/ray_patches/__init__.py +17 -3
- sky/skylet/ray_patches/autoscaler.py.diff +18 -0
- sky/skylet/ray_patches/cli.py.diff +19 -0
- sky/skylet/ray_patches/command_runner.py.diff +17 -0
- sky/skylet/ray_patches/log_monitor.py.diff +20 -0
- sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
- sky/skylet/ray_patches/updater.py.diff +18 -0
- sky/skylet/ray_patches/worker.py.diff +41 -0
- sky/skylet/services.py +564 -0
- sky/skylet/skylet.py +63 -4
- sky/skylet/subprocess_daemon.py +103 -29
- sky/skypilot_config.py +506 -99
- sky/ssh_node_pools/__init__.py +1 -0
- sky/ssh_node_pools/core.py +135 -0
- sky/ssh_node_pools/server.py +233 -0
- sky/task.py +621 -137
- sky/templates/aws-ray.yml.j2 +10 -3
- sky/templates/azure-ray.yml.j2 +1 -1
- sky/templates/do-ray.yml.j2 +1 -1
- sky/templates/gcp-ray.yml.j2 +57 -0
- sky/templates/hyperbolic-ray.yml.j2 +67 -0
- sky/templates/jobs-controller.yaml.j2 +27 -24
- sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
- sky/templates/kubernetes-ray.yml.j2 +607 -51
- sky/templates/lambda-ray.yml.j2 +1 -1
- sky/templates/nebius-ray.yml.j2 +33 -12
- sky/templates/paperspace-ray.yml.j2 +1 -1
- sky/templates/primeintellect-ray.yml.j2 +71 -0
- sky/templates/runpod-ray.yml.j2 +9 -1
- sky/templates/scp-ray.yml.j2 +3 -50
- sky/templates/seeweb-ray.yml.j2 +108 -0
- sky/templates/shadeform-ray.yml.j2 +72 -0
- sky/templates/sky-serve-controller.yaml.j2 +22 -2
- sky/templates/websocket_proxy.py +178 -18
- sky/usage/usage_lib.py +18 -11
- sky/users/__init__.py +0 -0
- sky/users/model.conf +15 -0
- sky/users/permission.py +387 -0
- sky/users/rbac.py +121 -0
- sky/users/server.py +720 -0
- sky/users/token_service.py +218 -0
- sky/utils/accelerator_registry.py +34 -5
- sky/utils/admin_policy_utils.py +84 -38
- sky/utils/annotations.py +16 -5
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/auth_utils.py +153 -0
- sky/utils/benchmark_utils.py +60 -0
- sky/utils/cli_utils/status_utils.py +159 -86
- sky/utils/cluster_utils.py +31 -9
- sky/utils/command_runner.py +354 -68
- sky/utils/command_runner.pyi +93 -3
- sky/utils/common.py +35 -8
- sky/utils/common_utils.py +310 -87
- sky/utils/config_utils.py +87 -5
- sky/utils/context.py +402 -0
- sky/utils/context_utils.py +222 -0
- sky/utils/controller_utils.py +264 -89
- sky/utils/dag_utils.py +31 -12
- sky/utils/db/__init__.py +0 -0
- sky/utils/db/db_utils.py +470 -0
- sky/utils/db/migration_utils.py +133 -0
- sky/utils/directory_utils.py +12 -0
- sky/utils/env_options.py +13 -0
- sky/utils/git.py +567 -0
- sky/utils/git_clone.sh +460 -0
- sky/utils/infra_utils.py +195 -0
- sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
- sky/utils/kubernetes/config_map_utils.py +133 -0
- sky/utils/kubernetes/create_cluster.sh +13 -27
- sky/utils/kubernetes/delete_cluster.sh +10 -7
- sky/utils/kubernetes/deploy_remote_cluster.py +1299 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
- sky/utils/kubernetes/generate_kind_config.py +6 -66
- sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
- sky/utils/kubernetes/gpu_labeler.py +5 -5
- sky/utils/kubernetes/kubernetes_deploy_utils.py +354 -47
- sky/utils/kubernetes/ssh-tunnel.sh +379 -0
- sky/utils/kubernetes/ssh_utils.py +221 -0
- sky/utils/kubernetes_enums.py +8 -15
- sky/utils/lock_events.py +94 -0
- sky/utils/locks.py +368 -0
- sky/utils/log_utils.py +300 -6
- sky/utils/perf_utils.py +22 -0
- sky/utils/resource_checker.py +298 -0
- sky/utils/resources_utils.py +249 -32
- sky/utils/rich_utils.py +213 -37
- sky/utils/schemas.py +905 -147
- sky/utils/serialize_utils.py +16 -0
- sky/utils/status_lib.py +10 -0
- sky/utils/subprocess_utils.py +38 -15
- sky/utils/tempstore.py +70 -0
- sky/utils/timeline.py +24 -52
- sky/utils/ux_utils.py +84 -15
- sky/utils/validator.py +11 -1
- sky/utils/volume.py +86 -0
- sky/utils/yaml_utils.py +111 -0
- sky/volumes/__init__.py +13 -0
- sky/volumes/client/__init__.py +0 -0
- sky/volumes/client/sdk.py +149 -0
- sky/volumes/server/__init__.py +0 -0
- sky/volumes/server/core.py +258 -0
- sky/volumes/server/server.py +122 -0
- sky/volumes/volume.py +212 -0
- sky/workspaces/__init__.py +0 -0
- sky/workspaces/core.py +655 -0
- sky/workspaces/server.py +101 -0
- sky/workspaces/utils.py +56 -0
- skypilot_nightly-1.0.0.dev20251107.dist-info/METADATA +675 -0
- skypilot_nightly-1.0.0.dev20251107.dist-info/RECORD +594 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +1 -1
- sky/benchmark/benchmark_state.py +0 -256
- sky/benchmark/benchmark_utils.py +0 -641
- sky/clouds/service_catalog/constants.py +0 -7
- sky/dashboard/out/_next/static/LksQgChY5izXjokL3LcEu/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
- sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
- sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
- sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
- sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
- sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
- sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
- sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
- sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
- sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
- sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
- sky/jobs/dashboard/dashboard.py +0 -223
- sky/jobs/dashboard/static/favicon.ico +0 -0
- sky/jobs/dashboard/templates/index.html +0 -831
- sky/jobs/server/dashboard_utils.py +0 -69
- sky/skylet/providers/scp/__init__.py +0 -2
- sky/skylet/providers/scp/config.py +0 -149
- sky/skylet/providers/scp/node_provider.py +0 -578
- sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
- sky/utils/db_utils.py +0 -100
- sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
- skypilot_nightly-1.0.0.dev20250509.dist-info/METADATA +0 -361
- skypilot_nightly-1.0.0.dev20250509.dist-info/RECORD +0 -396
- /sky/{clouds/service_catalog → catalog}/config.py +0 -0
- /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
- /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
- /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/server/server.py
CHANGED
|
@@ -2,55 +2,88 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import asyncio
|
|
5
|
+
import base64
|
|
6
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
7
|
import contextlib
|
|
6
|
-
import dataclasses
|
|
7
8
|
import datetime
|
|
8
|
-
import
|
|
9
|
+
from enum import IntEnum
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
9
12
|
import multiprocessing
|
|
10
13
|
import os
|
|
11
14
|
import pathlib
|
|
15
|
+
import posixpath
|
|
12
16
|
import re
|
|
17
|
+
import resource
|
|
13
18
|
import shutil
|
|
19
|
+
import struct
|
|
14
20
|
import sys
|
|
15
|
-
|
|
21
|
+
import threading
|
|
22
|
+
import traceback
|
|
23
|
+
from typing import Dict, List, Literal, Optional, Set, Tuple
|
|
16
24
|
import uuid
|
|
17
25
|
import zipfile
|
|
18
26
|
|
|
19
27
|
import aiofiles
|
|
28
|
+
import anyio
|
|
20
29
|
import fastapi
|
|
30
|
+
from fastapi import responses as fastapi_responses
|
|
21
31
|
from fastapi.middleware import cors
|
|
22
32
|
import starlette.middleware.base
|
|
33
|
+
import uvloop
|
|
23
34
|
|
|
24
35
|
import sky
|
|
36
|
+
from sky import catalog
|
|
25
37
|
from sky import check as sky_check
|
|
26
38
|
from sky import clouds
|
|
27
39
|
from sky import core
|
|
28
40
|
from sky import exceptions
|
|
29
41
|
from sky import execution
|
|
30
42
|
from sky import global_user_state
|
|
43
|
+
from sky import models
|
|
31
44
|
from sky import sky_logging
|
|
32
|
-
from sky.clouds import service_catalog
|
|
33
45
|
from sky.data import storage_utils
|
|
46
|
+
from sky.jobs import utils as managed_job_utils
|
|
34
47
|
from sky.jobs.server import server as jobs_rest
|
|
48
|
+
from sky.metrics import utils as metrics_utils
|
|
49
|
+
from sky.provision import metadata_utils
|
|
35
50
|
from sky.provision.kubernetes import utils as kubernetes_utils
|
|
51
|
+
from sky.schemas.api import responses
|
|
36
52
|
from sky.serve.server import server as serve_rest
|
|
37
53
|
from sky.server import common
|
|
38
54
|
from sky.server import config as server_config
|
|
39
55
|
from sky.server import constants as server_constants
|
|
56
|
+
from sky.server import daemons
|
|
57
|
+
from sky.server import metrics
|
|
58
|
+
from sky.server import state
|
|
40
59
|
from sky.server import stream_utils
|
|
60
|
+
from sky.server import versions
|
|
61
|
+
from sky.server.auth import authn
|
|
62
|
+
from sky.server.auth import loopback
|
|
63
|
+
from sky.server.auth import oauth2_proxy
|
|
41
64
|
from sky.server.requests import executor
|
|
42
65
|
from sky.server.requests import payloads
|
|
43
66
|
from sky.server.requests import preconditions
|
|
67
|
+
from sky.server.requests import request_names
|
|
44
68
|
from sky.server.requests import requests as requests_lib
|
|
45
69
|
from sky.skylet import constants
|
|
70
|
+
from sky.ssh_node_pools import server as ssh_node_pools_rest
|
|
46
71
|
from sky.usage import usage_lib
|
|
72
|
+
from sky.users import permission
|
|
73
|
+
from sky.users import server as users_rest
|
|
47
74
|
from sky.utils import admin_policy_utils
|
|
48
75
|
from sky.utils import common as common_lib
|
|
49
76
|
from sky.utils import common_utils
|
|
77
|
+
from sky.utils import context
|
|
78
|
+
from sky.utils import context_utils
|
|
50
79
|
from sky.utils import dag_utils
|
|
51
|
-
from sky.utils import
|
|
80
|
+
from sky.utils import perf_utils
|
|
52
81
|
from sky.utils import status_lib
|
|
53
82
|
from sky.utils import subprocess_utils
|
|
83
|
+
from sky.utils import ux_utils
|
|
84
|
+
from sky.utils.db import db_utils
|
|
85
|
+
from sky.volumes.server import server as volumes_rest
|
|
86
|
+
from sky.workspaces import server as workspaces_rest
|
|
54
87
|
|
|
55
88
|
# pylint: disable=ungrouped-imports
|
|
56
89
|
if sys.version_info >= (3, 10):
|
|
@@ -60,31 +93,8 @@ else:
|
|
|
60
93
|
|
|
61
94
|
P = ParamSpec('P')
|
|
62
95
|
|
|
96
|
+
_SERVER_USER_HASH_KEY = 'server_user_hash'
|
|
63
97
|
|
|
64
|
-
def _add_timestamp_prefix_for_server_logs() -> None:
|
|
65
|
-
server_logger = sky_logging.init_logger('sky.server')
|
|
66
|
-
# Clear existing handlers first to prevent duplicates
|
|
67
|
-
server_logger.handlers.clear()
|
|
68
|
-
# Disable propagation to avoid the root logger of SkyPilot being affected
|
|
69
|
-
server_logger.propagate = False
|
|
70
|
-
# Add date prefix to the log message printed by loggers under
|
|
71
|
-
# server.
|
|
72
|
-
stream_handler = logging.StreamHandler(sys.stdout)
|
|
73
|
-
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
74
|
-
stream_handler.setLevel(logging.DEBUG)
|
|
75
|
-
else:
|
|
76
|
-
stream_handler.setLevel(logging.INFO)
|
|
77
|
-
stream_handler.flush = sys.stdout.flush # type: ignore
|
|
78
|
-
stream_handler.setFormatter(sky_logging.FORMATTER)
|
|
79
|
-
server_logger.addHandler(stream_handler)
|
|
80
|
-
# Add date prefix to the log message printed by uvicorn.
|
|
81
|
-
for name in ['uvicorn', 'uvicorn.access']:
|
|
82
|
-
uvicorn_logger = logging.getLogger(name)
|
|
83
|
-
uvicorn_logger.handlers.clear()
|
|
84
|
-
uvicorn_logger.addHandler(stream_handler)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
_add_timestamp_prefix_for_server_logs()
|
|
88
98
|
logger = sky_logging.init_logger(__name__)
|
|
89
99
|
|
|
90
100
|
# TODO(zhwu): Streaming requests, such log tailing after sky launch or sky logs,
|
|
@@ -92,11 +102,72 @@ logger = sky_logging.init_logger(__name__)
|
|
|
92
102
|
# response will block other requests from being processed.
|
|
93
103
|
|
|
94
104
|
|
|
105
|
+
def _basic_auth_401_response(content: str):
|
|
106
|
+
"""Return a 401 response with basic auth realm."""
|
|
107
|
+
return fastapi.responses.JSONResponse(
|
|
108
|
+
status_code=401,
|
|
109
|
+
headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
|
|
110
|
+
content=content)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _try_set_basic_auth_user(request: fastapi.Request):
|
|
114
|
+
auth_header = request.headers.get('authorization')
|
|
115
|
+
if not auth_header or not auth_header.lower().startswith('basic '):
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
# Check username and password
|
|
119
|
+
encoded = auth_header.split(' ', 1)[1]
|
|
120
|
+
try:
|
|
121
|
+
decoded = base64.b64decode(encoded).decode()
|
|
122
|
+
username, password = decoded.split(':', 1)
|
|
123
|
+
except Exception: # pylint: disable=broad-except
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
users = global_user_state.get_user_by_name(username)
|
|
127
|
+
if not users:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
for user in users:
|
|
131
|
+
if not user.name or not user.password:
|
|
132
|
+
continue
|
|
133
|
+
username_encoded = username.encode('utf8')
|
|
134
|
+
db_username_encoded = user.name.encode('utf8')
|
|
135
|
+
if (username_encoded == db_username_encoded and
|
|
136
|
+
common.crypt_ctx.verify(password, user.password)):
|
|
137
|
+
request.state.auth_user = user
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
142
|
+
"""Middleware to handle RBAC."""
|
|
143
|
+
|
|
144
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
145
|
+
# TODO(hailong): should have a list of paths
|
|
146
|
+
# that are not checked for RBAC
|
|
147
|
+
if (request.url.path.startswith('/dashboard/') or
|
|
148
|
+
request.url.path.startswith('/api/')):
|
|
149
|
+
return await call_next(request)
|
|
150
|
+
|
|
151
|
+
auth_user = request.state.auth_user
|
|
152
|
+
if auth_user is None:
|
|
153
|
+
return await call_next(request)
|
|
154
|
+
|
|
155
|
+
permission_service = permission.permission_service
|
|
156
|
+
# Check the role permission
|
|
157
|
+
if permission_service.check_endpoint_permission(auth_user.id,
|
|
158
|
+
request.url.path,
|
|
159
|
+
request.method):
|
|
160
|
+
return fastapi.responses.JSONResponse(
|
|
161
|
+
status_code=403, content={'detail': 'Forbidden'})
|
|
162
|
+
|
|
163
|
+
return await call_next(request)
|
|
164
|
+
|
|
165
|
+
|
|
95
166
|
class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
96
167
|
"""Middleware to add a request ID to each request."""
|
|
97
168
|
|
|
98
169
|
async def dispatch(self, request: fastapi.Request, call_next):
|
|
99
|
-
request_id =
|
|
170
|
+
request_id = requests_lib.get_new_request_id()
|
|
100
171
|
request.state.request_id = request_id
|
|
101
172
|
response = await call_next(request)
|
|
102
173
|
# TODO(syang): remove X-Request-ID when v0.10.0 is released.
|
|
@@ -105,6 +176,238 @@ class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
|
105
176
|
return response
|
|
106
177
|
|
|
107
178
|
|
|
179
|
+
def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
|
|
180
|
+
header_name = os.environ.get(constants.ENV_VAR_SERVER_AUTH_USER_HEADER,
|
|
181
|
+
'X-Auth-Request-Email')
|
|
182
|
+
if header_name not in request.headers:
|
|
183
|
+
return None
|
|
184
|
+
user_name = request.headers[header_name]
|
|
185
|
+
user_hash = hashlib.md5(
|
|
186
|
+
user_name.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
|
|
187
|
+
return models.User(id=user_hash, name=user_name)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class InitializeRequestAuthUserMiddleware(
|
|
191
|
+
starlette.middleware.base.BaseHTTPMiddleware):
|
|
192
|
+
|
|
193
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
194
|
+
# Make sure that request.state.auth_user is set. Otherwise, we may get a
|
|
195
|
+
# KeyError while trying to read it.
|
|
196
|
+
request.state.auth_user = None
|
|
197
|
+
return await call_next(request)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
201
|
+
"""Middleware to handle HTTP Basic Auth."""
|
|
202
|
+
|
|
203
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
204
|
+
if managed_job_utils.is_consolidation_mode(
|
|
205
|
+
) and loopback.is_loopback_request(request):
|
|
206
|
+
return await call_next(request)
|
|
207
|
+
|
|
208
|
+
if request.url.path.startswith('/api/health'):
|
|
209
|
+
# Try to set the auth user from basic auth
|
|
210
|
+
_try_set_basic_auth_user(request)
|
|
211
|
+
return await call_next(request)
|
|
212
|
+
|
|
213
|
+
auth_header = request.headers.get('authorization')
|
|
214
|
+
if not auth_header:
|
|
215
|
+
return _basic_auth_401_response('Authentication required')
|
|
216
|
+
|
|
217
|
+
# Only handle basic auth
|
|
218
|
+
if not auth_header.lower().startswith('basic '):
|
|
219
|
+
return _basic_auth_401_response('Invalid authentication method')
|
|
220
|
+
|
|
221
|
+
# Check username and password
|
|
222
|
+
encoded = auth_header.split(' ', 1)[1]
|
|
223
|
+
try:
|
|
224
|
+
decoded = base64.b64decode(encoded).decode()
|
|
225
|
+
username, password = decoded.split(':', 1)
|
|
226
|
+
except Exception: # pylint: disable=broad-except
|
|
227
|
+
return _basic_auth_401_response('Invalid basic auth')
|
|
228
|
+
|
|
229
|
+
users = global_user_state.get_user_by_name(username)
|
|
230
|
+
if not users:
|
|
231
|
+
return _basic_auth_401_response('Invalid credentials')
|
|
232
|
+
|
|
233
|
+
valid_user = False
|
|
234
|
+
for user in users:
|
|
235
|
+
if not user.name or not user.password:
|
|
236
|
+
continue
|
|
237
|
+
username_encoded = username.encode('utf8')
|
|
238
|
+
db_username_encoded = user.name.encode('utf8')
|
|
239
|
+
if (username_encoded == db_username_encoded and
|
|
240
|
+
common.crypt_ctx.verify(password, user.password)):
|
|
241
|
+
valid_user = True
|
|
242
|
+
request.state.auth_user = user
|
|
243
|
+
await authn.override_user_info_in_request_body(request, user)
|
|
244
|
+
break
|
|
245
|
+
if not valid_user:
|
|
246
|
+
return _basic_auth_401_response('Invalid credentials')
|
|
247
|
+
|
|
248
|
+
return await call_next(request)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
252
|
+
"""Middleware to handle Bearer Token Auth (Service Accounts)."""
|
|
253
|
+
|
|
254
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
255
|
+
"""Make sure correct bearer token auth is present.
|
|
256
|
+
|
|
257
|
+
1. If the request has the X-Skypilot-Auth-Mode: token header, it must
|
|
258
|
+
have a valid bearer token.
|
|
259
|
+
2. For backwards compatibility, if the request has a Bearer token
|
|
260
|
+
beginning with "sky_" (even if X-Skypilot-Auth-Mode is not present),
|
|
261
|
+
it must be a valid token.
|
|
262
|
+
3. If X-Skypilot-Auth-Mode is not set to "token", and there is no Bearer
|
|
263
|
+
token beginning with "sky_", allow the request to continue.
|
|
264
|
+
|
|
265
|
+
In conjunction with an auth proxy, the idea is to make the auth proxy
|
|
266
|
+
bypass requests with bearer tokens, instead setting the
|
|
267
|
+
X-Skypilot-Auth-Mode header. The auth proxy should either validate the
|
|
268
|
+
auth or set the header X-Skypilot-Auth-Mode: token.
|
|
269
|
+
"""
|
|
270
|
+
has_skypilot_auth_header = (
|
|
271
|
+
request.headers.get('X-Skypilot-Auth-Mode') == 'token')
|
|
272
|
+
auth_header = request.headers.get('authorization')
|
|
273
|
+
has_bearer_token_starting_with_sky = (
|
|
274
|
+
auth_header and auth_header.lower().startswith('bearer ') and
|
|
275
|
+
auth_header.split(' ', 1)[1].startswith('sky_'))
|
|
276
|
+
|
|
277
|
+
if (not has_skypilot_auth_header and
|
|
278
|
+
not has_bearer_token_starting_with_sky):
|
|
279
|
+
# This is case #3 above. We do not need to validate the request.
|
|
280
|
+
# No Bearer token, continue with normal processing (OAuth2 cookies,
|
|
281
|
+
# etc.)
|
|
282
|
+
return await call_next(request)
|
|
283
|
+
# After this point, all requests must be validated.
|
|
284
|
+
|
|
285
|
+
if auth_header is None:
|
|
286
|
+
return fastapi.responses.JSONResponse(
|
|
287
|
+
status_code=401, content={'detail': 'Authentication required'})
|
|
288
|
+
|
|
289
|
+
# Extract token
|
|
290
|
+
split_header = auth_header.split(' ', 1)
|
|
291
|
+
if split_header[0].lower() != 'bearer':
|
|
292
|
+
return fastapi.responses.JSONResponse(
|
|
293
|
+
status_code=401,
|
|
294
|
+
content={'detail': 'Invalid authentication method'})
|
|
295
|
+
sa_token = split_header[1]
|
|
296
|
+
|
|
297
|
+
# Handle SkyPilot service account tokens
|
|
298
|
+
return await self._handle_service_account_token(request, sa_token,
|
|
299
|
+
call_next)
|
|
300
|
+
|
|
301
|
+
async def _handle_service_account_token(self, request: fastapi.Request,
|
|
302
|
+
sa_token: str, call_next):
|
|
303
|
+
"""Handle SkyPilot service account tokens."""
|
|
304
|
+
# Check if service account tokens are enabled
|
|
305
|
+
sa_enabled = os.environ.get(constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
|
|
306
|
+
'false').lower()
|
|
307
|
+
if sa_enabled != 'true':
|
|
308
|
+
return fastapi.responses.JSONResponse(
|
|
309
|
+
status_code=401,
|
|
310
|
+
content={'detail': 'Service account authentication disabled'})
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
# Import here to avoid circular imports
|
|
314
|
+
# pylint: disable=import-outside-toplevel
|
|
315
|
+
from sky.users.token_service import token_service
|
|
316
|
+
|
|
317
|
+
# Verify and decode JWT token
|
|
318
|
+
payload = token_service.verify_token(sa_token)
|
|
319
|
+
|
|
320
|
+
if payload is None:
|
|
321
|
+
logger.warning('Service account token verification failed')
|
|
322
|
+
return fastapi.responses.JSONResponse(
|
|
323
|
+
status_code=401,
|
|
324
|
+
content={
|
|
325
|
+
'detail': 'Invalid or expired service account token'
|
|
326
|
+
})
|
|
327
|
+
|
|
328
|
+
# Extract user information from JWT payload
|
|
329
|
+
user_id = payload.get('sub')
|
|
330
|
+
user_name = payload.get('name')
|
|
331
|
+
token_id = payload.get('token_id')
|
|
332
|
+
|
|
333
|
+
if not user_id or not token_id:
|
|
334
|
+
logger.warning(
|
|
335
|
+
'Invalid token payload: missing user_id or token_id')
|
|
336
|
+
return fastapi.responses.JSONResponse(
|
|
337
|
+
status_code=401,
|
|
338
|
+
content={'detail': 'Invalid token payload'})
|
|
339
|
+
|
|
340
|
+
# Verify user still exists in database
|
|
341
|
+
user_info = global_user_state.get_user(user_id)
|
|
342
|
+
if user_info is None:
|
|
343
|
+
logger.warning(
|
|
344
|
+
f'Service account user {user_id} no longer exists')
|
|
345
|
+
return fastapi.responses.JSONResponse(
|
|
346
|
+
status_code=401,
|
|
347
|
+
content={'detail': 'Service account user no longer exists'})
|
|
348
|
+
|
|
349
|
+
# Update last used timestamp for token tracking
|
|
350
|
+
try:
|
|
351
|
+
global_user_state.update_service_account_token_last_used(
|
|
352
|
+
token_id)
|
|
353
|
+
except Exception as e: # pylint: disable=broad-except
|
|
354
|
+
logger.debug(f'Failed to update token last used time: {e}')
|
|
355
|
+
|
|
356
|
+
# Set the authenticated user
|
|
357
|
+
auth_user = models.User(id=user_id,
|
|
358
|
+
name=user_name or user_info.name)
|
|
359
|
+
request.state.auth_user = auth_user
|
|
360
|
+
|
|
361
|
+
# Override user info in request body for service account requests
|
|
362
|
+
await authn.override_user_info_in_request_body(request, auth_user)
|
|
363
|
+
|
|
364
|
+
logger.debug(f'Authenticated service account: {user_id}')
|
|
365
|
+
|
|
366
|
+
except Exception as e: # pylint: disable=broad-except
|
|
367
|
+
logger.error(f'Service account authentication failed: {e}',
|
|
368
|
+
exc_info=True)
|
|
369
|
+
return fastapi.responses.JSONResponse(
|
|
370
|
+
status_code=401,
|
|
371
|
+
content={
|
|
372
|
+
'detail': f'Service account authentication failed: {str(e)}'
|
|
373
|
+
})
|
|
374
|
+
|
|
375
|
+
return await call_next(request)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
379
|
+
"""Middleware to handle auth proxy."""
|
|
380
|
+
|
|
381
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
382
|
+
auth_user = _get_auth_user_header(request)
|
|
383
|
+
|
|
384
|
+
if request.state.auth_user is not None:
|
|
385
|
+
# Previous middleware is trusted more than this middleware. For
|
|
386
|
+
# instance, a client could set the Authorization and the
|
|
387
|
+
# X-Auth-Request-Email header. In that case, the auth proxy will be
|
|
388
|
+
# skipped and we should rely on the Bearer token to authenticate the
|
|
389
|
+
# user - but that means the user could set X-Auth-Request-Email to
|
|
390
|
+
# whatever the user wants. We should thus ignore it.
|
|
391
|
+
if auth_user is not None:
|
|
392
|
+
logger.debug('Warning: ignoring auth proxy header since the '
|
|
393
|
+
'auth user was already set.')
|
|
394
|
+
return await call_next(request)
|
|
395
|
+
|
|
396
|
+
# Add user to database if auth_user is present
|
|
397
|
+
if auth_user is not None:
|
|
398
|
+
newly_added = global_user_state.add_or_update_user(auth_user)
|
|
399
|
+
if newly_added:
|
|
400
|
+
permission.permission_service.add_user_if_not_exists(
|
|
401
|
+
auth_user.id)
|
|
402
|
+
|
|
403
|
+
# Store user info in request.state for access by GET endpoints
|
|
404
|
+
if auth_user is not None:
|
|
405
|
+
request.state.auth_user = auth_user
|
|
406
|
+
|
|
407
|
+
await authn.override_user_info_in_request_body(request, auth_user)
|
|
408
|
+
return await call_next(request)
|
|
409
|
+
|
|
410
|
+
|
|
108
411
|
# Default expiration time for upload ids before cleanup.
|
|
109
412
|
_DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
|
|
110
413
|
# Key: (upload_id, user_hash), Value: the time when the upload id needs to be
|
|
@@ -134,21 +437,74 @@ async def cleanup_upload_ids():
|
|
|
134
437
|
upload_ids_to_cleanup.pop((upload_id, user_hash))
|
|
135
438
|
|
|
136
439
|
|
|
440
|
+
async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
|
|
441
|
+
interval: float = 0.1) -> None:
|
|
442
|
+
target = loop.time() + interval
|
|
443
|
+
|
|
444
|
+
pid = str(os.getpid())
|
|
445
|
+
lag_threshold = perf_utils.get_loop_lag_threshold()
|
|
446
|
+
|
|
447
|
+
def tick():
|
|
448
|
+
nonlocal target
|
|
449
|
+
now = loop.time()
|
|
450
|
+
lag = max(0.0, now - target)
|
|
451
|
+
if lag_threshold is not None and lag > lag_threshold:
|
|
452
|
+
logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
|
|
453
|
+
f'{lag_threshold} seconds.')
|
|
454
|
+
metrics_utils.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
|
|
455
|
+
pid=pid).observe(lag)
|
|
456
|
+
target = now + interval
|
|
457
|
+
loop.call_at(target, tick)
|
|
458
|
+
|
|
459
|
+
loop.call_at(target, tick)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def schedule_on_boot_check_async():
|
|
463
|
+
try:
|
|
464
|
+
await executor.schedule_request_async(
|
|
465
|
+
request_id='skypilot-server-on-boot-check',
|
|
466
|
+
request_name=request_names.RequestName.CHECK,
|
|
467
|
+
request_body=payloads.CheckBody(),
|
|
468
|
+
func=sky_check.check,
|
|
469
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
470
|
+
is_skypilot_system=True,
|
|
471
|
+
)
|
|
472
|
+
except exceptions.RequestAlreadyExistsError:
|
|
473
|
+
# Lifespan will be executed in each uvicorn worker process, we
|
|
474
|
+
# can safely ignore the error if the task is already scheduled.
|
|
475
|
+
logger.debug('Request skypilot-server-on-boot-check already exists.')
|
|
476
|
+
|
|
477
|
+
|
|
137
478
|
@contextlib.asynccontextmanager
|
|
138
479
|
async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
|
|
139
480
|
"""FastAPI lifespan context manager."""
|
|
140
481
|
del app # unused
|
|
141
482
|
# Startup: Run background tasks
|
|
142
|
-
for event in
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
483
|
+
for event in daemons.INTERNAL_REQUEST_DAEMONS:
|
|
484
|
+
if event.should_skip():
|
|
485
|
+
continue
|
|
486
|
+
try:
|
|
487
|
+
await executor.schedule_request_async(
|
|
488
|
+
request_id=event.id,
|
|
489
|
+
request_name=event.name,
|
|
490
|
+
request_body=payloads.RequestBody(),
|
|
491
|
+
func=event.run_event,
|
|
492
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
493
|
+
is_skypilot_system=True,
|
|
494
|
+
# Request deamon should be retried if the process pool is
|
|
495
|
+
# broken.
|
|
496
|
+
retryable=True,
|
|
497
|
+
)
|
|
498
|
+
except exceptions.RequestAlreadyExistsError:
|
|
499
|
+
# Lifespan will be executed in each uvicorn worker process, we
|
|
500
|
+
# can safely ignore the error if the task is already scheduled.
|
|
501
|
+
logger.debug(f'Request {event.id} already exists.')
|
|
502
|
+
await schedule_on_boot_check_async()
|
|
151
503
|
asyncio.create_task(cleanup_upload_ids())
|
|
504
|
+
if metrics_utils.METRICS_ENABLED:
|
|
505
|
+
# Start monitoring the event loop lag in each server worker
|
|
506
|
+
# event loop (process).
|
|
507
|
+
asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
|
|
152
508
|
yield
|
|
153
509
|
# Shutdown: Add any cleanup code here if needed
|
|
154
510
|
|
|
@@ -166,8 +522,99 @@ class InternalDashboardPrefixMiddleware(
|
|
|
166
522
|
return await call_next(request)
|
|
167
523
|
|
|
168
524
|
|
|
525
|
+
class CacheControlStaticMiddleware(starlette.middleware.base.BaseHTTPMiddleware
|
|
526
|
+
):
|
|
527
|
+
"""Middleware to add cache control headers to static files."""
|
|
528
|
+
|
|
529
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
530
|
+
if request.url.path.startswith('/dashboard/_next'):
|
|
531
|
+
response = await call_next(request)
|
|
532
|
+
response.headers['Cache-Control'] = 'max-age=3600'
|
|
533
|
+
return response
|
|
534
|
+
return await call_next(request)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
538
|
+
"""Middleware to check the path of requests."""
|
|
539
|
+
|
|
540
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
541
|
+
if request.url.path.startswith('/dashboard/'):
|
|
542
|
+
# If the requested path is not relative to the expected directory,
|
|
543
|
+
# then the user is attempting path traversal, so deny the request.
|
|
544
|
+
parent = pathlib.Path('/dashboard')
|
|
545
|
+
request_path = pathlib.Path(posixpath.normpath(request.url.path))
|
|
546
|
+
if not _is_relative_to(request_path, parent):
|
|
547
|
+
return fastapi.responses.JSONResponse(
|
|
548
|
+
status_code=403, content={'detail': 'Forbidden'})
|
|
549
|
+
return await call_next(request)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class GracefulShutdownMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
553
|
+
"""Middleware to control requests when server is shutting down."""
|
|
554
|
+
|
|
555
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
556
|
+
if state.get_block_requests():
|
|
557
|
+
# Allow /api/ paths to continue, which are critical to operate
|
|
558
|
+
# on-going requests but will not submit new requests.
|
|
559
|
+
if not request.url.path.startswith('/api/'):
|
|
560
|
+
# Client will retry on 503 error.
|
|
561
|
+
return fastapi.responses.JSONResponse(
|
|
562
|
+
status_code=503,
|
|
563
|
+
content={
|
|
564
|
+
'detail': 'Server is shutting down, '
|
|
565
|
+
'please try again later.'
|
|
566
|
+
})
|
|
567
|
+
|
|
568
|
+
return await call_next(request)
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class APIVersionMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
572
|
+
"""Middleware to add API version to the request."""
|
|
573
|
+
|
|
574
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
575
|
+
version_info = versions.check_compatibility_at_server(request.headers)
|
|
576
|
+
# Bypass version handling for backward compatibility with clients prior
|
|
577
|
+
# to v0.11.0, the client will check the version in the body of
|
|
578
|
+
# /api/health response and hint an upgrade.
|
|
579
|
+
# TODO(aylei): remove this after v0.13.0 is released.
|
|
580
|
+
if version_info is None:
|
|
581
|
+
return await call_next(request)
|
|
582
|
+
if version_info.error is None:
|
|
583
|
+
versions.set_remote_api_version(version_info.api_version)
|
|
584
|
+
versions.set_remote_version(version_info.version)
|
|
585
|
+
response = await call_next(request)
|
|
586
|
+
else:
|
|
587
|
+
response = fastapi.responses.JSONResponse(
|
|
588
|
+
status_code=400,
|
|
589
|
+
content={
|
|
590
|
+
'error': common.ApiServerStatus.VERSION_MISMATCH.value,
|
|
591
|
+
'message': version_info.error,
|
|
592
|
+
})
|
|
593
|
+
response.headers[server_constants.API_VERSION_HEADER] = str(
|
|
594
|
+
server_constants.API_VERSION)
|
|
595
|
+
response.headers[server_constants.VERSION_HEADER] = \
|
|
596
|
+
versions.get_local_readable_version()
|
|
597
|
+
return response
|
|
598
|
+
|
|
599
|
+
|
|
169
600
|
app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
|
|
601
|
+
# Middleware wraps in the order defined here. E.g., given
|
|
602
|
+
# app.add_middleware(Middleware1)
|
|
603
|
+
# app.add_middleware(Middleware2)
|
|
604
|
+
# app.add_middleware(Middleware3)
|
|
605
|
+
# The effect will be like:
|
|
606
|
+
# Middleware3(Middleware2(Middleware1(request)))
|
|
607
|
+
# If MiddlewareN does something like print(n); call_next(); print(n), you'll get
|
|
608
|
+
# 3; 2; 1; <request>; 1; 2; 3
|
|
609
|
+
# Use environment variable to make the metrics middleware optional.
|
|
610
|
+
if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
|
|
611
|
+
app.add_middleware(metrics.PrometheusMiddleware)
|
|
612
|
+
app.add_middleware(APIVersionMiddleware)
|
|
613
|
+
app.add_middleware(RBACMiddleware)
|
|
170
614
|
app.add_middleware(InternalDashboardPrefixMiddleware)
|
|
615
|
+
app.add_middleware(GracefulShutdownMiddleware)
|
|
616
|
+
app.add_middleware(PathCleanMiddleware)
|
|
617
|
+
app.add_middleware(CacheControlStaticMiddleware)
|
|
171
618
|
app.add_middleware(
|
|
172
619
|
cors.CORSMiddleware,
|
|
173
620
|
# TODO(zhwu): in production deployment, we should restrict the allowed
|
|
@@ -176,20 +623,119 @@ app.add_middleware(
|
|
|
176
623
|
allow_credentials=True,
|
|
177
624
|
allow_methods=['*'],
|
|
178
625
|
allow_headers=['*'],
|
|
179
|
-
# TODO(syang): remove X-Request-ID when v0.10.0 is released.
|
|
626
|
+
# TODO(syang): remove X-Request-ID \when v0.10.0 is released.
|
|
180
627
|
expose_headers=['X-Request-ID', 'X-Skypilot-Request-ID'])
|
|
628
|
+
# The order of all the authentication-related middleware is important.
|
|
629
|
+
# RBACMiddleware must precede all the auth middleware, so it can access
|
|
630
|
+
# request.state.auth_user.
|
|
631
|
+
app.add_middleware(RBACMiddleware)
|
|
632
|
+
# Authentication based on oauth2-proxy.
|
|
633
|
+
app.add_middleware(oauth2_proxy.OAuth2ProxyMiddleware)
|
|
634
|
+
# AuthProxyMiddleware should precede BasicAuthMiddleware and
|
|
635
|
+
# BearerTokenMiddleware, since it should be skipped if either of those set the
|
|
636
|
+
# auth user.
|
|
637
|
+
app.add_middleware(AuthProxyMiddleware)
|
|
638
|
+
enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
|
|
639
|
+
if str(enable_basic_auth).lower() == 'true':
|
|
640
|
+
app.add_middleware(BasicAuthMiddleware)
|
|
641
|
+
# Bearer token middleware should always be present to handle service account
|
|
642
|
+
# authentication
|
|
643
|
+
app.add_middleware(BearerTokenMiddleware)
|
|
644
|
+
# InitializeRequestAuthUserMiddleware must be the last added middleware so that
|
|
645
|
+
# request.state.auth_user is always set, but can be overridden by the auth
|
|
646
|
+
# middleware above.
|
|
647
|
+
app.add_middleware(InitializeRequestAuthUserMiddleware)
|
|
181
648
|
app.add_middleware(RequestIDMiddleware)
|
|
182
649
|
app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
|
|
183
650
|
app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
|
|
651
|
+
app.include_router(users_rest.router, prefix='/users', tags=['users'])
|
|
652
|
+
app.include_router(workspaces_rest.router,
|
|
653
|
+
prefix='/workspaces',
|
|
654
|
+
tags=['workspaces'])
|
|
655
|
+
app.include_router(volumes_rest.router, prefix='/volumes', tags=['volumes'])
|
|
656
|
+
app.include_router(ssh_node_pools_rest.router,
|
|
657
|
+
prefix='/ssh_node_pools',
|
|
658
|
+
tags=['ssh_node_pools'])
|
|
659
|
+
# increase the resource limit for the server
|
|
660
|
+
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
661
|
+
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
662
|
+
|
|
663
|
+
# Increase the limit of files we can open to our hard limit. This fixes bugs
|
|
664
|
+
# where we can not aquire file locks or open enough logs and the API server
|
|
665
|
+
# crashes. On Mac, the hard limit is 9,223,372,036,854,775,807.
|
|
666
|
+
# TODO(luca) figure out what to do if we need to open more than 2^63 files.
|
|
667
|
+
try:
|
|
668
|
+
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
669
|
+
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
670
|
+
except Exception: # pylint: disable=broad-except
|
|
671
|
+
pass # no issue, we will warn the user later if its too low
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@app.exception_handler(exceptions.ConcurrentWorkerExhaustedError)
|
|
675
|
+
def handle_concurrent_worker_exhausted_error(
|
|
676
|
+
request: fastapi.Request, e: exceptions.ConcurrentWorkerExhaustedError):
|
|
677
|
+
del request # request is not used
|
|
678
|
+
# Print detailed error message to server log
|
|
679
|
+
logger.error('Concurrent worker exhausted: '
|
|
680
|
+
f'{common_utils.format_exception(e)}')
|
|
681
|
+
with ux_utils.enable_traceback():
|
|
682
|
+
logger.error(f' Traceback: {traceback.format_exc()}')
|
|
683
|
+
# Return human readable error message to client
|
|
684
|
+
return fastapi.responses.JSONResponse(
|
|
685
|
+
status_code=503,
|
|
686
|
+
content={
|
|
687
|
+
'detail':
|
|
688
|
+
('The server has exhausted its concurrent worker limit. '
|
|
689
|
+
'Please try again or scale the server if the load persists.')
|
|
690
|
+
})
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
@app.get('/token')
|
|
694
|
+
async def token(request: fastapi.Request,
|
|
695
|
+
local_port: Optional[int] = None) -> fastapi.responses.Response:
|
|
696
|
+
del local_port # local_port is used by the served js, but ignored by server
|
|
697
|
+
user = _get_auth_user_header(request)
|
|
698
|
+
|
|
699
|
+
token_data = {
|
|
700
|
+
'v': 1, # Token version number, bump for backwards incompatible.
|
|
701
|
+
'user': user.id if user is not None else None,
|
|
702
|
+
'cookies': request.cookies,
|
|
703
|
+
}
|
|
704
|
+
# Use base64 encoding to avoid having to escape anything in the HTML.
|
|
705
|
+
json_bytes = json.dumps(token_data).encode('utf-8')
|
|
706
|
+
base64_str = base64.b64encode(json_bytes).decode('utf-8')
|
|
707
|
+
|
|
708
|
+
html_dir = pathlib.Path(__file__).parent / 'html'
|
|
709
|
+
token_page_path = html_dir / 'token_page.html'
|
|
710
|
+
try:
|
|
711
|
+
with open(token_page_path, 'r', encoding='utf-8') as f:
|
|
712
|
+
html_content = f.read()
|
|
713
|
+
except FileNotFoundError as e:
|
|
714
|
+
raise fastapi.HTTPException(
|
|
715
|
+
status_code=500, detail='Token page template not found.') from e
|
|
716
|
+
|
|
717
|
+
user_info_string = f'Logged in as {user.name}' if user is not None else ''
|
|
718
|
+
html_content = html_content.replace(
|
|
719
|
+
'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER',
|
|
720
|
+
base64_str).replace('USER_PLACEHOLDER', user_info_string)
|
|
721
|
+
|
|
722
|
+
return fastapi.responses.HTMLResponse(
|
|
723
|
+
content=html_content,
|
|
724
|
+
headers={
|
|
725
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
726
|
+
# X-Accel-Buffering: no is useful for preventing buffering issues
|
|
727
|
+
# with some reverse proxies.
|
|
728
|
+
'X-Accel-Buffering': 'no'
|
|
729
|
+
})
|
|
184
730
|
|
|
185
731
|
|
|
186
732
|
@app.post('/check')
|
|
187
733
|
async def check(request: fastapi.Request,
|
|
188
734
|
check_body: payloads.CheckBody) -> None:
|
|
189
735
|
"""Checks enabled clouds."""
|
|
190
|
-
executor.
|
|
736
|
+
await executor.schedule_request_async(
|
|
191
737
|
request_id=request.state.request_id,
|
|
192
|
-
request_name=
|
|
738
|
+
request_name=request_names.RequestName.CHECK,
|
|
193
739
|
request_body=check_body,
|
|
194
740
|
func=sky_check.check,
|
|
195
741
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -197,12 +743,15 @@ async def check(request: fastapi.Request,
|
|
|
197
743
|
|
|
198
744
|
|
|
199
745
|
@app.get('/enabled_clouds')
|
|
200
|
-
async def enabled_clouds(request: fastapi.Request
|
|
746
|
+
async def enabled_clouds(request: fastapi.Request,
|
|
747
|
+
workspace: Optional[str] = None,
|
|
748
|
+
expand: bool = False) -> None:
|
|
201
749
|
"""Gets enabled clouds on the server."""
|
|
202
|
-
executor.
|
|
750
|
+
await executor.schedule_request_async(
|
|
203
751
|
request_id=request.state.request_id,
|
|
204
|
-
request_name=
|
|
205
|
-
request_body=payloads.
|
|
752
|
+
request_name=request_names.RequestName.ENABLED_CLOUDS,
|
|
753
|
+
request_body=payloads.EnabledCloudsBody(workspace=workspace,
|
|
754
|
+
expand=expand),
|
|
206
755
|
func=core.enabled_clouds,
|
|
207
756
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
208
757
|
)
|
|
@@ -214,9 +763,10 @@ async def realtime_kubernetes_gpu_availability(
|
|
|
214
763
|
realtime_gpu_availability_body: payloads.RealtimeGpuAvailabilityRequestBody
|
|
215
764
|
) -> None:
|
|
216
765
|
"""Gets real-time Kubernetes GPU availability."""
|
|
217
|
-
executor.
|
|
766
|
+
await executor.schedule_request_async(
|
|
218
767
|
request_id=request.state.request_id,
|
|
219
|
-
request_name=
|
|
768
|
+
request_name=request_names.RequestName.
|
|
769
|
+
REALTIME_KUBERNETES_GPU_AVAILABILITY,
|
|
220
770
|
request_body=realtime_gpu_availability_body,
|
|
221
771
|
func=core.realtime_kubernetes_gpu_availability,
|
|
222
772
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -229,9 +779,9 @@ async def kubernetes_node_info(
|
|
|
229
779
|
kubernetes_node_info_body: payloads.KubernetesNodeInfoRequestBody
|
|
230
780
|
) -> None:
|
|
231
781
|
"""Gets Kubernetes nodes information and hints."""
|
|
232
|
-
executor.
|
|
782
|
+
await executor.schedule_request_async(
|
|
233
783
|
request_id=request.state.request_id,
|
|
234
|
-
request_name=
|
|
784
|
+
request_name=request_names.RequestName.KUBERNETES_NODE_INFO,
|
|
235
785
|
request_body=kubernetes_node_info_body,
|
|
236
786
|
func=kubernetes_utils.get_kubernetes_node_info,
|
|
237
787
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -241,9 +791,9 @@ async def kubernetes_node_info(
|
|
|
241
791
|
@app.get('/status_kubernetes')
|
|
242
792
|
async def status_kubernetes(request: fastapi.Request) -> None:
|
|
243
793
|
"""Gets Kubernetes status."""
|
|
244
|
-
executor.
|
|
794
|
+
await executor.schedule_request_async(
|
|
245
795
|
request_id=request.state.request_id,
|
|
246
|
-
request_name=
|
|
796
|
+
request_name=request_names.RequestName.STATUS_KUBERNETES,
|
|
247
797
|
request_body=payloads.RequestBody(),
|
|
248
798
|
func=core.status_kubernetes,
|
|
249
799
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -255,11 +805,11 @@ async def list_accelerators(
|
|
|
255
805
|
request: fastapi.Request,
|
|
256
806
|
list_accelerator_counts_body: payloads.ListAcceleratorsBody) -> None:
|
|
257
807
|
"""Gets list of accelerators from cloud catalog."""
|
|
258
|
-
executor.
|
|
808
|
+
await executor.schedule_request_async(
|
|
259
809
|
request_id=request.state.request_id,
|
|
260
|
-
request_name=
|
|
810
|
+
request_name=request_names.RequestName.LIST_ACCELERATORS,
|
|
261
811
|
request_body=list_accelerator_counts_body,
|
|
262
|
-
func=
|
|
812
|
+
func=catalog.list_accelerators,
|
|
263
813
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
264
814
|
)
|
|
265
815
|
|
|
@@ -270,11 +820,11 @@ async def list_accelerator_counts(
|
|
|
270
820
|
list_accelerator_counts_body: payloads.ListAcceleratorCountsBody
|
|
271
821
|
) -> None:
|
|
272
822
|
"""Gets list of accelerator counts from cloud catalog."""
|
|
273
|
-
executor.
|
|
823
|
+
await executor.schedule_request_async(
|
|
274
824
|
request_id=request.state.request_id,
|
|
275
|
-
request_name=
|
|
825
|
+
request_name=request_names.RequestName.LIST_ACCELERATOR_COUNTS,
|
|
276
826
|
request_body=list_accelerator_counts_body,
|
|
277
|
-
func=
|
|
827
|
+
func=catalog.list_accelerator_counts,
|
|
278
828
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
279
829
|
)
|
|
280
830
|
|
|
@@ -292,25 +842,33 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
|
|
|
292
842
|
# pairs.
|
|
293
843
|
logger.debug(f'Validating tasks: {validate_body.dag}')
|
|
294
844
|
|
|
845
|
+
context.initialize()
|
|
846
|
+
ctx = context.get()
|
|
847
|
+
assert ctx is not None
|
|
848
|
+
# TODO(aylei): generalize this to all requests without a db record.
|
|
849
|
+
ctx.override_envs(validate_body.env_vars)
|
|
850
|
+
|
|
295
851
|
def validate_dag(dag: dag_utils.dag_lib.Dag):
|
|
296
852
|
# TODO: Admin policy may contain arbitrary code, which may be expensive
|
|
297
853
|
# to run and may block the server thread. However, moving it into the
|
|
298
854
|
# executor adds a ~150ms penalty on the local API server because of
|
|
299
855
|
# added RTTs. For now, we stick to doing the validation inline in the
|
|
300
856
|
# server thread.
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
857
|
+
with admin_policy_utils.apply_and_use_config_in_current_request(
|
|
858
|
+
dag,
|
|
859
|
+
request_name=request_names.AdminPolicyRequestName.VALIDATE,
|
|
860
|
+
request_options=validate_body.get_request_options()) as dag:
|
|
861
|
+
dag.resolve_and_validate_volumes()
|
|
862
|
+
# Skip validating workdir and file_mounts, as those need to be
|
|
863
|
+
# validated after the files are uploaded to the SkyPilot API server
|
|
864
|
+
# with `upload_mounts_to_api_server`.
|
|
865
|
+
dag.validate(skip_file_mounts=True, skip_workdir=True)
|
|
307
866
|
|
|
308
867
|
try:
|
|
309
868
|
dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
|
|
310
|
-
loop = asyncio.get_running_loop()
|
|
311
869
|
# Apply admin policy and validate DAG is blocking, run it in a separate
|
|
312
870
|
# thread executor to avoid blocking the uvicorn event loop.
|
|
313
|
-
await
|
|
871
|
+
await context_utils.to_thread(validate_dag, dag)
|
|
314
872
|
except Exception as e: # pylint: disable=broad-except
|
|
315
873
|
raise fastapi.HTTPException(
|
|
316
874
|
status_code=400, detail=exceptions.serialize_exception(e)) from e
|
|
@@ -320,9 +878,9 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
|
|
|
320
878
|
async def optimize(optimize_body: payloads.OptimizeBody,
|
|
321
879
|
request: fastapi.Request) -> None:
|
|
322
880
|
"""Optimizes the user's DAG."""
|
|
323
|
-
executor.
|
|
881
|
+
await executor.schedule_request_async(
|
|
324
882
|
request_id=request.state.request_id,
|
|
325
|
-
request_name=
|
|
883
|
+
request_name=request_names.RequestName.OPTIMIZE,
|
|
326
884
|
request_body=optimize_body,
|
|
327
885
|
ignore_return_value=True,
|
|
328
886
|
func=core.optimize,
|
|
@@ -350,16 +908,30 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
350
908
|
chunk_index: The chunk index, starting from 0.
|
|
351
909
|
total_chunks: The total number of chunks.
|
|
352
910
|
"""
|
|
911
|
+
# Field _body would be set if the request body has been received, fail fast
|
|
912
|
+
# to surface potential memory issues, i.e. catch the issue in our smoke
|
|
913
|
+
# test.
|
|
914
|
+
# pylint: disable=protected-access
|
|
915
|
+
if hasattr(request, '_body'):
|
|
916
|
+
raise fastapi.HTTPException(
|
|
917
|
+
status_code=500,
|
|
918
|
+
detail='Upload request body should not be received before streaming'
|
|
919
|
+
)
|
|
353
920
|
# Add the upload id to the cleanup list.
|
|
354
921
|
upload_ids_to_cleanup[(upload_id,
|
|
355
922
|
user_hash)] = (datetime.datetime.now() +
|
|
356
923
|
_DEFAULT_UPLOAD_EXPIRATION_TIME)
|
|
924
|
+
# For anonymous access, use the user hash from client
|
|
925
|
+
user_id = user_hash
|
|
926
|
+
if request.state.auth_user is not None:
|
|
927
|
+
# Otherwise, the authenticated identity should be used.
|
|
928
|
+
user_id = request.state.auth_user.id
|
|
357
929
|
|
|
358
930
|
# TODO(SKY-1271): We need to double check security of uploading zip file.
|
|
359
931
|
client_file_mounts_dir = (
|
|
360
|
-
common.API_SERVER_CLIENT_DIR.expanduser().resolve() /
|
|
932
|
+
common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_id /
|
|
361
933
|
'file_mounts')
|
|
362
|
-
client_file_mounts_dir.mkdir(parents=True, exist_ok=True)
|
|
934
|
+
await anyio.Path(client_file_mounts_dir).mkdir(parents=True, exist_ok=True)
|
|
363
935
|
|
|
364
936
|
# Check upload_id to be a valid SkyPilot run_timestamp appended with 8 hex
|
|
365
937
|
# characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
|
|
@@ -382,7 +954,7 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
382
954
|
zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
|
|
383
955
|
else:
|
|
384
956
|
chunk_dir = client_file_mounts_dir / upload_id
|
|
385
|
-
chunk_dir.mkdir(parents=True, exist_ok=True)
|
|
957
|
+
await anyio.Path(chunk_dir).mkdir(parents=True, exist_ok=True)
|
|
386
958
|
zip_file_path = chunk_dir / f'part{chunk_index}.incomplete'
|
|
387
959
|
|
|
388
960
|
try:
|
|
@@ -412,8 +984,9 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
412
984
|
zip_file_path.rename(zip_file_path.with_suffix(''))
|
|
413
985
|
missing_chunks = get_missing_chunks(total_chunks)
|
|
414
986
|
if missing_chunks:
|
|
415
|
-
return payloads.UploadZipFileResponse(
|
|
416
|
-
|
|
987
|
+
return payloads.UploadZipFileResponse(
|
|
988
|
+
status=responses.UploadStatus.UPLOADING.value,
|
|
989
|
+
missing_chunks=missing_chunks)
|
|
417
990
|
zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
|
|
418
991
|
async with aiofiles.open(zip_file_path, 'wb') as zip_file:
|
|
419
992
|
for chunk in range(total_chunks):
|
|
@@ -427,10 +1000,11 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
427
1000
|
await zip_file.write(data)
|
|
428
1001
|
|
|
429
1002
|
logger.info(f'Uploaded zip file: {zip_file_path}')
|
|
430
|
-
unzip_file(zip_file_path, client_file_mounts_dir)
|
|
1003
|
+
await unzip_file(zip_file_path, client_file_mounts_dir)
|
|
431
1004
|
if total_chunks > 1:
|
|
432
|
-
shutil.rmtree
|
|
433
|
-
return payloads.UploadZipFileResponse(
|
|
1005
|
+
await context_utils.to_thread(shutil.rmtree, chunk_dir)
|
|
1006
|
+
return payloads.UploadZipFileResponse(
|
|
1007
|
+
status=responses.UploadStatus.COMPLETED.value)
|
|
434
1008
|
|
|
435
1009
|
|
|
436
1010
|
def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
|
|
@@ -443,61 +1017,69 @@ def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
|
|
|
443
1017
|
return False
|
|
444
1018
|
|
|
445
1019
|
|
|
446
|
-
def unzip_file(zip_file_path: pathlib.Path,
|
|
447
|
-
|
|
448
|
-
"""Unzips a zip file."""
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
1020
|
+
async def unzip_file(zip_file_path: pathlib.Path,
|
|
1021
|
+
client_file_mounts_dir: pathlib.Path) -> None:
|
|
1022
|
+
"""Unzips a zip file without blocking the event loop."""
|
|
1023
|
+
|
|
1024
|
+
def _do_unzip() -> None:
|
|
1025
|
+
try:
|
|
1026
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zipf:
|
|
1027
|
+
for member in zipf.infolist():
|
|
1028
|
+
# Determine the new path
|
|
1029
|
+
original_path = os.path.normpath(member.filename)
|
|
1030
|
+
new_path = client_file_mounts_dir / original_path.lstrip(
|
|
1031
|
+
'/')
|
|
1032
|
+
|
|
1033
|
+
if (member.external_attr >> 28) == 0xA:
|
|
1034
|
+
# Symlink. Read the target path and create a symlink.
|
|
1035
|
+
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1036
|
+
target = zipf.read(member).decode()
|
|
1037
|
+
assert not os.path.isabs(target), target
|
|
1038
|
+
# Since target is a relative path, we need to check that
|
|
1039
|
+
# it is under `client_file_mounts_dir` for security.
|
|
1040
|
+
full_target_path = (new_path.parent / target).resolve()
|
|
1041
|
+
if not _is_relative_to(full_target_path,
|
|
1042
|
+
client_file_mounts_dir):
|
|
1043
|
+
raise ValueError(
|
|
1044
|
+
f'Symlink target {target} leads to a '
|
|
1045
|
+
'file not in userspace. Aborted.')
|
|
1046
|
+
|
|
1047
|
+
if new_path.exists() or new_path.is_symlink():
|
|
1048
|
+
new_path.unlink(missing_ok=True)
|
|
1049
|
+
new_path.symlink_to(
|
|
1050
|
+
target,
|
|
1051
|
+
target_is_directory=member.filename.endswith('/'))
|
|
1052
|
+
continue
|
|
1053
|
+
|
|
1054
|
+
# Handle directories
|
|
1055
|
+
if member.filename.endswith('/'):
|
|
1056
|
+
new_path.mkdir(parents=True, exist_ok=True)
|
|
1057
|
+
continue
|
|
1058
|
+
|
|
1059
|
+
# Handle files
|
|
458
1060
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
continue
|
|
480
|
-
|
|
481
|
-
# Handle files
|
|
482
|
-
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
483
|
-
with zipf.open(member) as member_file, new_path.open('wb') as f:
|
|
484
|
-
# Use shutil.copyfileobj to copy files in chunks, so it does
|
|
485
|
-
# not load the entire file into memory.
|
|
486
|
-
shutil.copyfileobj(member_file, f)
|
|
487
|
-
except zipfile.BadZipFile as e:
|
|
488
|
-
logger.error(f'Bad zip file: {zip_file_path}')
|
|
489
|
-
raise fastapi.HTTPException(
|
|
490
|
-
status_code=400,
|
|
491
|
-
detail=f'Invalid zip file: {common_utils.format_exception(e)}')
|
|
492
|
-
except Exception as e:
|
|
493
|
-
logger.error(f'Error unzipping file: {zip_file_path}')
|
|
494
|
-
raise fastapi.HTTPException(
|
|
495
|
-
status_code=500,
|
|
496
|
-
detail=(f'Error unzipping file: '
|
|
497
|
-
f'{common_utils.format_exception(e)}'))
|
|
1061
|
+
with zipf.open(member) as member_file, new_path.open(
|
|
1062
|
+
'wb') as f:
|
|
1063
|
+
# Use shutil.copyfileobj to copy files in chunks,
|
|
1064
|
+
# so it does not load the entire file into memory.
|
|
1065
|
+
shutil.copyfileobj(member_file, f)
|
|
1066
|
+
except zipfile.BadZipFile as e:
|
|
1067
|
+
logger.error(f'Bad zip file: {zip_file_path}')
|
|
1068
|
+
raise fastapi.HTTPException(
|
|
1069
|
+
status_code=400,
|
|
1070
|
+
detail=f'Invalid zip file: {common_utils.format_exception(e)}')
|
|
1071
|
+
except Exception as e:
|
|
1072
|
+
logger.error(f'Error unzipping file: {zip_file_path}')
|
|
1073
|
+
raise fastapi.HTTPException(
|
|
1074
|
+
status_code=500,
|
|
1075
|
+
detail=(f'Error unzipping file: '
|
|
1076
|
+
f'{common_utils.format_exception(e)}'))
|
|
1077
|
+
finally:
|
|
1078
|
+
# Cleanup the temporary file regardless of
|
|
1079
|
+
# success/failure handling above
|
|
1080
|
+
zip_file_path.unlink(missing_ok=True)
|
|
498
1081
|
|
|
499
|
-
|
|
500
|
-
zip_file_path.unlink()
|
|
1082
|
+
await context_utils.to_thread(_do_unzip)
|
|
501
1083
|
|
|
502
1084
|
|
|
503
1085
|
@app.post('/launch')
|
|
@@ -506,13 +1088,14 @@ async def launch(launch_body: payloads.LaunchBody,
|
|
|
506
1088
|
"""Launches a cluster or task."""
|
|
507
1089
|
request_id = request.state.request_id
|
|
508
1090
|
logger.info(f'Launching request: {request_id}')
|
|
509
|
-
executor.
|
|
1091
|
+
await executor.schedule_request_async(
|
|
510
1092
|
request_id,
|
|
511
|
-
request_name=
|
|
1093
|
+
request_name=request_names.RequestName.CLUSTER_LAUNCH,
|
|
512
1094
|
request_body=launch_body,
|
|
513
1095
|
func=execution.launch,
|
|
514
1096
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
515
1097
|
request_cluster_name=launch_body.cluster_name,
|
|
1098
|
+
retryable=launch_body.retry_until_up,
|
|
516
1099
|
)
|
|
517
1100
|
|
|
518
1101
|
|
|
@@ -521,9 +1104,9 @@ async def launch(launch_body: payloads.LaunchBody,
|
|
|
521
1104
|
async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
|
|
522
1105
|
"""Executes a task on an existing cluster."""
|
|
523
1106
|
cluster_name = exec_body.cluster_name
|
|
524
|
-
executor.
|
|
1107
|
+
await executor.schedule_request_async(
|
|
525
1108
|
request_id=request.state.request_id,
|
|
526
|
-
request_name=
|
|
1109
|
+
request_name=request_names.RequestName.CLUSTER_EXEC,
|
|
527
1110
|
request_body=exec_body,
|
|
528
1111
|
func=execution.exec,
|
|
529
1112
|
precondition=preconditions.ClusterStartCompletePrecondition(
|
|
@@ -539,9 +1122,9 @@ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
|
|
|
539
1122
|
async def stop(request: fastapi.Request,
|
|
540
1123
|
stop_body: payloads.StopOrDownBody) -> None:
|
|
541
1124
|
"""Stops a cluster."""
|
|
542
|
-
executor.
|
|
1125
|
+
await executor.schedule_request_async(
|
|
543
1126
|
request_id=request.state.request_id,
|
|
544
|
-
request_name=
|
|
1127
|
+
request_name=request_names.RequestName.CLUSTER_STOP,
|
|
545
1128
|
request_body=stop_body,
|
|
546
1129
|
func=core.stop,
|
|
547
1130
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -555,9 +1138,13 @@ async def status(
|
|
|
555
1138
|
status_body: payloads.StatusBody = payloads.StatusBody()
|
|
556
1139
|
) -> None:
|
|
557
1140
|
"""Gets cluster statuses."""
|
|
558
|
-
|
|
1141
|
+
if state.get_block_requests():
|
|
1142
|
+
raise fastapi.HTTPException(
|
|
1143
|
+
status_code=503,
|
|
1144
|
+
detail='Server is shutting down, please try again later.')
|
|
1145
|
+
await executor.schedule_request_async(
|
|
559
1146
|
request_id=request.state.request_id,
|
|
560
|
-
request_name=
|
|
1147
|
+
request_name=request_names.RequestName.CLUSTER_STATUS,
|
|
561
1148
|
request_body=status_body,
|
|
562
1149
|
func=core.status,
|
|
563
1150
|
schedule_type=(requests_lib.ScheduleType.LONG if
|
|
@@ -570,9 +1157,9 @@ async def status(
|
|
|
570
1157
|
async def endpoints(request: fastapi.Request,
|
|
571
1158
|
endpoint_body: payloads.EndpointsBody) -> None:
|
|
572
1159
|
"""Gets the endpoint for a given cluster and port number (endpoint)."""
|
|
573
|
-
executor.
|
|
1160
|
+
await executor.schedule_request_async(
|
|
574
1161
|
request_id=request.state.request_id,
|
|
575
|
-
request_name=
|
|
1162
|
+
request_name=request_names.RequestName.CLUSTER_ENDPOINTS,
|
|
576
1163
|
request_body=endpoint_body,
|
|
577
1164
|
func=core.endpoints,
|
|
578
1165
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -584,9 +1171,9 @@ async def endpoints(request: fastapi.Request,
|
|
|
584
1171
|
async def down(request: fastapi.Request,
|
|
585
1172
|
down_body: payloads.StopOrDownBody) -> None:
|
|
586
1173
|
"""Tears down a cluster."""
|
|
587
|
-
executor.
|
|
1174
|
+
await executor.schedule_request_async(
|
|
588
1175
|
request_id=request.state.request_id,
|
|
589
|
-
request_name=
|
|
1176
|
+
request_name=request_names.RequestName.CLUSTER_DOWN,
|
|
590
1177
|
request_body=down_body,
|
|
591
1178
|
func=core.down,
|
|
592
1179
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -598,9 +1185,9 @@ async def down(request: fastapi.Request,
|
|
|
598
1185
|
async def start(request: fastapi.Request,
|
|
599
1186
|
start_body: payloads.StartBody) -> None:
|
|
600
1187
|
"""Restarts a cluster."""
|
|
601
|
-
executor.
|
|
1188
|
+
await executor.schedule_request_async(
|
|
602
1189
|
request_id=request.state.request_id,
|
|
603
|
-
request_name=
|
|
1190
|
+
request_name=request_names.RequestName.CLUSTER_START,
|
|
604
1191
|
request_body=start_body,
|
|
605
1192
|
func=core.start,
|
|
606
1193
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -612,9 +1199,9 @@ async def start(request: fastapi.Request,
|
|
|
612
1199
|
async def autostop(request: fastapi.Request,
|
|
613
1200
|
autostop_body: payloads.AutostopBody) -> None:
|
|
614
1201
|
"""Schedules an autostop/autodown for a cluster."""
|
|
615
|
-
executor.
|
|
1202
|
+
await executor.schedule_request_async(
|
|
616
1203
|
request_id=request.state.request_id,
|
|
617
|
-
request_name=
|
|
1204
|
+
request_name=request_names.RequestName.CLUSTER_AUTOSTOP,
|
|
618
1205
|
request_body=autostop_body,
|
|
619
1206
|
func=core.autostop,
|
|
620
1207
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -626,9 +1213,9 @@ async def autostop(request: fastapi.Request,
|
|
|
626
1213
|
async def queue(request: fastapi.Request,
|
|
627
1214
|
queue_body: payloads.QueueBody) -> None:
|
|
628
1215
|
"""Gets the job queue of a cluster."""
|
|
629
|
-
executor.
|
|
1216
|
+
await executor.schedule_request_async(
|
|
630
1217
|
request_id=request.state.request_id,
|
|
631
|
-
request_name=
|
|
1218
|
+
request_name=request_names.RequestName.CLUSTER_QUEUE,
|
|
632
1219
|
request_body=queue_body,
|
|
633
1220
|
func=core.queue,
|
|
634
1221
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -640,9 +1227,9 @@ async def queue(request: fastapi.Request,
|
|
|
640
1227
|
async def job_status(request: fastapi.Request,
|
|
641
1228
|
job_status_body: payloads.JobStatusBody) -> None:
|
|
642
1229
|
"""Gets the status of a job."""
|
|
643
|
-
executor.
|
|
1230
|
+
await executor.schedule_request_async(
|
|
644
1231
|
request_id=request.state.request_id,
|
|
645
|
-
request_name=
|
|
1232
|
+
request_name=request_names.RequestName.CLUSTER_JOB_STATUS,
|
|
646
1233
|
request_body=job_status_body,
|
|
647
1234
|
func=core.job_status,
|
|
648
1235
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -654,9 +1241,9 @@ async def job_status(request: fastapi.Request,
|
|
|
654
1241
|
async def cancel(request: fastapi.Request,
|
|
655
1242
|
cancel_body: payloads.CancelBody) -> None:
|
|
656
1243
|
"""Cancels jobs on a cluster."""
|
|
657
|
-
executor.
|
|
1244
|
+
await executor.schedule_request_async(
|
|
658
1245
|
request_id=request.state.request_id,
|
|
659
|
-
request_name=
|
|
1246
|
+
request_name=request_names.RequestName.CLUSTER_JOB_CANCEL,
|
|
660
1247
|
request_body=cancel_body,
|
|
661
1248
|
func=core.cancel,
|
|
662
1249
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -673,36 +1260,27 @@ async def logs(
|
|
|
673
1260
|
# TODO(zhwu): This should wait for the request on the cluster, e.g., async
|
|
674
1261
|
# launch, to finish, so that a user does not need to manually pull the
|
|
675
1262
|
# request status.
|
|
676
|
-
executor.
|
|
1263
|
+
executor.check_request_thread_executor_available()
|
|
1264
|
+
request_task = await executor.prepare_request_async(
|
|
677
1265
|
request_id=request.state.request_id,
|
|
678
|
-
request_name=
|
|
1266
|
+
request_name=request_names.RequestName.CLUSTER_JOB_LOGS,
|
|
679
1267
|
request_body=cluster_job_body,
|
|
680
1268
|
func=core.tail_logs,
|
|
681
|
-
# TODO(aylei): We have tail logs scheduled as SHORT request, because it
|
|
682
|
-
# should be responsive. However, it can be long running if the user's
|
|
683
|
-
# job keeps running, and we should avoid it taking the SHORT worker.
|
|
684
1269
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
685
1270
|
request_cluster_name=cluster_job_body.cluster_name,
|
|
686
1271
|
)
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
1272
|
+
task = executor.execute_request_in_coroutine(request_task)
|
|
1273
|
+
background_tasks.add_task(task.cancel)
|
|
690
1274
|
# TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
|
|
691
1275
|
# the same approach as /stream.
|
|
692
|
-
return stream_utils.
|
|
693
|
-
request_id=
|
|
1276
|
+
return stream_utils.stream_response_for_long_request(
|
|
1277
|
+
request_id=request.state.request_id,
|
|
694
1278
|
logs_path=request_task.log_path,
|
|
695
1279
|
background_tasks=background_tasks,
|
|
1280
|
+
kill_request_on_disconnect=False,
|
|
696
1281
|
)
|
|
697
1282
|
|
|
698
1283
|
|
|
699
|
-
@app.get('/users')
|
|
700
|
-
async def users() -> List[Dict[str, Any]]:
|
|
701
|
-
"""Gets all users."""
|
|
702
|
-
user_list = global_user_state.get_all_users()
|
|
703
|
-
return [user.to_dict() for user in user_list]
|
|
704
|
-
|
|
705
|
-
|
|
706
1284
|
@app.post('/download_logs')
|
|
707
1285
|
async def download_logs(
|
|
708
1286
|
request: fastapi.Request,
|
|
@@ -714,9 +1292,9 @@ async def download_logs(
|
|
|
714
1292
|
# We should reuse the original request body, so that the env vars, such as
|
|
715
1293
|
# user hash, are kept the same.
|
|
716
1294
|
cluster_jobs_body.local_dir = str(logs_dir_on_api_server)
|
|
717
|
-
executor.
|
|
1295
|
+
await executor.schedule_request_async(
|
|
718
1296
|
request_id=request.state.request_id,
|
|
719
|
-
request_name=
|
|
1297
|
+
request_name=request_names.RequestName.CLUSTER_JOB_DOWNLOAD_LOGS,
|
|
720
1298
|
request_body=cluster_jobs_body,
|
|
721
1299
|
func=core.download_logs,
|
|
722
1300
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -725,7 +1303,8 @@ async def download_logs(
|
|
|
725
1303
|
|
|
726
1304
|
|
|
727
1305
|
@app.post('/download')
|
|
728
|
-
async def download(download_body: payloads.DownloadBody
|
|
1306
|
+
async def download(download_body: payloads.DownloadBody,
|
|
1307
|
+
request: fastapi.Request) -> None:
|
|
729
1308
|
"""Downloads a folder from the cluster to the local machine."""
|
|
730
1309
|
folder_paths = [
|
|
731
1310
|
pathlib.Path(folder_path) for folder_path in download_body.folder_paths
|
|
@@ -750,11 +1329,25 @@ async def download(download_body: payloads.DownloadBody) -> None:
|
|
|
750
1329
|
logs_dir_on_api_server).expanduser().resolve() / zip_filename
|
|
751
1330
|
|
|
752
1331
|
try:
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
1332
|
+
|
|
1333
|
+
def _zip_files_and_folders(folder_paths, zip_path):
|
|
1334
|
+
folders = [
|
|
1335
|
+
str(folder_path.expanduser().resolve())
|
|
1336
|
+
for folder_path in folder_paths
|
|
1337
|
+
]
|
|
1338
|
+
# Check for optional query parameter to control zip entry structure
|
|
1339
|
+
relative = request.query_params.get('relative', 'home')
|
|
1340
|
+
if relative == 'items':
|
|
1341
|
+
# Dashboard-friendly: entries relative to selected folders
|
|
1342
|
+
storage_utils.zip_files_and_folders(folders,
|
|
1343
|
+
zip_path,
|
|
1344
|
+
relative_to_items=True)
|
|
1345
|
+
else:
|
|
1346
|
+
# CLI-friendly (default): entries with full paths for mapping
|
|
1347
|
+
storage_utils.zip_files_and_folders(folders, zip_path)
|
|
1348
|
+
|
|
1349
|
+
await context_utils.to_thread(_zip_files_and_folders, folder_paths,
|
|
1350
|
+
zip_path)
|
|
758
1351
|
|
|
759
1352
|
# Add home path to the response headers, so that the client can replace
|
|
760
1353
|
# the remote path in the zip file to the local path.
|
|
@@ -776,13 +1369,84 @@ async def download(download_body: payloads.DownloadBody) -> None:
|
|
|
776
1369
|
detail=f'Error creating zip file: {str(e)}')
|
|
777
1370
|
|
|
778
1371
|
|
|
779
|
-
|
|
780
|
-
|
|
1372
|
+
# TODO(aylei): run it asynchronously after global_user_state support async op
|
|
1373
|
+
@app.post('/provision_logs')
|
|
1374
|
+
def provision_logs(provision_logs_body: payloads.ProvisionLogsBody,
|
|
1375
|
+
follow: bool = True,
|
|
1376
|
+
tail: int = 0) -> fastapi.responses.StreamingResponse:
|
|
1377
|
+
"""Streams the provision.log for the latest launch request of a cluster."""
|
|
1378
|
+
log_path = None
|
|
1379
|
+
cluster_name = provision_logs_body.cluster_name
|
|
1380
|
+
worker = provision_logs_body.worker
|
|
1381
|
+
# stream head node logs
|
|
1382
|
+
if worker is None:
|
|
1383
|
+
# Prefer clusters table first, then cluster_history as fallback.
|
|
1384
|
+
log_path_str = global_user_state.get_cluster_provision_log_path(
|
|
1385
|
+
cluster_name)
|
|
1386
|
+
if not log_path_str:
|
|
1387
|
+
log_path_str = (
|
|
1388
|
+
global_user_state.get_cluster_history_provision_log_path(
|
|
1389
|
+
cluster_name))
|
|
1390
|
+
if not log_path_str:
|
|
1391
|
+
raise fastapi.HTTPException(
|
|
1392
|
+
status_code=404,
|
|
1393
|
+
detail=('Provision log path is not recorded for this cluster. '
|
|
1394
|
+
'Please relaunch to generate provisioning logs.'))
|
|
1395
|
+
log_path = pathlib.Path(log_path_str).expanduser().resolve()
|
|
1396
|
+
if not log_path.exists():
|
|
1397
|
+
raise fastapi.HTTPException(
|
|
1398
|
+
status_code=404,
|
|
1399
|
+
detail=f'Provision log path does not exist: {str(log_path)}')
|
|
1400
|
+
|
|
1401
|
+
# stream worker node logs
|
|
1402
|
+
else:
|
|
1403
|
+
handle = global_user_state.get_handle_from_cluster_name(cluster_name)
|
|
1404
|
+
if handle is None:
|
|
1405
|
+
raise fastapi.HTTPException(
|
|
1406
|
+
status_code=404,
|
|
1407
|
+
detail=('Cluster handle is not recorded for this cluster. '
|
|
1408
|
+
'Please relaunch to generate provisioning logs.'))
|
|
1409
|
+
# instance_ids includes head node
|
|
1410
|
+
instance_ids = handle.instance_ids
|
|
1411
|
+
if instance_ids is None:
|
|
1412
|
+
raise fastapi.HTTPException(
|
|
1413
|
+
status_code=400,
|
|
1414
|
+
detail='Instance IDs are not recorded for this cluster. '
|
|
1415
|
+
'Please relaunch to generate provisioning logs.')
|
|
1416
|
+
if worker > len(instance_ids) - 1:
|
|
1417
|
+
raise fastapi.HTTPException(
|
|
1418
|
+
status_code=400,
|
|
1419
|
+
detail=f'Worker {worker} is out of range. '
|
|
1420
|
+
f'The cluster has {len(instance_ids)} nodes.')
|
|
1421
|
+
log_path = metadata_utils.get_instance_log_dir(
|
|
1422
|
+
handle.get_cluster_name_on_cloud(), instance_ids[worker])
|
|
1423
|
+
|
|
1424
|
+
# Tail semantics: 0 means print all lines. Convert 0 -> None for streamer.
|
|
1425
|
+
effective_tail = None if tail is None or tail <= 0 else tail
|
|
1426
|
+
|
|
1427
|
+
return fastapi.responses.StreamingResponse(
|
|
1428
|
+
content=stream_utils.log_streamer(None,
|
|
1429
|
+
log_path,
|
|
1430
|
+
tail=effective_tail,
|
|
1431
|
+
follow=follow,
|
|
1432
|
+
cluster_name=cluster_name),
|
|
1433
|
+
media_type='text/plain',
|
|
1434
|
+
headers={
|
|
1435
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
1436
|
+
'X-Accel-Buffering': 'no',
|
|
1437
|
+
'Transfer-Encoding': 'chunked',
|
|
1438
|
+
},
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
|
|
1442
|
+
@app.post('/cost_report')
|
|
1443
|
+
async def cost_report(request: fastapi.Request,
|
|
1444
|
+
cost_report_body: payloads.CostReportBody) -> None:
|
|
781
1445
|
"""Gets the cost report of a cluster."""
|
|
782
|
-
executor.
|
|
1446
|
+
await executor.schedule_request_async(
|
|
783
1447
|
request_id=request.state.request_id,
|
|
784
|
-
request_name=
|
|
785
|
-
request_body=
|
|
1448
|
+
request_name=request_names.RequestName.CLUSTER_COST_REPORT,
|
|
1449
|
+
request_body=cost_report_body,
|
|
786
1450
|
func=core.cost_report,
|
|
787
1451
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
788
1452
|
)
|
|
@@ -791,9 +1455,9 @@ async def cost_report(request: fastapi.Request) -> None:
|
|
|
791
1455
|
@app.get('/storage/ls')
|
|
792
1456
|
async def storage_ls(request: fastapi.Request) -> None:
|
|
793
1457
|
"""Gets the storages."""
|
|
794
|
-
executor.
|
|
1458
|
+
await executor.schedule_request_async(
|
|
795
1459
|
request_id=request.state.request_id,
|
|
796
|
-
request_name=
|
|
1460
|
+
request_name=request_names.RequestName.STORAGE_LS,
|
|
797
1461
|
request_body=payloads.RequestBody(),
|
|
798
1462
|
func=core.storage_ls,
|
|
799
1463
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -804,9 +1468,9 @@ async def storage_ls(request: fastapi.Request) -> None:
|
|
|
804
1468
|
async def storage_delete(request: fastapi.Request,
|
|
805
1469
|
storage_body: payloads.StorageBody) -> None:
|
|
806
1470
|
"""Deletes a storage."""
|
|
807
|
-
executor.
|
|
1471
|
+
await executor.schedule_request_async(
|
|
808
1472
|
request_id=request.state.request_id,
|
|
809
|
-
request_name=
|
|
1473
|
+
request_name=request_names.RequestName.STORAGE_DELETE,
|
|
810
1474
|
request_body=storage_body,
|
|
811
1475
|
func=core.storage_delete,
|
|
812
1476
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -817,9 +1481,9 @@ async def storage_delete(request: fastapi.Request,
|
|
|
817
1481
|
async def local_up(request: fastapi.Request,
|
|
818
1482
|
local_up_body: payloads.LocalUpBody) -> None:
|
|
819
1483
|
"""Launches a Kubernetes cluster on API server."""
|
|
820
|
-
executor.
|
|
1484
|
+
await executor.schedule_request_async(
|
|
821
1485
|
request_id=request.state.request_id,
|
|
822
|
-
request_name=
|
|
1486
|
+
request_name=request_names.RequestName.LOCAL_UP,
|
|
823
1487
|
request_body=local_up_body,
|
|
824
1488
|
func=core.local_up,
|
|
825
1489
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -827,37 +1491,65 @@ async def local_up(request: fastapi.Request,
|
|
|
827
1491
|
|
|
828
1492
|
|
|
829
1493
|
@app.post('/local_down')
|
|
830
|
-
async def local_down(request: fastapi.Request
|
|
1494
|
+
async def local_down(request: fastapi.Request,
|
|
1495
|
+
local_down_body: payloads.LocalDownBody) -> None:
|
|
831
1496
|
"""Tears down the Kubernetes cluster started by local_up."""
|
|
832
|
-
executor.
|
|
1497
|
+
await executor.schedule_request_async(
|
|
833
1498
|
request_id=request.state.request_id,
|
|
834
|
-
request_name=
|
|
835
|
-
request_body=
|
|
1499
|
+
request_name=request_names.RequestName.LOCAL_DOWN,
|
|
1500
|
+
request_body=local_down_body,
|
|
836
1501
|
func=core.local_down,
|
|
837
1502
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
838
1503
|
)
|
|
839
1504
|
|
|
840
1505
|
|
|
1506
|
+
async def get_expanded_request_id(request_id: str) -> str:
|
|
1507
|
+
"""Gets the expanded request ID for a given request ID prefix."""
|
|
1508
|
+
request_tasks = await requests_lib.get_requests_async_with_prefix(
|
|
1509
|
+
request_id, fields=['request_id'])
|
|
1510
|
+
if request_tasks is None:
|
|
1511
|
+
raise fastapi.HTTPException(status_code=404,
|
|
1512
|
+
detail=f'Request {request_id!r} not found')
|
|
1513
|
+
if len(request_tasks) > 1:
|
|
1514
|
+
raise fastapi.HTTPException(status_code=400,
|
|
1515
|
+
detail=('Multiple requests found for '
|
|
1516
|
+
f'request ID prefix: {request_id}'))
|
|
1517
|
+
return request_tasks[0].request_id
|
|
1518
|
+
|
|
1519
|
+
|
|
841
1520
|
# === API server related APIs ===
|
|
842
|
-
@app.get('/api/get')
|
|
843
|
-
async def api_get(request_id: str) ->
|
|
1521
|
+
@app.get('/api/get', response_class=fastapi_responses.ORJSONResponse)
|
|
1522
|
+
async def api_get(request_id: str) -> payloads.RequestPayload:
|
|
844
1523
|
"""Gets a request with a given request ID prefix."""
|
|
1524
|
+
# Validate request_id prefix matches a single request.
|
|
1525
|
+
request_id = await get_expanded_request_id(request_id)
|
|
1526
|
+
|
|
845
1527
|
while True:
|
|
846
|
-
|
|
847
|
-
if
|
|
1528
|
+
req_status = await requests_lib.get_request_status_async(request_id)
|
|
1529
|
+
if req_status is None:
|
|
848
1530
|
print(f'No task with request ID {request_id}', flush=True)
|
|
849
1531
|
raise fastapi.HTTPException(
|
|
850
1532
|
status_code=404, detail=f'Request {request_id!r} not found')
|
|
851
|
-
if
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
return request_task.encode()
|
|
1533
|
+
if (req_status.status == requests_lib.RequestStatus.RUNNING and
|
|
1534
|
+
daemons.is_daemon_request_id(request_id)):
|
|
1535
|
+
# Daemon requests run forever, break without waiting for complete.
|
|
1536
|
+
break
|
|
1537
|
+
if req_status.status > requests_lib.RequestStatus.RUNNING:
|
|
1538
|
+
break
|
|
858
1539
|
# yield control to allow other coroutines to run, sleep shortly
|
|
859
1540
|
# to avoid storming the DB and CPU in the meantime
|
|
860
1541
|
await asyncio.sleep(0.1)
|
|
1542
|
+
request_task = await requests_lib.get_request_async(request_id)
|
|
1543
|
+
# TODO(aylei): refine this, /api/get will not be retried and this is
|
|
1544
|
+
# meaningless to retry. It is the original request that should be retried.
|
|
1545
|
+
if request_task.should_retry:
|
|
1546
|
+
raise fastapi.HTTPException(
|
|
1547
|
+
status_code=503, detail=f'Request {request_id!r} should be retried')
|
|
1548
|
+
request_error = request_task.get_error()
|
|
1549
|
+
if request_error is not None:
|
|
1550
|
+
raise fastapi.HTTPException(status_code=500,
|
|
1551
|
+
detail=request_task.encode().model_dump())
|
|
1552
|
+
return request_task.encode()
|
|
861
1553
|
|
|
862
1554
|
|
|
863
1555
|
@app.get('/api/stream')
|
|
@@ -891,13 +1583,18 @@ async def stream(
|
|
|
891
1583
|
clients, console for CLI/API clients), 'plain' (force plain text),
|
|
892
1584
|
'html' (force HTML), or 'console' (force console)
|
|
893
1585
|
"""
|
|
1586
|
+
# We need to save the user-supplied request ID for the response header.
|
|
1587
|
+
user_supplied_request_id = request_id
|
|
894
1588
|
if request_id is not None and log_path is not None:
|
|
895
1589
|
raise fastapi.HTTPException(
|
|
896
1590
|
status_code=400,
|
|
897
1591
|
detail='Only one of request_id and log_path can be provided')
|
|
898
1592
|
|
|
1593
|
+
if request_id is not None:
|
|
1594
|
+
request_id = await get_expanded_request_id(request_id)
|
|
1595
|
+
|
|
899
1596
|
if request_id is None and log_path is None:
|
|
900
|
-
request_id = requests_lib.
|
|
1597
|
+
request_id = await requests_lib.get_latest_request_id_async()
|
|
901
1598
|
if request_id is None:
|
|
902
1599
|
raise fastapi.HTTPException(status_code=404,
|
|
903
1600
|
detail='No request found')
|
|
@@ -924,19 +1621,40 @@ async def stream(
|
|
|
924
1621
|
'X-Accel-Buffering': 'no'
|
|
925
1622
|
})
|
|
926
1623
|
|
|
1624
|
+
polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
|
|
927
1625
|
# Original plain text streaming logic
|
|
928
1626
|
if request_id is not None:
|
|
929
|
-
request_task = requests_lib.
|
|
1627
|
+
request_task = await requests_lib.get_request_async(
|
|
1628
|
+
request_id, fields=['request_id', 'schedule_type'])
|
|
930
1629
|
if request_task is None:
|
|
931
1630
|
print(f'No task with request ID {request_id}')
|
|
932
1631
|
raise fastapi.HTTPException(
|
|
933
1632
|
status_code=404, detail=f'Request {request_id!r} not found')
|
|
1633
|
+
# req.log_path is derived from request_id,
|
|
1634
|
+
# so it's ok to just grab the request_id in the above query.
|
|
934
1635
|
log_path_to_stream = request_task.log_path
|
|
1636
|
+
if not log_path_to_stream.exists():
|
|
1637
|
+
# The log file might be deleted by the request GC daemon but the
|
|
1638
|
+
# request task is still in the database.
|
|
1639
|
+
raise fastapi.HTTPException(
|
|
1640
|
+
status_code=404,
|
|
1641
|
+
detail=f'Log of request {request_id!r} has been deleted')
|
|
1642
|
+
if request_task.schedule_type == requests_lib.ScheduleType.LONG:
|
|
1643
|
+
polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
|
|
1644
|
+
del request_task
|
|
935
1645
|
else:
|
|
936
1646
|
assert log_path is not None, (request_id, log_path)
|
|
937
1647
|
if log_path == constants.API_SERVER_LOGS:
|
|
938
1648
|
resolved_log_path = pathlib.Path(
|
|
939
1649
|
constants.API_SERVER_LOGS).expanduser()
|
|
1650
|
+
if not resolved_log_path.exists():
|
|
1651
|
+
raise fastapi.HTTPException(
|
|
1652
|
+
status_code=404,
|
|
1653
|
+
detail='Server log file does not exist. The API server may '
|
|
1654
|
+
'have been started with `--foreground` - check the '
|
|
1655
|
+
'stdout of API server process, such as: '
|
|
1656
|
+
'`kubectl logs -n api-server-namespace '
|
|
1657
|
+
'api-server-pod-name`')
|
|
940
1658
|
else:
|
|
941
1659
|
# This should be a log path under ~/sky_logs.
|
|
942
1660
|
resolved_logs_directory = pathlib.Path(
|
|
@@ -957,18 +1675,26 @@ async def stream(
|
|
|
957
1675
|
detail=f'Log path {log_path!r} does not exist')
|
|
958
1676
|
|
|
959
1677
|
log_path_to_stream = resolved_log_path
|
|
1678
|
+
|
|
1679
|
+
headers = {
|
|
1680
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
1681
|
+
'X-Accel-Buffering': 'no',
|
|
1682
|
+
'Transfer-Encoding': 'chunked'
|
|
1683
|
+
}
|
|
1684
|
+
if request_id is not None:
|
|
1685
|
+
headers[server_constants.STREAM_REQUEST_HEADER] = (
|
|
1686
|
+
user_supplied_request_id
|
|
1687
|
+
if user_supplied_request_id else request_id)
|
|
1688
|
+
|
|
960
1689
|
return fastapi.responses.StreamingResponse(
|
|
961
1690
|
content=stream_utils.log_streamer(request_id,
|
|
962
1691
|
log_path_to_stream,
|
|
963
1692
|
plain_logs=format == 'plain',
|
|
964
1693
|
tail=tail,
|
|
965
|
-
follow=follow
|
|
1694
|
+
follow=follow,
|
|
1695
|
+
polling_interval=polling_interval),
|
|
966
1696
|
media_type='text/plain',
|
|
967
|
-
headers=
|
|
968
|
-
'Cache-Control': 'no-cache, no-transform',
|
|
969
|
-
'X-Accel-Buffering': 'no',
|
|
970
|
-
'Transfer-Encoding': 'chunked'
|
|
971
|
-
},
|
|
1697
|
+
headers=headers,
|
|
972
1698
|
)
|
|
973
1699
|
|
|
974
1700
|
|
|
@@ -976,11 +1702,11 @@ async def stream(
|
|
|
976
1702
|
async def api_cancel(request: fastapi.Request,
|
|
977
1703
|
request_cancel_body: payloads.RequestCancelBody) -> None:
|
|
978
1704
|
"""Cancels requests."""
|
|
979
|
-
executor.
|
|
1705
|
+
await executor.schedule_request_async(
|
|
980
1706
|
request_id=request.state.request_id,
|
|
981
|
-
request_name=
|
|
1707
|
+
request_name=request_names.RequestName.API_CANCEL,
|
|
982
1708
|
request_body=request_cancel_body,
|
|
983
|
-
func=requests_lib.
|
|
1709
|
+
func=requests_lib.kill_requests_with_prefix,
|
|
984
1710
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
985
1711
|
)
|
|
986
1712
|
|
|
@@ -988,10 +1714,14 @@ async def api_cancel(request: fastapi.Request,
|
|
|
988
1714
|
@app.get('/api/status')
|
|
989
1715
|
async def api_status(
|
|
990
1716
|
request_ids: Optional[List[str]] = fastapi.Query(
|
|
991
|
-
None, description='Request
|
|
1717
|
+
None, description='Request ID prefixes to get status for.'),
|
|
992
1718
|
all_status: bool = fastapi.Query(
|
|
993
1719
|
False, description='Get finished requests as well.'),
|
|
994
|
-
|
|
1720
|
+
limit: Optional[int] = fastapi.Query(
|
|
1721
|
+
None, description='Number of requests to show.'),
|
|
1722
|
+
fields: Optional[List[str]] = fastapi.Query(
|
|
1723
|
+
None, description='Fields to get. If None, get all fields.'),
|
|
1724
|
+
) -> List[payloads.RequestPayload]:
|
|
995
1725
|
"""Gets the list of requests."""
|
|
996
1726
|
if request_ids is None:
|
|
997
1727
|
statuses = None
|
|
@@ -1000,53 +1730,120 @@ async def api_status(
|
|
|
1000
1730
|
requests_lib.RequestStatus.PENDING,
|
|
1001
1731
|
requests_lib.RequestStatus.RUNNING,
|
|
1002
1732
|
]
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1733
|
+
request_tasks = await requests_lib.get_request_tasks_async(
|
|
1734
|
+
req_filter=requests_lib.RequestTaskFilter(
|
|
1735
|
+
status=statuses,
|
|
1736
|
+
limit=limit,
|
|
1737
|
+
fields=fields,
|
|
1738
|
+
sort=True,
|
|
1739
|
+
))
|
|
1740
|
+
return requests_lib.encode_requests(request_tasks)
|
|
1007
1741
|
else:
|
|
1008
1742
|
encoded_request_tasks = []
|
|
1009
1743
|
for request_id in request_ids:
|
|
1010
|
-
|
|
1011
|
-
|
|
1744
|
+
request_tasks = await requests_lib.get_requests_async_with_prefix(
|
|
1745
|
+
request_id)
|
|
1746
|
+
if request_tasks is None:
|
|
1012
1747
|
continue
|
|
1013
|
-
|
|
1748
|
+
for request_task in request_tasks:
|
|
1749
|
+
encoded_request_tasks.append(request_task.readable_encode())
|
|
1014
1750
|
return encoded_request_tasks
|
|
1015
1751
|
|
|
1016
1752
|
|
|
1017
|
-
@app.get(
|
|
1018
|
-
|
|
1753
|
+
@app.get(
|
|
1754
|
+
'/api/health',
|
|
1755
|
+
# response_model_exclude_unset omits unset fields
|
|
1756
|
+
# in the response JSON.
|
|
1757
|
+
response_model_exclude_unset=True)
|
|
1758
|
+
async def health(request: fastapi.Request) -> responses.APIHealthResponse:
|
|
1019
1759
|
"""Checks the health of the API server.
|
|
1020
1760
|
|
|
1021
1761
|
Returns:
|
|
1022
|
-
|
|
1023
|
-
- status: str; The status of the API server.
|
|
1024
|
-
- api_version: str; The API version of the API server.
|
|
1025
|
-
- version: str; The version of SkyPilot used for API server.
|
|
1026
|
-
- version_on_disk: str; The version of the SkyPilot installation on
|
|
1027
|
-
disk, which can be used to warn about restarting the API server
|
|
1028
|
-
- commit: str; The commit hash of SkyPilot used for API server.
|
|
1762
|
+
responses.APIHealthResponse: The health response.
|
|
1029
1763
|
"""
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1764
|
+
user = request.state.auth_user
|
|
1765
|
+
server_status = common.ApiServerStatus.HEALTHY
|
|
1766
|
+
if getattr(request.state, 'anonymous_user', False):
|
|
1767
|
+
# API server authentication is enabled, but the request is not
|
|
1768
|
+
# authenticated. We still have to serve the request because the
|
|
1769
|
+
# /api/health endpoint has two different usage:
|
|
1770
|
+
# 1. For health check from `api start` and external ochestration
|
|
1771
|
+
# tools (k8s), which does not require authentication and user info.
|
|
1772
|
+
# 2. Return server info to client and hint client to login if required.
|
|
1773
|
+
# Separating these two usage to different APIs will break backward
|
|
1774
|
+
# compatibility for existing ochestration solutions (e.g. helm chart).
|
|
1775
|
+
# So we serve these two usages in a backward compatible manner below.
|
|
1776
|
+
client_version = versions.get_remote_api_version()
|
|
1777
|
+
# - For Client with API version >= 14, we return 200 response with
|
|
1778
|
+
# status=NEEDS_AUTH, new client will handle the login process.
|
|
1779
|
+
# - For health check from `sky api start`, the client code always uses
|
|
1780
|
+
# the same API version with the server, thus there is no compatibility
|
|
1781
|
+
# issue.
|
|
1782
|
+
server_status = common.ApiServerStatus.NEEDS_AUTH
|
|
1783
|
+
if client_version is None:
|
|
1784
|
+
# - For health check from ochestration tools (e.g. k8s), we also
|
|
1785
|
+
# return 200 with status=NEEDS_AUTH, which passes HTTP probe
|
|
1786
|
+
# check.
|
|
1787
|
+
# - There is no harm when an malicious client calls /api/health
|
|
1788
|
+
# without authentication since no sensitive information is
|
|
1789
|
+
# returned.
|
|
1790
|
+
return responses.APIHealthResponse(
|
|
1791
|
+
status=common.ApiServerStatus.HEALTHY,)
|
|
1792
|
+
# TODO(aylei): remove this after min_compatible_api_version >= 14.
|
|
1793
|
+
if client_version < 14:
|
|
1794
|
+
# For Client with API version < 14, the NEEDS_AUTH status is not
|
|
1795
|
+
# honored. Return 401 to trigger the login process.
|
|
1796
|
+
raise fastapi.HTTPException(status_code=401,
|
|
1797
|
+
detail='Authentication required')
|
|
1798
|
+
|
|
1799
|
+
logger.debug(f'Health endpoint: request.state.auth_user = {user}')
|
|
1800
|
+
return responses.APIHealthResponse(
|
|
1801
|
+
status=server_status,
|
|
1802
|
+
# Kept for backward compatibility, clients before 0.11.0 will read this
|
|
1803
|
+
# field to check compatibility and hint the user to upgrade the CLI.
|
|
1804
|
+
# TODO(aylei): remove this field after 0.13.0
|
|
1805
|
+
api_version=str(server_constants.API_VERSION),
|
|
1806
|
+
version=sky.__version__,
|
|
1807
|
+
version_on_disk=common.get_skypilot_version_on_disk(),
|
|
1808
|
+
commit=sky.__commit__,
|
|
1809
|
+
# Whether basic auth on api server is enabled
|
|
1810
|
+
basic_auth_enabled=os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH,
|
|
1811
|
+
'false').lower() == 'true',
|
|
1812
|
+
user=user if user is not None else None,
|
|
1813
|
+
# Whether service account token is enabled
|
|
1814
|
+
service_account_token_enabled=(os.environ.get(
|
|
1815
|
+
constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
|
|
1816
|
+
'false').lower() == 'true'),
|
|
1817
|
+
# Whether basic auth on ingress is enabled
|
|
1818
|
+
ingress_basic_auth_enabled=os.environ.get(
|
|
1819
|
+
constants.SKYPILOT_INGRESS_BASIC_AUTH_ENABLED,
|
|
1820
|
+
'false').lower() == 'true',
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
|
|
1824
|
+
class KubernetesSSHMessageType(IntEnum):
|
|
1825
|
+
REGULAR_DATA = 0
|
|
1826
|
+
PINGPONG = 1
|
|
1827
|
+
LATENCY_MEASUREMENT = 2
|
|
1037
1828
|
|
|
1038
1829
|
|
|
1039
1830
|
@app.websocket('/kubernetes-pod-ssh-proxy')
|
|
1040
1831
|
async def kubernetes_pod_ssh_proxy(
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
) -> None:
|
|
1832
|
+
websocket: fastapi.WebSocket,
|
|
1833
|
+
cluster_name: str,
|
|
1834
|
+
client_version: Optional[int] = None) -> None:
|
|
1044
1835
|
"""Proxies SSH to the Kubernetes pod with websocket."""
|
|
1045
1836
|
await websocket.accept()
|
|
1046
|
-
cluster_name = cluster_name_body.cluster_name
|
|
1047
1837
|
logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
|
|
1048
1838
|
|
|
1049
|
-
|
|
1839
|
+
timestamps_supported = client_version is not None and client_version > 21
|
|
1840
|
+
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
1841
|
+
client_version = {client_version}')
|
|
1842
|
+
|
|
1843
|
+
# Run core.status in another thread to avoid blocking the event loop.
|
|
1844
|
+
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1845
|
+
cluster_records = await context_utils.to_thread_with_executor(
|
|
1846
|
+
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1050
1847
|
cluster_record = cluster_records[0]
|
|
1051
1848
|
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1052
1849
|
raise fastapi.HTTPException(
|
|
@@ -1085,17 +1882,70 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1085
1882
|
return
|
|
1086
1883
|
|
|
1087
1884
|
logger.info(f'Starting port-forward to local port: {local_port}')
|
|
1885
|
+
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
1886
|
+
pid=os.getpid())
|
|
1887
|
+
ssh_failed = False
|
|
1888
|
+
websocket_closed = False
|
|
1088
1889
|
try:
|
|
1890
|
+
conn_gauge.inc()
|
|
1089
1891
|
# Connect to the local port
|
|
1090
1892
|
reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
|
|
1091
1893
|
|
|
1092
1894
|
async def websocket_to_ssh():
|
|
1093
1895
|
try:
|
|
1094
1896
|
async for message in websocket.iter_bytes():
|
|
1897
|
+
if timestamps_supported:
|
|
1898
|
+
type_size = struct.calcsize('!B')
|
|
1899
|
+
message_type = struct.unpack('!B',
|
|
1900
|
+
message[:type_size])[0]
|
|
1901
|
+
if (message_type ==
|
|
1902
|
+
KubernetesSSHMessageType.REGULAR_DATA):
|
|
1903
|
+
# Regular data - strip type byte and forward to SSH
|
|
1904
|
+
message = message[type_size:]
|
|
1905
|
+
elif message_type == KubernetesSSHMessageType.PINGPONG:
|
|
1906
|
+
# PING message - respond with PONG (type 1)
|
|
1907
|
+
ping_id_size = struct.calcsize('!I')
|
|
1908
|
+
if len(message) != type_size + ping_id_size:
|
|
1909
|
+
raise ValueError('Invalid PING message '
|
|
1910
|
+
f'length: {len(message)}')
|
|
1911
|
+
# Return the same PING message, so that the client
|
|
1912
|
+
# can measure the latency.
|
|
1913
|
+
await websocket.send_bytes(message)
|
|
1914
|
+
continue
|
|
1915
|
+
elif (message_type ==
|
|
1916
|
+
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
|
|
1917
|
+
# Latency measurement from client
|
|
1918
|
+
latency_size = struct.calcsize('!Q')
|
|
1919
|
+
if len(message) != type_size + latency_size:
|
|
1920
|
+
raise ValueError(
|
|
1921
|
+
'Invalid latency measurement '
|
|
1922
|
+
f'message length: {len(message)}')
|
|
1923
|
+
avg_latency_ms = struct.unpack(
|
|
1924
|
+
'!Q',
|
|
1925
|
+
message[type_size:type_size + latency_size])[0]
|
|
1926
|
+
latency_seconds = avg_latency_ms / 1000
|
|
1927
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
|
|
1928
|
+
continue
|
|
1929
|
+
else:
|
|
1930
|
+
# Unknown message type.
|
|
1931
|
+
raise ValueError(
|
|
1932
|
+
f'Unknown message type: {message_type}')
|
|
1095
1933
|
writer.write(message)
|
|
1096
|
-
|
|
1934
|
+
try:
|
|
1935
|
+
await writer.drain()
|
|
1936
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1937
|
+
# Typically we will not reach here, if the ssh to pod
|
|
1938
|
+
# is disconnected, ssh_to_websocket will exit first.
|
|
1939
|
+
# But just in case.
|
|
1940
|
+
logger.error('Failed to write to pod through '
|
|
1941
|
+
f'port-forward connection: {e}')
|
|
1942
|
+
nonlocal ssh_failed
|
|
1943
|
+
ssh_failed = True
|
|
1944
|
+
break
|
|
1097
1945
|
except fastapi.WebSocketDisconnect:
|
|
1098
1946
|
pass
|
|
1947
|
+
nonlocal websocket_closed
|
|
1948
|
+
websocket_closed = True
|
|
1099
1949
|
writer.close()
|
|
1100
1950
|
|
|
1101
1951
|
async def ssh_to_websocket():
|
|
@@ -1103,87 +1953,249 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1103
1953
|
while True:
|
|
1104
1954
|
data = await reader.read(1024)
|
|
1105
1955
|
if not data:
|
|
1956
|
+
if not websocket_closed:
|
|
1957
|
+
logger.warning('SSH connection to pod is '
|
|
1958
|
+
'disconnected before websocket '
|
|
1959
|
+
'connection is closed')
|
|
1960
|
+
nonlocal ssh_failed
|
|
1961
|
+
ssh_failed = True
|
|
1106
1962
|
break
|
|
1963
|
+
if timestamps_supported:
|
|
1964
|
+
# Prepend message type byte (0 = regular data)
|
|
1965
|
+
message_type_bytes = struct.pack(
|
|
1966
|
+
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
|
|
1967
|
+
data = message_type_bytes + data
|
|
1107
1968
|
await websocket.send_bytes(data)
|
|
1108
1969
|
except Exception: # pylint: disable=broad-except
|
|
1109
1970
|
pass
|
|
1110
|
-
|
|
1971
|
+
try:
|
|
1972
|
+
await websocket.close()
|
|
1973
|
+
except Exception: # pylint: disable=broad-except
|
|
1974
|
+
# The websocket might has been closed by the client.
|
|
1975
|
+
pass
|
|
1111
1976
|
|
|
1112
1977
|
await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
|
|
1113
1978
|
finally:
|
|
1114
|
-
|
|
1979
|
+
conn_gauge.dec()
|
|
1980
|
+
reason = ''
|
|
1981
|
+
try:
|
|
1982
|
+
logger.info('Terminating kubectl port-forward process')
|
|
1983
|
+
proc.terminate()
|
|
1984
|
+
except ProcessLookupError:
|
|
1985
|
+
stdout = await proc.stdout.read()
|
|
1986
|
+
logger.error('kubectl port-forward was terminated before the '
|
|
1987
|
+
'ssh websocket connection was closed. Remaining '
|
|
1988
|
+
f'output: {str(stdout)}')
|
|
1989
|
+
reason = 'KubectlPortForwardExit'
|
|
1990
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
1991
|
+
pid=os.getpid(), reason='KubectlPortForwardExit').inc()
|
|
1992
|
+
else:
|
|
1993
|
+
if ssh_failed:
|
|
1994
|
+
reason = 'SSHToPodDisconnected'
|
|
1995
|
+
else:
|
|
1996
|
+
reason = 'ClientClosed'
|
|
1997
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
1998
|
+
pid=os.getpid(), reason=reason).inc()
|
|
1999
|
+
|
|
2000
|
+
|
|
2001
|
+
@app.get('/all_contexts')
|
|
2002
|
+
async def all_contexts(request: fastapi.Request) -> None:
|
|
2003
|
+
"""Gets all Kubernetes and SSH node pool contexts."""
|
|
2004
|
+
|
|
2005
|
+
await executor.schedule_request_async(
|
|
2006
|
+
request_id=request.state.request_id,
|
|
2007
|
+
request_name=request_names.RequestName.ALL_CONTEXTS,
|
|
2008
|
+
request_body=payloads.RequestBody(),
|
|
2009
|
+
func=core.get_all_contexts,
|
|
2010
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
2011
|
+
)
|
|
1115
2012
|
|
|
1116
2013
|
|
|
1117
2014
|
# === Internal APIs ===
|
|
1118
2015
|
@app.get('/api/completion/cluster_name')
|
|
1119
2016
|
async def complete_cluster_name(incomplete: str,) -> List[str]:
|
|
1120
|
-
return
|
|
2017
|
+
return await context_utils.to_thread(
|
|
2018
|
+
global_user_state.get_cluster_names_start_with, incomplete)
|
|
1121
2019
|
|
|
1122
2020
|
|
|
1123
2021
|
@app.get('/api/completion/storage_name')
|
|
1124
2022
|
async def complete_storage_name(incomplete: str,) -> List[str]:
|
|
1125
|
-
return
|
|
2023
|
+
return await context_utils.to_thread(
|
|
2024
|
+
global_user_state.get_storage_names_start_with, incomplete)
|
|
1126
2025
|
|
|
1127
2026
|
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
2027
|
+
@app.get('/api/completion/volume_name')
|
|
2028
|
+
async def complete_volume_name(incomplete: str,) -> List[str]:
|
|
2029
|
+
return await context_utils.to_thread(
|
|
2030
|
+
global_user_state.get_volume_names_start_with, incomplete)
|
|
1132
2031
|
|
|
1133
|
-
Handles the /dashboard prefix from Next.js configuration.
|
|
1134
|
-
"""
|
|
1135
|
-
# Check if the path starts with 'dashboard/' and remove it if it does
|
|
1136
|
-
if full_path.startswith('dashboard/'):
|
|
1137
|
-
full_path = full_path[len('dashboard/'):]
|
|
1138
2032
|
|
|
1139
|
-
|
|
2033
|
+
@app.get('/api/completion/api_request')
|
|
2034
|
+
async def complete_api_request(incomplete: str,) -> List[str]:
|
|
2035
|
+
return await requests_lib.get_api_request_ids_start_with(incomplete)
|
|
2036
|
+
|
|
2037
|
+
|
|
2038
|
+
@app.get('/dashboard/{full_path:path}')
|
|
2039
|
+
async def serve_dashboard(full_path: str):
|
|
2040
|
+
"""Serves the Next.js dashboard application.
|
|
2041
|
+
|
|
2042
|
+
Args:
|
|
2043
|
+
full_path: The path requested by the client.
|
|
2044
|
+
e.g. /clusters, /jobs
|
|
2045
|
+
|
|
2046
|
+
Returns:
|
|
2047
|
+
FileResponse for static files or index.html for client-side routing.
|
|
2048
|
+
|
|
2049
|
+
Raises:
|
|
2050
|
+
HTTPException: If the path is invalid or file not found.
|
|
2051
|
+
"""
|
|
2052
|
+
# Try to serve the staticfile directly e.g. /skypilot.svg,
|
|
2053
|
+
# /favicon.ico, and /_next/, etc.
|
|
1140
2054
|
file_path = os.path.join(server_constants.DASHBOARD_DIR, full_path)
|
|
1141
2055
|
if os.path.isfile(file_path):
|
|
1142
2056
|
return fastapi.responses.FileResponse(file_path)
|
|
1143
2057
|
|
|
1144
|
-
#
|
|
1145
|
-
#
|
|
1146
|
-
# client will be redirected to the index.html.
|
|
2058
|
+
# Serve index.html for client-side routing
|
|
2059
|
+
# e.g. /clusters, /jobs
|
|
1147
2060
|
index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
|
|
1148
2061
|
try:
|
|
1149
2062
|
with open(index_path, 'r', encoding='utf-8') as f:
|
|
1150
2063
|
content = f.read()
|
|
2064
|
+
|
|
1151
2065
|
return fastapi.responses.HTMLResponse(content=content)
|
|
1152
2066
|
except Exception as e:
|
|
1153
2067
|
logger.error(f'Error serving dashboard: {e}')
|
|
1154
2068
|
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
|
1155
2069
|
|
|
1156
2070
|
|
|
2071
|
+
# Redirect the root path to dashboard
|
|
2072
|
+
@app.get('/')
|
|
2073
|
+
async def root():
|
|
2074
|
+
return fastapi.responses.RedirectResponse(url='/dashboard/')
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
def _init_or_restore_server_user_hash():
|
|
2078
|
+
"""Restores the server user hash from the global user state db.
|
|
2079
|
+
|
|
2080
|
+
The API server must have a stable user hash across restarts and potential
|
|
2081
|
+
multiple replicas. Thus we persist the user hash in db and restore it on
|
|
2082
|
+
startup. When upgrading from old version, the user hash will be read from
|
|
2083
|
+
the local file (if any) to keep the user hash consistent.
|
|
2084
|
+
"""
|
|
2085
|
+
|
|
2086
|
+
def apply_user_hash(user_hash: str) -> None:
|
|
2087
|
+
# For local API server, the user hash in db and local file should be
|
|
2088
|
+
# same so there is no harm to override here.
|
|
2089
|
+
common_utils.set_user_hash_locally(user_hash)
|
|
2090
|
+
# Refresh the server user hash for current process after restore or
|
|
2091
|
+
# initialize the user hash in db, child processes will get the correct
|
|
2092
|
+
# server id from the local cache file.
|
|
2093
|
+
common_lib.refresh_server_id()
|
|
2094
|
+
|
|
2095
|
+
user_hash = global_user_state.get_system_config(_SERVER_USER_HASH_KEY)
|
|
2096
|
+
if user_hash is not None:
|
|
2097
|
+
apply_user_hash(user_hash)
|
|
2098
|
+
return
|
|
2099
|
+
|
|
2100
|
+
# Initial deployment, generate a user hash and save it to the db.
|
|
2101
|
+
user_hash = common_utils.get_user_hash()
|
|
2102
|
+
global_user_state.set_system_config(_SERVER_USER_HASH_KEY, user_hash)
|
|
2103
|
+
apply_user_hash(user_hash)
|
|
2104
|
+
|
|
2105
|
+
|
|
1157
2106
|
if __name__ == '__main__':
|
|
1158
2107
|
import uvicorn
|
|
1159
2108
|
|
|
1160
2109
|
from sky.server import uvicorn as skyuvicorn
|
|
1161
2110
|
|
|
1162
|
-
|
|
2111
|
+
logger.info('Initializing SkyPilot API server')
|
|
2112
|
+
skyuvicorn.add_timestamp_prefix_for_server_logs()
|
|
1163
2113
|
|
|
1164
2114
|
parser = argparse.ArgumentParser()
|
|
1165
2115
|
parser.add_argument('--host', default='127.0.0.1')
|
|
1166
2116
|
parser.add_argument('--port', default=46580, type=int)
|
|
1167
2117
|
parser.add_argument('--deploy', action='store_true')
|
|
2118
|
+
# Serve metrics on a separate port to isolate it from the application APIs:
|
|
2119
|
+
# metrics port will not be exposed to the public network typically.
|
|
2120
|
+
parser.add_argument('--metrics-port', default=9090, type=int)
|
|
1168
2121
|
cmd_args = parser.parse_args()
|
|
2122
|
+
if cmd_args.port == cmd_args.metrics_port:
|
|
2123
|
+
logger.error('port and metrics-port cannot be the same, exiting.')
|
|
2124
|
+
raise ValueError('port and metrics-port cannot be the same')
|
|
2125
|
+
|
|
2126
|
+
# Fail fast if the port is not available to avoid corrupt the state
|
|
2127
|
+
# of potential running server instance.
|
|
2128
|
+
# We might reach here because the running server is currently not
|
|
2129
|
+
# responding, thus the healthz check fails and `sky api start` think
|
|
2130
|
+
# we should start a new server instance.
|
|
2131
|
+
if not common_utils.is_port_available(cmd_args.port):
|
|
2132
|
+
logger.error(f'Port {cmd_args.port} is not available, exiting.')
|
|
2133
|
+
raise RuntimeError(f'Port {cmd_args.port} is not available')
|
|
2134
|
+
|
|
2135
|
+
# Maybe touch the signal file on API server startup. Do it again here even
|
|
2136
|
+
# if we already touched it in the sky/server/common.py::_start_api_server.
|
|
2137
|
+
# This is because the sky/server/common.py::_start_api_server function call
|
|
2138
|
+
# is running outside the skypilot API server process tree. The process tree
|
|
2139
|
+
# starts within that function (see the `subprocess.Popen` call in
|
|
2140
|
+
# sky/server/common.py::_start_api_server). When pg is used, the
|
|
2141
|
+
# _start_api_server function will not load the config file from db, which
|
|
2142
|
+
# will ignore the consolidation mode config. Here, inside the process tree,
|
|
2143
|
+
# we already reload the config as a server (with env var _start_api_server),
|
|
2144
|
+
# so we will respect the consolidation mode config.
|
|
2145
|
+
# Refers to #7717 for more details.
|
|
2146
|
+
managed_job_utils.is_consolidation_mode(on_api_restart=True)
|
|
2147
|
+
|
|
1169
2148
|
# Show the privacy policy if it is not already shown. We place it here so
|
|
1170
2149
|
# that it is shown only when the API server is started.
|
|
1171
2150
|
usage_lib.maybe_show_privacy_policy()
|
|
1172
2151
|
|
|
1173
|
-
|
|
2152
|
+
# Initialize global user state db
|
|
2153
|
+
db_utils.set_max_connections(1)
|
|
2154
|
+
logger.info('Initializing database engine')
|
|
2155
|
+
global_user_state.initialize_and_get_db()
|
|
2156
|
+
logger.info('Database engine initialized')
|
|
2157
|
+
# Initialize request db
|
|
2158
|
+
requests_lib.reset_db_and_logs()
|
|
2159
|
+
# Restore the server user hash
|
|
2160
|
+
logger.info('Initializing server user hash')
|
|
2161
|
+
_init_or_restore_server_user_hash()
|
|
2162
|
+
|
|
2163
|
+
max_db_connections = global_user_state.get_max_db_connections()
|
|
2164
|
+
logger.info(f'Max db connections: {max_db_connections}')
|
|
2165
|
+
config = server_config.compute_server_config(cmd_args.deploy,
|
|
2166
|
+
max_db_connections)
|
|
2167
|
+
|
|
1174
2168
|
num_workers = config.num_server_workers
|
|
1175
2169
|
|
|
1176
|
-
|
|
2170
|
+
queue_server: Optional[multiprocessing.Process] = None
|
|
2171
|
+
workers: List[executor.RequestWorker] = []
|
|
2172
|
+
# Global background tasks that will be scheduled in a separate event loop.
|
|
2173
|
+
global_tasks: List[asyncio.Task] = []
|
|
1177
2174
|
try:
|
|
1178
|
-
|
|
2175
|
+
background = uvloop.new_event_loop()
|
|
2176
|
+
if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
|
|
2177
|
+
metrics_server = metrics.build_metrics_server(
|
|
2178
|
+
cmd_args.host, cmd_args.metrics_port)
|
|
2179
|
+
global_tasks.append(background.create_task(metrics_server.serve()))
|
|
2180
|
+
global_tasks.append(
|
|
2181
|
+
background.create_task(requests_lib.requests_gc_daemon()))
|
|
2182
|
+
global_tasks.append(
|
|
2183
|
+
background.create_task(
|
|
2184
|
+
global_user_state.cluster_event_retention_daemon()))
|
|
2185
|
+
threading.Thread(target=background.run_forever, daemon=True).start()
|
|
2186
|
+
|
|
2187
|
+
queue_server, workers = executor.start(config)
|
|
2188
|
+
|
|
1179
2189
|
logger.info(f'Starting SkyPilot API server, workers={num_workers}')
|
|
1180
2190
|
# We don't support reload for now, since it may cause leakage of request
|
|
1181
2191
|
# workers or interrupt running requests.
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
2192
|
+
uvicorn_config = uvicorn.Config('sky.server.server:app',
|
|
2193
|
+
host=cmd_args.host,
|
|
2194
|
+
port=cmd_args.port,
|
|
2195
|
+
workers=num_workers,
|
|
2196
|
+
ws_per_message_deflate=False)
|
|
2197
|
+
skyuvicorn.run(uvicorn_config,
|
|
2198
|
+
max_db_connections=config.num_db_connections_per_worker)
|
|
1187
2199
|
except Exception as exc: # pylint: disable=broad-except
|
|
1188
2200
|
logger.error(f'Failed to start SkyPilot API server: '
|
|
1189
2201
|
f'{common_utils.format_exception(exc, use_bracket=True)}')
|
|
@@ -1191,17 +2203,11 @@ if __name__ == '__main__':
|
|
|
1191
2203
|
finally:
|
|
1192
2204
|
logger.info('Shutting down SkyPilot API server...')
|
|
1193
2205
|
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
# Terminate processes in reverse order in case dependency, especially
|
|
1203
|
-
# queue server. Terminate queue server first does not affect the
|
|
1204
|
-
# correctness of cleanup but introduce redundant error messages.
|
|
1205
|
-
subprocess_utils.run_in_parallel(cleanup,
|
|
1206
|
-
list(reversed(sub_procs)),
|
|
1207
|
-
num_threads=len(sub_procs))
|
|
2206
|
+
for gt in global_tasks:
|
|
2207
|
+
gt.cancel()
|
|
2208
|
+
subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
|
|
2209
|
+
workers,
|
|
2210
|
+
num_threads=len(workers))
|
|
2211
|
+
if queue_server is not None:
|
|
2212
|
+
queue_server.kill()
|
|
2213
|
+
queue_server.join()
|