skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +22 -6
- sky/adaptors/aws.py +81 -16
- sky/adaptors/common.py +25 -2
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/do.py +8 -2
- sky/adaptors/gcp.py +11 -0
- sky/adaptors/hyperbolic.py +8 -0
- sky/adaptors/ibm.py +5 -2
- sky/adaptors/kubernetes.py +149 -18
- sky/adaptors/nebius.py +173 -30
- sky/adaptors/primeintellect.py +1 -0
- sky/adaptors/runpod.py +68 -0
- sky/adaptors/seeweb.py +183 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +187 -4
- sky/authentication.py +179 -225
- sky/backends/__init__.py +4 -2
- sky/backends/backend.py +22 -9
- sky/backends/backend_utils.py +1323 -397
- sky/backends/cloud_vm_ray_backend.py +1749 -1029
- sky/backends/docker_utils.py +1 -1
- sky/backends/local_docker_backend.py +11 -6
- sky/backends/task_codegen.py +633 -0
- sky/backends/wheel_utils.py +55 -9
- sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
- sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
- sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
- sky/{clouds/service_catalog → catalog}/common.py +90 -49
- sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
- sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
- sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
- sky/catalog/data_fetchers/fetch_nebius.py +338 -0
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
- sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
- sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
- sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
- sky/catalog/hyperbolic_catalog.py +136 -0
- sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
- sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
- sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
- sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
- sky/catalog/primeintellect_catalog.py +95 -0
- sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
- sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
- sky/catalog/seeweb_catalog.py +184 -0
- sky/catalog/shadeform_catalog.py +165 -0
- sky/catalog/ssh_catalog.py +167 -0
- sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
- sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
- sky/check.py +533 -185
- sky/cli.py +5 -5975
- sky/client/{cli.py → cli/command.py} +2591 -1956
- sky/client/cli/deprecation_utils.py +99 -0
- sky/client/cli/flags.py +359 -0
- sky/client/cli/table_utils.py +322 -0
- sky/client/cli/utils.py +79 -0
- sky/client/common.py +78 -32
- sky/client/oauth.py +82 -0
- sky/client/sdk.py +1219 -319
- sky/client/sdk_async.py +827 -0
- sky/client/service_account_auth.py +47 -0
- sky/cloud_stores.py +82 -3
- sky/clouds/__init__.py +13 -0
- sky/clouds/aws.py +564 -164
- sky/clouds/azure.py +105 -83
- sky/clouds/cloud.py +140 -40
- sky/clouds/cudo.py +68 -50
- sky/clouds/do.py +66 -48
- sky/clouds/fluidstack.py +63 -44
- sky/clouds/gcp.py +339 -110
- sky/clouds/hyperbolic.py +293 -0
- sky/clouds/ibm.py +70 -49
- sky/clouds/kubernetes.py +570 -162
- sky/clouds/lambda_cloud.py +74 -54
- sky/clouds/nebius.py +210 -81
- sky/clouds/oci.py +88 -66
- sky/clouds/paperspace.py +61 -44
- sky/clouds/primeintellect.py +317 -0
- sky/clouds/runpod.py +164 -74
- sky/clouds/scp.py +89 -86
- sky/clouds/seeweb.py +477 -0
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +263 -0
- sky/clouds/utils/aws_utils.py +10 -4
- sky/clouds/utils/gcp_utils.py +87 -11
- sky/clouds/utils/oci_utils.py +38 -14
- sky/clouds/utils/scp_utils.py +231 -167
- sky/clouds/vast.py +99 -77
- sky/clouds/vsphere.py +51 -40
- sky/core.py +375 -173
- sky/dag.py +15 -0
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
- sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
- sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
- sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
- sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
- sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
- sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
- sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
- sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
- sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
- sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
- sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -0
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -0
- sky/dashboard/out/infra.html +1 -0
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -0
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -0
- sky/dashboard/out/volumes.html +1 -0
- sky/dashboard/out/workspace/new.html +1 -0
- sky/dashboard/out/workspaces/[name].html +1 -0
- sky/dashboard/out/workspaces.html +1 -0
- sky/data/data_utils.py +137 -1
- sky/data/mounting_utils.py +269 -84
- sky/data/storage.py +1460 -1807
- sky/data/storage_utils.py +43 -57
- sky/exceptions.py +126 -2
- sky/execution.py +216 -63
- sky/global_user_state.py +2390 -586
- sky/jobs/__init__.py +7 -0
- sky/jobs/client/sdk.py +300 -58
- sky/jobs/client/sdk_async.py +161 -0
- sky/jobs/constants.py +15 -8
- sky/jobs/controller.py +848 -275
- sky/jobs/file_content_utils.py +128 -0
- sky/jobs/log_gc.py +193 -0
- sky/jobs/recovery_strategy.py +402 -152
- sky/jobs/scheduler.py +314 -189
- sky/jobs/server/core.py +836 -255
- sky/jobs/server/server.py +156 -115
- sky/jobs/server/utils.py +136 -0
- sky/jobs/state.py +2109 -706
- sky/jobs/utils.py +1306 -215
- sky/logs/__init__.py +21 -0
- sky/logs/agent.py +108 -0
- sky/logs/aws.py +243 -0
- sky/logs/gcp.py +91 -0
- sky/metrics/__init__.py +0 -0
- sky/metrics/utils.py +453 -0
- sky/models.py +78 -1
- sky/optimizer.py +164 -70
- sky/provision/__init__.py +90 -4
- sky/provision/aws/config.py +147 -26
- sky/provision/aws/instance.py +136 -50
- sky/provision/azure/instance.py +11 -6
- sky/provision/common.py +13 -1
- sky/provision/cudo/cudo_machine_type.py +1 -1
- sky/provision/cudo/cudo_utils.py +14 -8
- sky/provision/cudo/cudo_wrapper.py +72 -71
- sky/provision/cudo/instance.py +10 -6
- sky/provision/do/instance.py +10 -6
- sky/provision/do/utils.py +4 -3
- sky/provision/docker_utils.py +140 -33
- sky/provision/fluidstack/instance.py +13 -8
- sky/provision/gcp/__init__.py +1 -0
- sky/provision/gcp/config.py +301 -19
- sky/provision/gcp/constants.py +218 -0
- sky/provision/gcp/instance.py +36 -8
- sky/provision/gcp/instance_utils.py +18 -4
- sky/provision/gcp/volume_utils.py +247 -0
- sky/provision/hyperbolic/__init__.py +12 -0
- sky/provision/hyperbolic/config.py +10 -0
- sky/provision/hyperbolic/instance.py +437 -0
- sky/provision/hyperbolic/utils.py +373 -0
- sky/provision/instance_setup.py +101 -20
- sky/provision/kubernetes/__init__.py +5 -0
- sky/provision/kubernetes/config.py +9 -52
- sky/provision/kubernetes/constants.py +17 -0
- sky/provision/kubernetes/instance.py +919 -280
- sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
- sky/provision/kubernetes/network.py +27 -17
- sky/provision/kubernetes/network_utils.py +44 -43
- sky/provision/kubernetes/utils.py +1221 -534
- sky/provision/kubernetes/volume.py +343 -0
- sky/provision/lambda_cloud/instance.py +22 -16
- sky/provision/nebius/constants.py +50 -0
- sky/provision/nebius/instance.py +19 -6
- sky/provision/nebius/utils.py +237 -137
- sky/provision/oci/instance.py +10 -5
- sky/provision/paperspace/instance.py +10 -7
- sky/provision/paperspace/utils.py +1 -1
- sky/provision/primeintellect/__init__.py +10 -0
- sky/provision/primeintellect/config.py +11 -0
- sky/provision/primeintellect/instance.py +454 -0
- sky/provision/primeintellect/utils.py +398 -0
- sky/provision/provisioner.py +117 -36
- sky/provision/runpod/__init__.py +5 -0
- sky/provision/runpod/instance.py +27 -6
- sky/provision/runpod/utils.py +51 -18
- sky/provision/runpod/volume.py +214 -0
- sky/provision/scp/__init__.py +15 -0
- sky/provision/scp/config.py +93 -0
- sky/provision/scp/instance.py +707 -0
- sky/provision/seeweb/__init__.py +11 -0
- sky/provision/seeweb/config.py +13 -0
- sky/provision/seeweb/instance.py +812 -0
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/ssh/__init__.py +18 -0
- sky/provision/vast/instance.py +13 -8
- sky/provision/vast/utils.py +10 -7
- sky/provision/volume.py +164 -0
- sky/provision/vsphere/common/ssl_helper.py +1 -1
- sky/provision/vsphere/common/vapiconnect.py +2 -1
- sky/provision/vsphere/common/vim_utils.py +4 -4
- sky/provision/vsphere/instance.py +15 -10
- sky/provision/vsphere/vsphere_utils.py +17 -20
- sky/py.typed +0 -0
- sky/resources.py +845 -119
- sky/schemas/__init__.py +0 -0
- sky/schemas/api/__init__.py +0 -0
- sky/schemas/api/responses.py +227 -0
- sky/schemas/db/README +4 -0
- sky/schemas/db/env.py +90 -0
- sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
- sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
- sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
- sky/schemas/db/global_user_state/004_is_managed.py +34 -0
- sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
- sky/schemas/db/global_user_state/006_provision_log.py +41 -0
- sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
- sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
- sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
- sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
- sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
- sky/schemas/db/script.py.mako +28 -0
- sky/schemas/db/serve_state/001_initial_schema.py +67 -0
- sky/schemas/db/serve_state/002_yaml_content.py +34 -0
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
- sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
- sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
- sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
- sky/schemas/generated/__init__.py +0 -0
- sky/schemas/generated/autostopv1_pb2.py +36 -0
- sky/schemas/generated/autostopv1_pb2.pyi +43 -0
- sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
- sky/schemas/generated/jobsv1_pb2.py +86 -0
- sky/schemas/generated/jobsv1_pb2.pyi +254 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
- sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
- sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
- sky/schemas/generated/servev1_pb2.py +58 -0
- sky/schemas/generated/servev1_pb2.pyi +115 -0
- sky/schemas/generated/servev1_pb2_grpc.py +322 -0
- sky/serve/autoscalers.py +357 -5
- sky/serve/client/impl.py +310 -0
- sky/serve/client/sdk.py +47 -139
- sky/serve/client/sdk_async.py +130 -0
- sky/serve/constants.py +12 -9
- sky/serve/controller.py +68 -17
- sky/serve/load_balancer.py +106 -60
- sky/serve/load_balancing_policies.py +116 -2
- sky/serve/replica_managers.py +434 -249
- sky/serve/serve_rpc_utils.py +179 -0
- sky/serve/serve_state.py +569 -257
- sky/serve/serve_utils.py +775 -265
- sky/serve/server/core.py +66 -711
- sky/serve/server/impl.py +1093 -0
- sky/serve/server/server.py +21 -18
- sky/serve/service.py +192 -89
- sky/serve/service_spec.py +144 -20
- sky/serve/spot_placer.py +3 -0
- sky/server/auth/__init__.py +0 -0
- sky/server/auth/authn.py +50 -0
- sky/server/auth/loopback.py +38 -0
- sky/server/auth/oauth2_proxy.py +202 -0
- sky/server/common.py +478 -182
- sky/server/config.py +85 -23
- sky/server/constants.py +44 -6
- sky/server/daemons.py +295 -0
- sky/server/html/token_page.html +185 -0
- sky/server/metrics.py +160 -0
- sky/server/middleware_utils.py +166 -0
- sky/server/requests/executor.py +558 -138
- sky/server/requests/payloads.py +364 -24
- sky/server/requests/preconditions.py +21 -17
- sky/server/requests/process.py +112 -29
- sky/server/requests/request_names.py +121 -0
- sky/server/requests/requests.py +822 -226
- sky/server/requests/serializers/decoders.py +82 -31
- sky/server/requests/serializers/encoders.py +140 -22
- sky/server/requests/threads.py +117 -0
- sky/server/rest.py +455 -0
- sky/server/server.py +1309 -285
- sky/server/state.py +20 -0
- sky/server/stream_utils.py +327 -61
- sky/server/uvicorn.py +217 -3
- sky/server/versions.py +270 -0
- sky/setup_files/MANIFEST.in +11 -1
- sky/setup_files/alembic.ini +160 -0
- sky/setup_files/dependencies.py +139 -31
- sky/setup_files/setup.py +44 -42
- sky/sky_logging.py +114 -7
- sky/skylet/attempt_skylet.py +106 -24
- sky/skylet/autostop_lib.py +129 -8
- sky/skylet/configs.py +29 -20
- sky/skylet/constants.py +216 -25
- sky/skylet/events.py +101 -21
- sky/skylet/job_lib.py +345 -164
- sky/skylet/log_lib.py +297 -18
- sky/skylet/log_lib.pyi +44 -1
- sky/skylet/providers/ibm/node_provider.py +12 -8
- sky/skylet/providers/ibm/vpc_provider.py +13 -12
- sky/skylet/ray_patches/__init__.py +17 -3
- sky/skylet/ray_patches/autoscaler.py.diff +18 -0
- sky/skylet/ray_patches/cli.py.diff +19 -0
- sky/skylet/ray_patches/command_runner.py.diff +17 -0
- sky/skylet/ray_patches/log_monitor.py.diff +20 -0
- sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
- sky/skylet/ray_patches/updater.py.diff +18 -0
- sky/skylet/ray_patches/worker.py.diff +41 -0
- sky/skylet/runtime_utils.py +21 -0
- sky/skylet/services.py +568 -0
- sky/skylet/skylet.py +72 -4
- sky/skylet/subprocess_daemon.py +104 -29
- sky/skypilot_config.py +506 -99
- sky/ssh_node_pools/__init__.py +1 -0
- sky/ssh_node_pools/core.py +135 -0
- sky/ssh_node_pools/server.py +233 -0
- sky/task.py +685 -163
- sky/templates/aws-ray.yml.j2 +11 -3
- sky/templates/azure-ray.yml.j2 +2 -1
- sky/templates/cudo-ray.yml.j2 +1 -0
- sky/templates/do-ray.yml.j2 +2 -1
- sky/templates/fluidstack-ray.yml.j2 +1 -0
- sky/templates/gcp-ray.yml.j2 +62 -1
- sky/templates/hyperbolic-ray.yml.j2 +68 -0
- sky/templates/ibm-ray.yml.j2 +2 -1
- sky/templates/jobs-controller.yaml.j2 +27 -24
- sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
- sky/templates/kubernetes-ray.yml.j2 +611 -50
- sky/templates/lambda-ray.yml.j2 +2 -1
- sky/templates/nebius-ray.yml.j2 +34 -12
- sky/templates/oci-ray.yml.j2 +1 -0
- sky/templates/paperspace-ray.yml.j2 +2 -1
- sky/templates/primeintellect-ray.yml.j2 +72 -0
- sky/templates/runpod-ray.yml.j2 +10 -1
- sky/templates/scp-ray.yml.j2 +4 -50
- sky/templates/seeweb-ray.yml.j2 +171 -0
- sky/templates/shadeform-ray.yml.j2 +73 -0
- sky/templates/sky-serve-controller.yaml.j2 +22 -2
- sky/templates/vast-ray.yml.j2 +1 -0
- sky/templates/vsphere-ray.yml.j2 +1 -0
- sky/templates/websocket_proxy.py +212 -37
- sky/usage/usage_lib.py +31 -15
- sky/users/__init__.py +0 -0
- sky/users/model.conf +15 -0
- sky/users/permission.py +397 -0
- sky/users/rbac.py +121 -0
- sky/users/server.py +720 -0
- sky/users/token_service.py +218 -0
- sky/utils/accelerator_registry.py +35 -5
- sky/utils/admin_policy_utils.py +84 -38
- sky/utils/annotations.py +38 -5
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/atomic.py +1 -1
- sky/utils/auth_utils.py +153 -0
- sky/utils/benchmark_utils.py +60 -0
- sky/utils/cli_utils/status_utils.py +159 -86
- sky/utils/cluster_utils.py +31 -9
- sky/utils/command_runner.py +354 -68
- sky/utils/command_runner.pyi +93 -3
- sky/utils/common.py +35 -8
- sky/utils/common_utils.py +314 -91
- sky/utils/config_utils.py +74 -5
- sky/utils/context.py +403 -0
- sky/utils/context_utils.py +242 -0
- sky/utils/controller_utils.py +383 -89
- sky/utils/dag_utils.py +31 -12
- sky/utils/db/__init__.py +0 -0
- sky/utils/db/db_utils.py +485 -0
- sky/utils/db/kv_cache.py +149 -0
- sky/utils/db/migration_utils.py +137 -0
- sky/utils/directory_utils.py +12 -0
- sky/utils/env_options.py +13 -0
- sky/utils/git.py +567 -0
- sky/utils/git_clone.sh +460 -0
- sky/utils/infra_utils.py +195 -0
- sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
- sky/utils/kubernetes/config_map_utils.py +133 -0
- sky/utils/kubernetes/create_cluster.sh +15 -29
- sky/utils/kubernetes/delete_cluster.sh +10 -7
- sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
- sky/utils/kubernetes/generate_kind_config.py +6 -66
- sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
- sky/utils/kubernetes/gpu_labeler.py +18 -8
- sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
- sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
- sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
- sky/utils/kubernetes/rsync_helper.sh +11 -3
- sky/utils/kubernetes/ssh-tunnel.sh +379 -0
- sky/utils/kubernetes/ssh_utils.py +221 -0
- sky/utils/kubernetes_enums.py +8 -15
- sky/utils/lock_events.py +94 -0
- sky/utils/locks.py +416 -0
- sky/utils/log_utils.py +82 -107
- sky/utils/perf_utils.py +22 -0
- sky/utils/resource_checker.py +298 -0
- sky/utils/resources_utils.py +249 -32
- sky/utils/rich_utils.py +217 -39
- sky/utils/schemas.py +955 -160
- sky/utils/serialize_utils.py +16 -0
- sky/utils/status_lib.py +10 -0
- sky/utils/subprocess_utils.py +29 -15
- sky/utils/tempstore.py +70 -0
- sky/utils/thread_utils.py +91 -0
- sky/utils/timeline.py +26 -53
- sky/utils/ux_utils.py +84 -15
- sky/utils/validator.py +11 -1
- sky/utils/volume.py +165 -0
- sky/utils/yaml_utils.py +111 -0
- sky/volumes/__init__.py +13 -0
- sky/volumes/client/__init__.py +0 -0
- sky/volumes/client/sdk.py +150 -0
- sky/volumes/server/__init__.py +0 -0
- sky/volumes/server/core.py +270 -0
- sky/volumes/server/server.py +124 -0
- sky/volumes/volume.py +215 -0
- sky/workspaces/__init__.py +0 -0
- sky/workspaces/core.py +655 -0
- sky/workspaces/server.py +101 -0
- sky/workspaces/utils.py +56 -0
- sky_templates/README.md +3 -0
- sky_templates/__init__.py +3 -0
- sky_templates/ray/__init__.py +0 -0
- sky_templates/ray/start_cluster +183 -0
- sky_templates/ray/stop_cluster +75 -0
- skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
- skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
- skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
- sky/benchmark/benchmark_state.py +0 -256
- sky/benchmark/benchmark_utils.py +0 -641
- sky/clouds/service_catalog/constants.py +0 -7
- sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
- sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
- sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
- sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
- sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
- sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
- sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
- sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
- sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
- sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
- sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
- sky/jobs/dashboard/dashboard.py +0 -223
- sky/jobs/dashboard/static/favicon.ico +0 -0
- sky/jobs/dashboard/templates/index.html +0 -831
- sky/jobs/server/dashboard_utils.py +0 -69
- sky/skylet/providers/scp/__init__.py +0 -2
- sky/skylet/providers/scp/config.py +0 -149
- sky/skylet/providers/scp/node_provider.py +0 -578
- sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
- sky/utils/db_utils.py +0 -100
- sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
- skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
- skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
- skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
- /sky/{clouds/service_catalog → catalog}/config.py +0 -0
- /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
- /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
- /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
- /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/server/server.py
CHANGED
|
@@ -2,55 +2,91 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import asyncio
|
|
5
|
+
import base64
|
|
6
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
7
|
import contextlib
|
|
6
|
-
import dataclasses
|
|
7
8
|
import datetime
|
|
8
|
-
import
|
|
9
|
+
from enum import IntEnum
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
9
12
|
import multiprocessing
|
|
10
13
|
import os
|
|
11
14
|
import pathlib
|
|
15
|
+
import posixpath
|
|
12
16
|
import re
|
|
17
|
+
import resource
|
|
13
18
|
import shutil
|
|
19
|
+
import struct
|
|
14
20
|
import sys
|
|
15
|
-
|
|
21
|
+
import threading
|
|
22
|
+
import traceback
|
|
23
|
+
from typing import Dict, List, Literal, Optional, Set, Tuple
|
|
16
24
|
import uuid
|
|
17
25
|
import zipfile
|
|
18
26
|
|
|
19
27
|
import aiofiles
|
|
28
|
+
import anyio
|
|
20
29
|
import fastapi
|
|
30
|
+
from fastapi import responses as fastapi_responses
|
|
21
31
|
from fastapi.middleware import cors
|
|
22
32
|
import starlette.middleware.base
|
|
33
|
+
import uvloop
|
|
23
34
|
|
|
24
35
|
import sky
|
|
36
|
+
from sky import catalog
|
|
25
37
|
from sky import check as sky_check
|
|
26
38
|
from sky import clouds
|
|
27
39
|
from sky import core
|
|
28
40
|
from sky import exceptions
|
|
29
41
|
from sky import execution
|
|
30
42
|
from sky import global_user_state
|
|
43
|
+
from sky import models
|
|
31
44
|
from sky import sky_logging
|
|
32
|
-
from sky.clouds import service_catalog
|
|
33
45
|
from sky.data import storage_utils
|
|
46
|
+
from sky.jobs import utils as managed_job_utils
|
|
34
47
|
from sky.jobs.server import server as jobs_rest
|
|
48
|
+
from sky.metrics import utils as metrics_utils
|
|
49
|
+
from sky.provision import metadata_utils
|
|
35
50
|
from sky.provision.kubernetes import utils as kubernetes_utils
|
|
51
|
+
from sky.schemas.api import responses
|
|
36
52
|
from sky.serve.server import server as serve_rest
|
|
37
53
|
from sky.server import common
|
|
38
54
|
from sky.server import config as server_config
|
|
39
55
|
from sky.server import constants as server_constants
|
|
56
|
+
from sky.server import daemons
|
|
57
|
+
from sky.server import metrics
|
|
58
|
+
from sky.server import middleware_utils
|
|
59
|
+
from sky.server import state
|
|
40
60
|
from sky.server import stream_utils
|
|
61
|
+
from sky.server import versions
|
|
62
|
+
from sky.server.auth import authn
|
|
63
|
+
from sky.server.auth import loopback
|
|
64
|
+
from sky.server.auth import oauth2_proxy
|
|
41
65
|
from sky.server.requests import executor
|
|
42
66
|
from sky.server.requests import payloads
|
|
43
67
|
from sky.server.requests import preconditions
|
|
68
|
+
from sky.server.requests import request_names
|
|
44
69
|
from sky.server.requests import requests as requests_lib
|
|
45
70
|
from sky.skylet import constants
|
|
71
|
+
from sky.ssh_node_pools import server as ssh_node_pools_rest
|
|
46
72
|
from sky.usage import usage_lib
|
|
73
|
+
from sky.users import permission
|
|
74
|
+
from sky.users import server as users_rest
|
|
47
75
|
from sky.utils import admin_policy_utils
|
|
48
76
|
from sky.utils import common as common_lib
|
|
49
77
|
from sky.utils import common_utils
|
|
78
|
+
from sky.utils import context
|
|
79
|
+
from sky.utils import context_utils
|
|
80
|
+
from sky.utils import controller_utils
|
|
50
81
|
from sky.utils import dag_utils
|
|
51
82
|
from sky.utils import env_options
|
|
83
|
+
from sky.utils import perf_utils
|
|
52
84
|
from sky.utils import status_lib
|
|
53
85
|
from sky.utils import subprocess_utils
|
|
86
|
+
from sky.utils import ux_utils
|
|
87
|
+
from sky.utils.db import db_utils
|
|
88
|
+
from sky.volumes.server import server as volumes_rest
|
|
89
|
+
from sky.workspaces import server as workspaces_rest
|
|
54
90
|
|
|
55
91
|
# pylint: disable=ungrouped-imports
|
|
56
92
|
if sys.version_info >= (3, 10):
|
|
@@ -60,31 +96,8 @@ else:
|
|
|
60
96
|
|
|
61
97
|
P = ParamSpec('P')
|
|
62
98
|
|
|
99
|
+
_SERVER_USER_HASH_KEY = 'server_user_hash'
|
|
63
100
|
|
|
64
|
-
def _add_timestamp_prefix_for_server_logs() -> None:
|
|
65
|
-
server_logger = sky_logging.init_logger('sky.server')
|
|
66
|
-
# Clear existing handlers first to prevent duplicates
|
|
67
|
-
server_logger.handlers.clear()
|
|
68
|
-
# Disable propagation to avoid the root logger of SkyPilot being affected
|
|
69
|
-
server_logger.propagate = False
|
|
70
|
-
# Add date prefix to the log message printed by loggers under
|
|
71
|
-
# server.
|
|
72
|
-
stream_handler = logging.StreamHandler(sys.stdout)
|
|
73
|
-
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
74
|
-
stream_handler.setLevel(logging.DEBUG)
|
|
75
|
-
else:
|
|
76
|
-
stream_handler.setLevel(logging.INFO)
|
|
77
|
-
stream_handler.flush = sys.stdout.flush # type: ignore
|
|
78
|
-
stream_handler.setFormatter(sky_logging.FORMATTER)
|
|
79
|
-
server_logger.addHandler(stream_handler)
|
|
80
|
-
# Add date prefix to the log message printed by uvicorn.
|
|
81
|
-
for name in ['uvicorn', 'uvicorn.access']:
|
|
82
|
-
uvicorn_logger = logging.getLogger(name)
|
|
83
|
-
uvicorn_logger.handlers.clear()
|
|
84
|
-
uvicorn_logger.addHandler(stream_handler)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
_add_timestamp_prefix_for_server_logs()
|
|
88
101
|
logger = sky_logging.init_logger(__name__)
|
|
89
102
|
|
|
90
103
|
# TODO(zhwu): Streaming requests, such log tailing after sky launch or sky logs,
|
|
@@ -92,17 +105,315 @@ logger = sky_logging.init_logger(__name__)
|
|
|
92
105
|
# response will block other requests from being processed.
|
|
93
106
|
|
|
94
107
|
|
|
108
|
+
def _basic_auth_401_response(content: str):
|
|
109
|
+
"""Return a 401 response with basic auth realm."""
|
|
110
|
+
return fastapi.responses.JSONResponse(
|
|
111
|
+
status_code=401,
|
|
112
|
+
headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
|
|
113
|
+
content=content)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _try_set_basic_auth_user(request: fastapi.Request):
|
|
117
|
+
auth_header = request.headers.get('authorization')
|
|
118
|
+
if not auth_header or not auth_header.lower().startswith('basic '):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
# Check username and password
|
|
122
|
+
encoded = auth_header.split(' ', 1)[1]
|
|
123
|
+
try:
|
|
124
|
+
decoded = base64.b64decode(encoded).decode()
|
|
125
|
+
username, password = decoded.split(':', 1)
|
|
126
|
+
except Exception: # pylint: disable=broad-except
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
users = global_user_state.get_user_by_name(username)
|
|
130
|
+
if not users:
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
for user in users:
|
|
134
|
+
if not user.name or not user.password:
|
|
135
|
+
continue
|
|
136
|
+
username_encoded = username.encode('utf8')
|
|
137
|
+
db_username_encoded = user.name.encode('utf8')
|
|
138
|
+
if (username_encoded == db_username_encoded and
|
|
139
|
+
common.crypt_ctx.verify(password, user.password)):
|
|
140
|
+
request.state.auth_user = user
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@middleware_utils.websocket_aware
|
|
145
|
+
class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
146
|
+
"""Middleware to handle RBAC."""
|
|
147
|
+
|
|
148
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
149
|
+
# TODO(hailong): should have a list of paths
|
|
150
|
+
# that are not checked for RBAC
|
|
151
|
+
if (request.url.path.startswith('/dashboard/') or
|
|
152
|
+
request.url.path.startswith('/api/')):
|
|
153
|
+
return await call_next(request)
|
|
154
|
+
|
|
155
|
+
auth_user = request.state.auth_user
|
|
156
|
+
if auth_user is None:
|
|
157
|
+
return await call_next(request)
|
|
158
|
+
|
|
159
|
+
permission_service = permission.permission_service
|
|
160
|
+
# Check the role permission
|
|
161
|
+
if permission_service.check_endpoint_permission(auth_user.id,
|
|
162
|
+
request.url.path,
|
|
163
|
+
request.method):
|
|
164
|
+
return fastapi.responses.JSONResponse(
|
|
165
|
+
status_code=403, content={'detail': 'Forbidden'})
|
|
166
|
+
|
|
167
|
+
return await call_next(request)
|
|
168
|
+
|
|
169
|
+
|
|
95
170
|
class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
96
171
|
"""Middleware to add a request ID to each request."""
|
|
97
172
|
|
|
98
173
|
async def dispatch(self, request: fastapi.Request, call_next):
|
|
99
|
-
request_id =
|
|
174
|
+
request_id = requests_lib.get_new_request_id()
|
|
100
175
|
request.state.request_id = request_id
|
|
101
176
|
response = await call_next(request)
|
|
102
|
-
response.headers['X-Request-ID'] = request_id
|
|
177
|
+
response.headers['X-Skypilot-Request-ID'] = request_id
|
|
103
178
|
return response
|
|
104
179
|
|
|
105
180
|
|
|
181
|
+
def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
|
|
182
|
+
header_name = os.environ.get(constants.ENV_VAR_SERVER_AUTH_USER_HEADER,
|
|
183
|
+
'X-Auth-Request-Email')
|
|
184
|
+
if header_name not in request.headers:
|
|
185
|
+
return None
|
|
186
|
+
user_name = request.headers[header_name]
|
|
187
|
+
user_hash = hashlib.md5(
|
|
188
|
+
user_name.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
|
|
189
|
+
return models.User(id=user_hash, name=user_name)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@middleware_utils.websocket_aware
|
|
193
|
+
class InitializeRequestAuthUserMiddleware(
|
|
194
|
+
starlette.middleware.base.BaseHTTPMiddleware):
|
|
195
|
+
|
|
196
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
197
|
+
# Make sure that request.state.auth_user is set. Otherwise, we may get a
|
|
198
|
+
# KeyError while trying to read it.
|
|
199
|
+
request.state.auth_user = None
|
|
200
|
+
return await call_next(request)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@middleware_utils.websocket_aware
|
|
204
|
+
class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
205
|
+
"""Middleware to handle HTTP Basic Auth."""
|
|
206
|
+
|
|
207
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
208
|
+
if managed_job_utils.is_consolidation_mode(
|
|
209
|
+
) and loopback.is_loopback_request(request):
|
|
210
|
+
return await call_next(request)
|
|
211
|
+
|
|
212
|
+
if request.url.path.startswith('/api/health'):
|
|
213
|
+
# Try to set the auth user from basic auth
|
|
214
|
+
_try_set_basic_auth_user(request)
|
|
215
|
+
return await call_next(request)
|
|
216
|
+
|
|
217
|
+
auth_header = request.headers.get('authorization')
|
|
218
|
+
if not auth_header:
|
|
219
|
+
return _basic_auth_401_response('Authentication required')
|
|
220
|
+
|
|
221
|
+
# Only handle basic auth
|
|
222
|
+
if not auth_header.lower().startswith('basic '):
|
|
223
|
+
return _basic_auth_401_response('Invalid authentication method')
|
|
224
|
+
|
|
225
|
+
# Check username and password
|
|
226
|
+
encoded = auth_header.split(' ', 1)[1]
|
|
227
|
+
try:
|
|
228
|
+
decoded = base64.b64decode(encoded).decode()
|
|
229
|
+
username, password = decoded.split(':', 1)
|
|
230
|
+
except Exception: # pylint: disable=broad-except
|
|
231
|
+
return _basic_auth_401_response('Invalid basic auth')
|
|
232
|
+
|
|
233
|
+
users = global_user_state.get_user_by_name(username)
|
|
234
|
+
if not users:
|
|
235
|
+
return _basic_auth_401_response('Invalid credentials')
|
|
236
|
+
|
|
237
|
+
valid_user = False
|
|
238
|
+
for user in users:
|
|
239
|
+
if not user.name or not user.password:
|
|
240
|
+
continue
|
|
241
|
+
username_encoded = username.encode('utf8')
|
|
242
|
+
db_username_encoded = user.name.encode('utf8')
|
|
243
|
+
if (username_encoded == db_username_encoded and
|
|
244
|
+
common.crypt_ctx.verify(password, user.password)):
|
|
245
|
+
valid_user = True
|
|
246
|
+
request.state.auth_user = user
|
|
247
|
+
await authn.override_user_info_in_request_body(request, user)
|
|
248
|
+
break
|
|
249
|
+
if not valid_user:
|
|
250
|
+
return _basic_auth_401_response('Invalid credentials')
|
|
251
|
+
|
|
252
|
+
return await call_next(request)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@middleware_utils.websocket_aware
|
|
256
|
+
class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
257
|
+
"""Middleware to handle Bearer Token Auth (Service Accounts)."""
|
|
258
|
+
|
|
259
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
260
|
+
"""Make sure correct bearer token auth is present.
|
|
261
|
+
|
|
262
|
+
1. If the request has the X-Skypilot-Auth-Mode: token header, it must
|
|
263
|
+
have a valid bearer token.
|
|
264
|
+
2. For backwards compatibility, if the request has a Bearer token
|
|
265
|
+
beginning with "sky_" (even if X-Skypilot-Auth-Mode is not present),
|
|
266
|
+
it must be a valid token.
|
|
267
|
+
3. If X-Skypilot-Auth-Mode is not set to "token", and there is no Bearer
|
|
268
|
+
token beginning with "sky_", allow the request to continue.
|
|
269
|
+
|
|
270
|
+
In conjunction with an auth proxy, the idea is to make the auth proxy
|
|
271
|
+
bypass requests with bearer tokens, instead setting the
|
|
272
|
+
X-Skypilot-Auth-Mode header. The auth proxy should either validate the
|
|
273
|
+
auth or set the header X-Skypilot-Auth-Mode: token.
|
|
274
|
+
"""
|
|
275
|
+
has_skypilot_auth_header = (
|
|
276
|
+
request.headers.get('X-Skypilot-Auth-Mode') == 'token')
|
|
277
|
+
auth_header = request.headers.get('authorization')
|
|
278
|
+
has_bearer_token_starting_with_sky = (
|
|
279
|
+
auth_header and auth_header.lower().startswith('bearer ') and
|
|
280
|
+
auth_header.split(' ', 1)[1].startswith('sky_'))
|
|
281
|
+
|
|
282
|
+
if (not has_skypilot_auth_header and
|
|
283
|
+
not has_bearer_token_starting_with_sky):
|
|
284
|
+
# This is case #3 above. We do not need to validate the request.
|
|
285
|
+
# No Bearer token, continue with normal processing (OAuth2 cookies,
|
|
286
|
+
# etc.)
|
|
287
|
+
return await call_next(request)
|
|
288
|
+
# After this point, all requests must be validated.
|
|
289
|
+
|
|
290
|
+
if auth_header is None:
|
|
291
|
+
return fastapi.responses.JSONResponse(
|
|
292
|
+
status_code=401, content={'detail': 'Authentication required'})
|
|
293
|
+
|
|
294
|
+
# Extract token
|
|
295
|
+
split_header = auth_header.split(' ', 1)
|
|
296
|
+
if split_header[0].lower() != 'bearer':
|
|
297
|
+
return fastapi.responses.JSONResponse(
|
|
298
|
+
status_code=401,
|
|
299
|
+
content={'detail': 'Invalid authentication method'})
|
|
300
|
+
sa_token = split_header[1]
|
|
301
|
+
|
|
302
|
+
# Handle SkyPilot service account tokens
|
|
303
|
+
return await self._handle_service_account_token(request, sa_token,
|
|
304
|
+
call_next)
|
|
305
|
+
|
|
306
|
+
async def _handle_service_account_token(self, request: fastapi.Request,
|
|
307
|
+
sa_token: str, call_next):
|
|
308
|
+
"""Handle SkyPilot service account tokens."""
|
|
309
|
+
# Check if service account tokens are enabled
|
|
310
|
+
sa_enabled = os.environ.get(constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
|
|
311
|
+
'false').lower()
|
|
312
|
+
if sa_enabled != 'true':
|
|
313
|
+
return fastapi.responses.JSONResponse(
|
|
314
|
+
status_code=401,
|
|
315
|
+
content={'detail': 'Service account authentication disabled'})
|
|
316
|
+
|
|
317
|
+
try:
|
|
318
|
+
# Import here to avoid circular imports
|
|
319
|
+
# pylint: disable=import-outside-toplevel
|
|
320
|
+
from sky.users.token_service import token_service
|
|
321
|
+
|
|
322
|
+
# Verify and decode JWT token
|
|
323
|
+
payload = token_service.verify_token(sa_token)
|
|
324
|
+
|
|
325
|
+
if payload is None:
|
|
326
|
+
logger.warning('Service account token verification failed')
|
|
327
|
+
return fastapi.responses.JSONResponse(
|
|
328
|
+
status_code=401,
|
|
329
|
+
content={
|
|
330
|
+
'detail': 'Invalid or expired service account token'
|
|
331
|
+
})
|
|
332
|
+
|
|
333
|
+
# Extract user information from JWT payload
|
|
334
|
+
user_id = payload.get('sub')
|
|
335
|
+
user_name = payload.get('name')
|
|
336
|
+
token_id = payload.get('token_id')
|
|
337
|
+
|
|
338
|
+
if not user_id or not token_id:
|
|
339
|
+
logger.warning(
|
|
340
|
+
'Invalid token payload: missing user_id or token_id')
|
|
341
|
+
return fastapi.responses.JSONResponse(
|
|
342
|
+
status_code=401,
|
|
343
|
+
content={'detail': 'Invalid token payload'})
|
|
344
|
+
|
|
345
|
+
# Verify user still exists in database
|
|
346
|
+
user_info = global_user_state.get_user(user_id)
|
|
347
|
+
if user_info is None:
|
|
348
|
+
logger.warning(
|
|
349
|
+
f'Service account user {user_id} no longer exists')
|
|
350
|
+
return fastapi.responses.JSONResponse(
|
|
351
|
+
status_code=401,
|
|
352
|
+
content={'detail': 'Service account user no longer exists'})
|
|
353
|
+
|
|
354
|
+
# Update last used timestamp for token tracking
|
|
355
|
+
try:
|
|
356
|
+
global_user_state.update_service_account_token_last_used(
|
|
357
|
+
token_id)
|
|
358
|
+
except Exception as e: # pylint: disable=broad-except
|
|
359
|
+
logger.debug(f'Failed to update token last used time: {e}')
|
|
360
|
+
|
|
361
|
+
# Set the authenticated user
|
|
362
|
+
auth_user = models.User(id=user_id,
|
|
363
|
+
name=user_name or user_info.name)
|
|
364
|
+
request.state.auth_user = auth_user
|
|
365
|
+
|
|
366
|
+
# Override user info in request body for service account requests
|
|
367
|
+
await authn.override_user_info_in_request_body(request, auth_user)
|
|
368
|
+
|
|
369
|
+
logger.debug(f'Authenticated service account: {user_id}')
|
|
370
|
+
|
|
371
|
+
except Exception as e: # pylint: disable=broad-except
|
|
372
|
+
logger.error(f'Service account authentication failed: {e}',
|
|
373
|
+
exc_info=True)
|
|
374
|
+
return fastapi.responses.JSONResponse(
|
|
375
|
+
status_code=401,
|
|
376
|
+
content={
|
|
377
|
+
'detail': f'Service account authentication failed: {str(e)}'
|
|
378
|
+
})
|
|
379
|
+
|
|
380
|
+
return await call_next(request)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@middleware_utils.websocket_aware
|
|
384
|
+
class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
385
|
+
"""Middleware to handle auth proxy."""
|
|
386
|
+
|
|
387
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
388
|
+
auth_user = _get_auth_user_header(request)
|
|
389
|
+
|
|
390
|
+
if request.state.auth_user is not None:
|
|
391
|
+
# Previous middleware is trusted more than this middleware. For
|
|
392
|
+
# instance, a client could set the Authorization and the
|
|
393
|
+
# X-Auth-Request-Email header. In that case, the auth proxy will be
|
|
394
|
+
# skipped and we should rely on the Bearer token to authenticate the
|
|
395
|
+
# user - but that means the user could set X-Auth-Request-Email to
|
|
396
|
+
# whatever the user wants. We should thus ignore it.
|
|
397
|
+
if auth_user is not None:
|
|
398
|
+
logger.debug('Warning: ignoring auth proxy header since the '
|
|
399
|
+
'auth user was already set.')
|
|
400
|
+
return await call_next(request)
|
|
401
|
+
|
|
402
|
+
# Add user to database if auth_user is present
|
|
403
|
+
if auth_user is not None:
|
|
404
|
+
newly_added = global_user_state.add_or_update_user(auth_user)
|
|
405
|
+
if newly_added:
|
|
406
|
+
permission.permission_service.add_user_if_not_exists(
|
|
407
|
+
auth_user.id)
|
|
408
|
+
|
|
409
|
+
# Store user info in request.state for access by GET endpoints
|
|
410
|
+
if auth_user is not None:
|
|
411
|
+
request.state.auth_user = auth_user
|
|
412
|
+
|
|
413
|
+
await authn.override_user_info_in_request_body(request, auth_user)
|
|
414
|
+
return await call_next(request)
|
|
415
|
+
|
|
416
|
+
|
|
106
417
|
# Default expiration time for upload ids before cleanup.
|
|
107
418
|
_DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
|
|
108
419
|
# Key: (upload_id, user_hash), Value: the time when the upload id needs to be
|
|
@@ -132,21 +443,74 @@ async def cleanup_upload_ids():
|
|
|
132
443
|
upload_ids_to_cleanup.pop((upload_id, user_hash))
|
|
133
444
|
|
|
134
445
|
|
|
446
|
+
async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
|
|
447
|
+
interval: float = 0.1) -> None:
|
|
448
|
+
target = loop.time() + interval
|
|
449
|
+
|
|
450
|
+
pid = str(os.getpid())
|
|
451
|
+
lag_threshold = perf_utils.get_loop_lag_threshold()
|
|
452
|
+
|
|
453
|
+
def tick():
|
|
454
|
+
nonlocal target
|
|
455
|
+
now = loop.time()
|
|
456
|
+
lag = max(0.0, now - target)
|
|
457
|
+
if lag_threshold is not None and lag > lag_threshold:
|
|
458
|
+
logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
|
|
459
|
+
f'{lag_threshold} seconds.')
|
|
460
|
+
metrics_utils.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
|
|
461
|
+
pid=pid).observe(lag)
|
|
462
|
+
target = now + interval
|
|
463
|
+
loop.call_at(target, tick)
|
|
464
|
+
|
|
465
|
+
loop.call_at(target, tick)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
async def schedule_on_boot_check_async():
|
|
469
|
+
try:
|
|
470
|
+
await executor.schedule_request_async(
|
|
471
|
+
request_id='skypilot-server-on-boot-check',
|
|
472
|
+
request_name=request_names.RequestName.CHECK,
|
|
473
|
+
request_body=payloads.CheckBody(),
|
|
474
|
+
func=sky_check.check,
|
|
475
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
476
|
+
is_skypilot_system=True,
|
|
477
|
+
)
|
|
478
|
+
except exceptions.RequestAlreadyExistsError:
|
|
479
|
+
# Lifespan will be executed in each uvicorn worker process, we
|
|
480
|
+
# can safely ignore the error if the task is already scheduled.
|
|
481
|
+
logger.debug('Request skypilot-server-on-boot-check already exists.')
|
|
482
|
+
|
|
483
|
+
|
|
135
484
|
@contextlib.asynccontextmanager
|
|
136
485
|
async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
|
|
137
486
|
"""FastAPI lifespan context manager."""
|
|
138
487
|
del app # unused
|
|
139
488
|
# Startup: Run background tasks
|
|
140
|
-
for event in
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
489
|
+
for event in daemons.INTERNAL_REQUEST_DAEMONS:
|
|
490
|
+
if event.should_skip():
|
|
491
|
+
continue
|
|
492
|
+
try:
|
|
493
|
+
await executor.schedule_request_async(
|
|
494
|
+
request_id=event.id,
|
|
495
|
+
request_name=event.name,
|
|
496
|
+
request_body=payloads.RequestBody(),
|
|
497
|
+
func=event.run_event,
|
|
498
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
499
|
+
is_skypilot_system=True,
|
|
500
|
+
# Request deamon should be retried if the process pool is
|
|
501
|
+
# broken.
|
|
502
|
+
retryable=True,
|
|
503
|
+
)
|
|
504
|
+
except exceptions.RequestAlreadyExistsError:
|
|
505
|
+
# Lifespan will be executed in each uvicorn worker process, we
|
|
506
|
+
# can safely ignore the error if the task is already scheduled.
|
|
507
|
+
logger.debug(f'Request {event.id} already exists.')
|
|
508
|
+
await schedule_on_boot_check_async()
|
|
149
509
|
asyncio.create_task(cleanup_upload_ids())
|
|
510
|
+
if metrics_utils.METRICS_ENABLED:
|
|
511
|
+
# Start monitoring the event loop lag in each server worker
|
|
512
|
+
# event loop (process).
|
|
513
|
+
asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
|
|
150
514
|
yield
|
|
151
515
|
# Shutdown: Add any cleanup code here if needed
|
|
152
516
|
|
|
@@ -164,8 +528,104 @@ class InternalDashboardPrefixMiddleware(
|
|
|
164
528
|
return await call_next(request)
|
|
165
529
|
|
|
166
530
|
|
|
531
|
+
class CacheControlStaticMiddleware(starlette.middleware.base.BaseHTTPMiddleware
|
|
532
|
+
):
|
|
533
|
+
"""Middleware to add cache control headers to static files."""
|
|
534
|
+
|
|
535
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
536
|
+
if request.url.path.startswith('/dashboard/_next'):
|
|
537
|
+
response = await call_next(request)
|
|
538
|
+
response.headers['Cache-Control'] = 'max-age=3600'
|
|
539
|
+
return response
|
|
540
|
+
return await call_next(request)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
544
|
+
"""Middleware to check the path of requests."""
|
|
545
|
+
|
|
546
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
547
|
+
if request.url.path.startswith('/dashboard/'):
|
|
548
|
+
# If the requested path is not relative to the expected directory,
|
|
549
|
+
# then the user is attempting path traversal, so deny the request.
|
|
550
|
+
parent = pathlib.Path('/dashboard')
|
|
551
|
+
request_path = pathlib.Path(posixpath.normpath(request.url.path))
|
|
552
|
+
if not _is_relative_to(request_path, parent):
|
|
553
|
+
return fastapi.responses.JSONResponse(
|
|
554
|
+
status_code=403, content={'detail': 'Forbidden'})
|
|
555
|
+
return await call_next(request)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
@middleware_utils.websocket_aware
|
|
559
|
+
class GracefulShutdownMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
560
|
+
"""Middleware to control requests when server is shutting down."""
|
|
561
|
+
|
|
562
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
563
|
+
if state.get_block_requests():
|
|
564
|
+
# Allow /api/ paths to continue, which are critical to operate
|
|
565
|
+
# on-going requests but will not submit new requests.
|
|
566
|
+
if not request.url.path.startswith('/api/'):
|
|
567
|
+
# Client will retry on 503 error.
|
|
568
|
+
return fastapi.responses.JSONResponse(
|
|
569
|
+
status_code=503,
|
|
570
|
+
content={
|
|
571
|
+
'detail': 'Server is shutting down, '
|
|
572
|
+
'please try again later.'
|
|
573
|
+
})
|
|
574
|
+
|
|
575
|
+
return await call_next(request)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
@middleware_utils.websocket_aware
|
|
579
|
+
class APIVersionMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
580
|
+
"""Middleware to add API version to the request."""
|
|
581
|
+
|
|
582
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
|
583
|
+
version_info = versions.check_compatibility_at_server(request.headers)
|
|
584
|
+
# Bypass version handling for backward compatibility with clients prior
|
|
585
|
+
# to v0.11.0, the client will check the version in the body of
|
|
586
|
+
# /api/health response and hint an upgrade.
|
|
587
|
+
# TODO(aylei): remove this after v0.13.0 is released.
|
|
588
|
+
if version_info is None:
|
|
589
|
+
return await call_next(request)
|
|
590
|
+
if version_info.error is None:
|
|
591
|
+
versions.set_remote_api_version(version_info.api_version)
|
|
592
|
+
versions.set_remote_version(version_info.version)
|
|
593
|
+
response = await call_next(request)
|
|
594
|
+
else:
|
|
595
|
+
response = fastapi.responses.JSONResponse(
|
|
596
|
+
status_code=400,
|
|
597
|
+
content={
|
|
598
|
+
'error': common.ApiServerStatus.VERSION_MISMATCH.value,
|
|
599
|
+
'message': version_info.error,
|
|
600
|
+
})
|
|
601
|
+
response.headers[server_constants.API_VERSION_HEADER] = str(
|
|
602
|
+
server_constants.API_VERSION)
|
|
603
|
+
response.headers[server_constants.VERSION_HEADER] = \
|
|
604
|
+
versions.get_local_readable_version()
|
|
605
|
+
return response
|
|
606
|
+
|
|
607
|
+
|
|
167
608
|
app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
|
|
609
|
+
# Middleware wraps in the order defined here. E.g., given
|
|
610
|
+
# app.add_middleware(Middleware1)
|
|
611
|
+
# app.add_middleware(Middleware2)
|
|
612
|
+
# app.add_middleware(Middleware3)
|
|
613
|
+
# The effect will be like:
|
|
614
|
+
# Middleware3(Middleware2(Middleware1(request)))
|
|
615
|
+
# If MiddlewareN does something like print(n); call_next(); print(n), you'll get
|
|
616
|
+
# 3; 2; 1; <request>; 1; 2; 3
|
|
617
|
+
# Use environment variable to make the metrics middleware optional.
|
|
618
|
+
if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
|
|
619
|
+
app.add_middleware(metrics.PrometheusMiddleware)
|
|
620
|
+
app.add_middleware(APIVersionMiddleware)
|
|
621
|
+
# The order of all the authentication-related middleware is important.
|
|
622
|
+
# RBACMiddleware must precede all the auth middleware, so it can access
|
|
623
|
+
# request.state.auth_user.
|
|
624
|
+
app.add_middleware(RBACMiddleware)
|
|
168
625
|
app.add_middleware(InternalDashboardPrefixMiddleware)
|
|
626
|
+
app.add_middleware(GracefulShutdownMiddleware)
|
|
627
|
+
app.add_middleware(PathCleanMiddleware)
|
|
628
|
+
app.add_middleware(CacheControlStaticMiddleware)
|
|
169
629
|
app.add_middleware(
|
|
170
630
|
cors.CORSMiddleware,
|
|
171
631
|
# TODO(zhwu): in production deployment, we should restrict the allowed
|
|
@@ -174,19 +634,104 @@ app.add_middleware(
|
|
|
174
634
|
allow_credentials=True,
|
|
175
635
|
allow_methods=['*'],
|
|
176
636
|
allow_headers=['*'],
|
|
177
|
-
expose_headers=['X-Request-ID'])
|
|
637
|
+
expose_headers=['X-Skypilot-Request-ID'])
|
|
638
|
+
# Authentication based on oauth2-proxy.
|
|
639
|
+
app.add_middleware(oauth2_proxy.OAuth2ProxyMiddleware)
|
|
640
|
+
# AuthProxyMiddleware should precede BasicAuthMiddleware and
|
|
641
|
+
# BearerTokenMiddleware, since it should be skipped if either of those set the
|
|
642
|
+
# auth user.
|
|
643
|
+
app.add_middleware(AuthProxyMiddleware)
|
|
644
|
+
enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
|
|
645
|
+
if str(enable_basic_auth).lower() == 'true':
|
|
646
|
+
app.add_middleware(BasicAuthMiddleware)
|
|
647
|
+
# Bearer token middleware should always be present to handle service account
|
|
648
|
+
# authentication
|
|
649
|
+
app.add_middleware(BearerTokenMiddleware)
|
|
650
|
+
# InitializeRequestAuthUserMiddleware must be the last added middleware so that
|
|
651
|
+
# request.state.auth_user is always set, but can be overridden by the auth
|
|
652
|
+
# middleware above.
|
|
653
|
+
app.add_middleware(InitializeRequestAuthUserMiddleware)
|
|
178
654
|
app.add_middleware(RequestIDMiddleware)
|
|
179
655
|
app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
|
|
180
656
|
app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
|
|
657
|
+
app.include_router(users_rest.router, prefix='/users', tags=['users'])
|
|
658
|
+
app.include_router(workspaces_rest.router,
|
|
659
|
+
prefix='/workspaces',
|
|
660
|
+
tags=['workspaces'])
|
|
661
|
+
app.include_router(volumes_rest.router, prefix='/volumes', tags=['volumes'])
|
|
662
|
+
app.include_router(ssh_node_pools_rest.router,
|
|
663
|
+
prefix='/ssh_node_pools',
|
|
664
|
+
tags=['ssh_node_pools'])
|
|
665
|
+
# increase the resource limit for the server
|
|
666
|
+
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
667
|
+
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
@app.exception_handler(exceptions.ConcurrentWorkerExhaustedError)
|
|
671
|
+
def handle_concurrent_worker_exhausted_error(
|
|
672
|
+
request: fastapi.Request, e: exceptions.ConcurrentWorkerExhaustedError):
|
|
673
|
+
del request # request is not used
|
|
674
|
+
# Print detailed error message to server log
|
|
675
|
+
logger.error('Concurrent worker exhausted: '
|
|
676
|
+
f'{common_utils.format_exception(e)}')
|
|
677
|
+
with ux_utils.enable_traceback():
|
|
678
|
+
logger.error(f' Traceback: {traceback.format_exc()}')
|
|
679
|
+
# Return human readable error message to client
|
|
680
|
+
return fastapi.responses.JSONResponse(
|
|
681
|
+
status_code=503,
|
|
682
|
+
content={
|
|
683
|
+
'detail':
|
|
684
|
+
('The server has exhausted its concurrent worker limit. '
|
|
685
|
+
'Please try again or scale the server if the load persists.')
|
|
686
|
+
})
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
@app.get('/token')
|
|
690
|
+
async def token(request: fastapi.Request,
|
|
691
|
+
local_port: Optional[int] = None) -> fastapi.responses.Response:
|
|
692
|
+
del local_port # local_port is used by the served js, but ignored by server
|
|
693
|
+
user = _get_auth_user_header(request)
|
|
694
|
+
|
|
695
|
+
token_data = {
|
|
696
|
+
'v': 1, # Token version number, bump for backwards incompatible.
|
|
697
|
+
'user': user.id if user is not None else None,
|
|
698
|
+
'cookies': request.cookies,
|
|
699
|
+
}
|
|
700
|
+
# Use base64 encoding to avoid having to escape anything in the HTML.
|
|
701
|
+
json_bytes = json.dumps(token_data).encode('utf-8')
|
|
702
|
+
base64_str = base64.b64encode(json_bytes).decode('utf-8')
|
|
703
|
+
|
|
704
|
+
html_dir = pathlib.Path(__file__).parent / 'html'
|
|
705
|
+
token_page_path = html_dir / 'token_page.html'
|
|
706
|
+
try:
|
|
707
|
+
with open(token_page_path, 'r', encoding='utf-8') as f:
|
|
708
|
+
html_content = f.read()
|
|
709
|
+
except FileNotFoundError as e:
|
|
710
|
+
raise fastapi.HTTPException(
|
|
711
|
+
status_code=500, detail='Token page template not found.') from e
|
|
712
|
+
|
|
713
|
+
user_info_string = f'Logged in as {user.name}' if user is not None else ''
|
|
714
|
+
html_content = html_content.replace(
|
|
715
|
+
'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER',
|
|
716
|
+
base64_str).replace('USER_PLACEHOLDER', user_info_string)
|
|
717
|
+
|
|
718
|
+
return fastapi.responses.HTMLResponse(
|
|
719
|
+
content=html_content,
|
|
720
|
+
headers={
|
|
721
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
722
|
+
# X-Accel-Buffering: no is useful for preventing buffering issues
|
|
723
|
+
# with some reverse proxies.
|
|
724
|
+
'X-Accel-Buffering': 'no'
|
|
725
|
+
})
|
|
181
726
|
|
|
182
727
|
|
|
183
728
|
@app.post('/check')
|
|
184
729
|
async def check(request: fastapi.Request,
|
|
185
730
|
check_body: payloads.CheckBody) -> None:
|
|
186
731
|
"""Checks enabled clouds."""
|
|
187
|
-
executor.
|
|
732
|
+
await executor.schedule_request_async(
|
|
188
733
|
request_id=request.state.request_id,
|
|
189
|
-
request_name=
|
|
734
|
+
request_name=request_names.RequestName.CHECK,
|
|
190
735
|
request_body=check_body,
|
|
191
736
|
func=sky_check.check,
|
|
192
737
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -194,12 +739,15 @@ async def check(request: fastapi.Request,
|
|
|
194
739
|
|
|
195
740
|
|
|
196
741
|
@app.get('/enabled_clouds')
|
|
197
|
-
async def enabled_clouds(request: fastapi.Request
|
|
742
|
+
async def enabled_clouds(request: fastapi.Request,
|
|
743
|
+
workspace: Optional[str] = None,
|
|
744
|
+
expand: bool = False) -> None:
|
|
198
745
|
"""Gets enabled clouds on the server."""
|
|
199
|
-
executor.
|
|
746
|
+
await executor.schedule_request_async(
|
|
200
747
|
request_id=request.state.request_id,
|
|
201
|
-
request_name=
|
|
202
|
-
request_body=payloads.
|
|
748
|
+
request_name=request_names.RequestName.ENABLED_CLOUDS,
|
|
749
|
+
request_body=payloads.EnabledCloudsBody(workspace=workspace,
|
|
750
|
+
expand=expand),
|
|
203
751
|
func=core.enabled_clouds,
|
|
204
752
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
205
753
|
)
|
|
@@ -211,9 +759,10 @@ async def realtime_kubernetes_gpu_availability(
|
|
|
211
759
|
realtime_gpu_availability_body: payloads.RealtimeGpuAvailabilityRequestBody
|
|
212
760
|
) -> None:
|
|
213
761
|
"""Gets real-time Kubernetes GPU availability."""
|
|
214
|
-
executor.
|
|
762
|
+
await executor.schedule_request_async(
|
|
215
763
|
request_id=request.state.request_id,
|
|
216
|
-
request_name=
|
|
764
|
+
request_name=request_names.RequestName.
|
|
765
|
+
REALTIME_KUBERNETES_GPU_AVAILABILITY,
|
|
217
766
|
request_body=realtime_gpu_availability_body,
|
|
218
767
|
func=core.realtime_kubernetes_gpu_availability,
|
|
219
768
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -226,9 +775,9 @@ async def kubernetes_node_info(
|
|
|
226
775
|
kubernetes_node_info_body: payloads.KubernetesNodeInfoRequestBody
|
|
227
776
|
) -> None:
|
|
228
777
|
"""Gets Kubernetes nodes information and hints."""
|
|
229
|
-
executor.
|
|
778
|
+
await executor.schedule_request_async(
|
|
230
779
|
request_id=request.state.request_id,
|
|
231
|
-
request_name=
|
|
780
|
+
request_name=request_names.RequestName.KUBERNETES_NODE_INFO,
|
|
232
781
|
request_body=kubernetes_node_info_body,
|
|
233
782
|
func=kubernetes_utils.get_kubernetes_node_info,
|
|
234
783
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -237,10 +786,11 @@ async def kubernetes_node_info(
|
|
|
237
786
|
|
|
238
787
|
@app.get('/status_kubernetes')
|
|
239
788
|
async def status_kubernetes(request: fastapi.Request) -> None:
|
|
240
|
-
"""
|
|
241
|
-
|
|
789
|
+
"""[Experimental] Get all SkyPilot resources (including from other '
|
|
790
|
+
'users) in the current Kubernetes context."""
|
|
791
|
+
await executor.schedule_request_async(
|
|
242
792
|
request_id=request.state.request_id,
|
|
243
|
-
request_name=
|
|
793
|
+
request_name=request_names.RequestName.STATUS_KUBERNETES,
|
|
244
794
|
request_body=payloads.RequestBody(),
|
|
245
795
|
func=core.status_kubernetes,
|
|
246
796
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -252,11 +802,11 @@ async def list_accelerators(
|
|
|
252
802
|
request: fastapi.Request,
|
|
253
803
|
list_accelerator_counts_body: payloads.ListAcceleratorsBody) -> None:
|
|
254
804
|
"""Gets list of accelerators from cloud catalog."""
|
|
255
|
-
executor.
|
|
805
|
+
await executor.schedule_request_async(
|
|
256
806
|
request_id=request.state.request_id,
|
|
257
|
-
request_name=
|
|
807
|
+
request_name=request_names.RequestName.LIST_ACCELERATORS,
|
|
258
808
|
request_body=list_accelerator_counts_body,
|
|
259
|
-
func=
|
|
809
|
+
func=catalog.list_accelerators,
|
|
260
810
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
261
811
|
)
|
|
262
812
|
|
|
@@ -267,11 +817,11 @@ async def list_accelerator_counts(
|
|
|
267
817
|
list_accelerator_counts_body: payloads.ListAcceleratorCountsBody
|
|
268
818
|
) -> None:
|
|
269
819
|
"""Gets list of accelerator counts from cloud catalog."""
|
|
270
|
-
executor.
|
|
820
|
+
await executor.schedule_request_async(
|
|
271
821
|
request_id=request.state.request_id,
|
|
272
|
-
request_name=
|
|
822
|
+
request_name=request_names.RequestName.LIST_ACCELERATOR_COUNTS,
|
|
273
823
|
request_body=list_accelerator_counts_body,
|
|
274
|
-
func=
|
|
824
|
+
func=catalog.list_accelerator_counts,
|
|
275
825
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
276
826
|
)
|
|
277
827
|
|
|
@@ -289,26 +839,39 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
|
|
|
289
839
|
# pairs.
|
|
290
840
|
logger.debug(f'Validating tasks: {validate_body.dag}')
|
|
291
841
|
|
|
842
|
+
context.initialize()
|
|
843
|
+
ctx = context.get()
|
|
844
|
+
assert ctx is not None
|
|
845
|
+
# TODO(aylei): generalize this to all requests without a db record.
|
|
846
|
+
ctx.override_envs(validate_body.env_vars)
|
|
847
|
+
|
|
292
848
|
def validate_dag(dag: dag_utils.dag_lib.Dag):
|
|
293
849
|
# TODO: Admin policy may contain arbitrary code, which may be expensive
|
|
294
850
|
# to run and may block the server thread. However, moving it into the
|
|
295
851
|
# executor adds a ~150ms penalty on the local API server because of
|
|
296
852
|
# added RTTs. For now, we stick to doing the validation inline in the
|
|
297
853
|
# server thread.
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
854
|
+
with admin_policy_utils.apply_and_use_config_in_current_request(
|
|
855
|
+
dag,
|
|
856
|
+
request_name=request_names.AdminPolicyRequestName.VALIDATE,
|
|
857
|
+
request_options=validate_body.get_request_options()) as dag:
|
|
858
|
+
dag.resolve_and_validate_volumes()
|
|
859
|
+
# Skip validating workdir and file_mounts, as those need to be
|
|
860
|
+
# validated after the files are uploaded to the SkyPilot API server
|
|
861
|
+
# with `upload_mounts_to_api_server`.
|
|
862
|
+
dag.validate(skip_file_mounts=True, skip_workdir=True)
|
|
304
863
|
|
|
305
864
|
try:
|
|
306
865
|
dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
|
|
307
|
-
loop = asyncio.get_running_loop()
|
|
308
866
|
# Apply admin policy and validate DAG is blocking, run it in a separate
|
|
309
867
|
# thread executor to avoid blocking the uvicorn event loop.
|
|
310
|
-
await
|
|
868
|
+
await context_utils.to_thread(validate_dag, dag)
|
|
311
869
|
except Exception as e: # pylint: disable=broad-except
|
|
870
|
+
# Print the exception to the API server log.
|
|
871
|
+
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
872
|
+
logger.info('/validate exception:', exc_info=True)
|
|
873
|
+
# Set the exception stacktrace for the serialized exception.
|
|
874
|
+
requests_lib.set_exception_stacktrace(e)
|
|
312
875
|
raise fastapi.HTTPException(
|
|
313
876
|
status_code=400, detail=exceptions.serialize_exception(e)) from e
|
|
314
877
|
|
|
@@ -317,9 +880,9 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
|
|
|
317
880
|
async def optimize(optimize_body: payloads.OptimizeBody,
|
|
318
881
|
request: fastapi.Request) -> None:
|
|
319
882
|
"""Optimizes the user's DAG."""
|
|
320
|
-
executor.
|
|
883
|
+
await executor.schedule_request_async(
|
|
321
884
|
request_id=request.state.request_id,
|
|
322
|
-
request_name=
|
|
885
|
+
request_name=request_names.RequestName.OPTIMIZE,
|
|
323
886
|
request_body=optimize_body,
|
|
324
887
|
ignore_return_value=True,
|
|
325
888
|
func=core.optimize,
|
|
@@ -347,16 +910,30 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
347
910
|
chunk_index: The chunk index, starting from 0.
|
|
348
911
|
total_chunks: The total number of chunks.
|
|
349
912
|
"""
|
|
913
|
+
# Field _body would be set if the request body has been received, fail fast
|
|
914
|
+
# to surface potential memory issues, i.e. catch the issue in our smoke
|
|
915
|
+
# test.
|
|
916
|
+
# pylint: disable=protected-access
|
|
917
|
+
if hasattr(request, '_body'):
|
|
918
|
+
raise fastapi.HTTPException(
|
|
919
|
+
status_code=500,
|
|
920
|
+
detail='Upload request body should not be received before streaming'
|
|
921
|
+
)
|
|
350
922
|
# Add the upload id to the cleanup list.
|
|
351
923
|
upload_ids_to_cleanup[(upload_id,
|
|
352
924
|
user_hash)] = (datetime.datetime.now() +
|
|
353
925
|
_DEFAULT_UPLOAD_EXPIRATION_TIME)
|
|
926
|
+
# For anonymous access, use the user hash from client
|
|
927
|
+
user_id = user_hash
|
|
928
|
+
if request.state.auth_user is not None:
|
|
929
|
+
# Otherwise, the authenticated identity should be used.
|
|
930
|
+
user_id = request.state.auth_user.id
|
|
354
931
|
|
|
355
932
|
# TODO(SKY-1271): We need to double check security of uploading zip file.
|
|
356
933
|
client_file_mounts_dir = (
|
|
357
|
-
common.API_SERVER_CLIENT_DIR.expanduser().resolve() /
|
|
934
|
+
common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_id /
|
|
358
935
|
'file_mounts')
|
|
359
|
-
client_file_mounts_dir.mkdir(parents=True, exist_ok=True)
|
|
936
|
+
await anyio.Path(client_file_mounts_dir).mkdir(parents=True, exist_ok=True)
|
|
360
937
|
|
|
361
938
|
# Check upload_id to be a valid SkyPilot run_timestamp appended with 8 hex
|
|
362
939
|
# characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
|
|
@@ -379,7 +956,7 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
379
956
|
zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
|
|
380
957
|
else:
|
|
381
958
|
chunk_dir = client_file_mounts_dir / upload_id
|
|
382
|
-
chunk_dir.mkdir(parents=True, exist_ok=True)
|
|
959
|
+
await anyio.Path(chunk_dir).mkdir(parents=True, exist_ok=True)
|
|
383
960
|
zip_file_path = chunk_dir / f'part{chunk_index}.incomplete'
|
|
384
961
|
|
|
385
962
|
try:
|
|
@@ -409,8 +986,9 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
409
986
|
zip_file_path.rename(zip_file_path.with_suffix(''))
|
|
410
987
|
missing_chunks = get_missing_chunks(total_chunks)
|
|
411
988
|
if missing_chunks:
|
|
412
|
-
return payloads.UploadZipFileResponse(
|
|
413
|
-
|
|
989
|
+
return payloads.UploadZipFileResponse(
|
|
990
|
+
status=responses.UploadStatus.UPLOADING.value,
|
|
991
|
+
missing_chunks=missing_chunks)
|
|
414
992
|
zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
|
|
415
993
|
async with aiofiles.open(zip_file_path, 'wb') as zip_file:
|
|
416
994
|
for chunk in range(total_chunks):
|
|
@@ -424,10 +1002,11 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
|
|
|
424
1002
|
await zip_file.write(data)
|
|
425
1003
|
|
|
426
1004
|
logger.info(f'Uploaded zip file: {zip_file_path}')
|
|
427
|
-
unzip_file(zip_file_path, client_file_mounts_dir)
|
|
1005
|
+
await unzip_file(zip_file_path, client_file_mounts_dir)
|
|
428
1006
|
if total_chunks > 1:
|
|
429
|
-
shutil.rmtree
|
|
430
|
-
return payloads.UploadZipFileResponse(
|
|
1007
|
+
await context_utils.to_thread(shutil.rmtree, chunk_dir)
|
|
1008
|
+
return payloads.UploadZipFileResponse(
|
|
1009
|
+
status=responses.UploadStatus.COMPLETED.value)
|
|
431
1010
|
|
|
432
1011
|
|
|
433
1012
|
def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
|
|
@@ -440,61 +1019,69 @@ def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
|
|
|
440
1019
|
return False
|
|
441
1020
|
|
|
442
1021
|
|
|
443
|
-
def unzip_file(zip_file_path: pathlib.Path,
|
|
444
|
-
|
|
445
|
-
"""Unzips a zip file."""
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
1022
|
+
async def unzip_file(zip_file_path: pathlib.Path,
|
|
1023
|
+
client_file_mounts_dir: pathlib.Path) -> None:
|
|
1024
|
+
"""Unzips a zip file without blocking the event loop."""
|
|
1025
|
+
|
|
1026
|
+
def _do_unzip() -> None:
|
|
1027
|
+
try:
|
|
1028
|
+
with zipfile.ZipFile(zip_file_path, 'r') as zipf:
|
|
1029
|
+
for member in zipf.infolist():
|
|
1030
|
+
# Determine the new path
|
|
1031
|
+
original_path = os.path.normpath(member.filename)
|
|
1032
|
+
new_path = client_file_mounts_dir / original_path.lstrip(
|
|
1033
|
+
'/')
|
|
1034
|
+
|
|
1035
|
+
if (member.external_attr >> 28) == 0xA:
|
|
1036
|
+
# Symlink. Read the target path and create a symlink.
|
|
1037
|
+
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1038
|
+
target = zipf.read(member).decode()
|
|
1039
|
+
assert not os.path.isabs(target), target
|
|
1040
|
+
# Since target is a relative path, we need to check that
|
|
1041
|
+
# it is under `client_file_mounts_dir` for security.
|
|
1042
|
+
full_target_path = (new_path.parent / target).resolve()
|
|
1043
|
+
if not _is_relative_to(full_target_path,
|
|
1044
|
+
client_file_mounts_dir):
|
|
1045
|
+
raise ValueError(
|
|
1046
|
+
f'Symlink target {target} leads to a '
|
|
1047
|
+
'file not in userspace. Aborted.')
|
|
1048
|
+
|
|
1049
|
+
if new_path.exists() or new_path.is_symlink():
|
|
1050
|
+
new_path.unlink(missing_ok=True)
|
|
1051
|
+
new_path.symlink_to(
|
|
1052
|
+
target,
|
|
1053
|
+
target_is_directory=member.filename.endswith('/'))
|
|
1054
|
+
continue
|
|
1055
|
+
|
|
1056
|
+
# Handle directories
|
|
1057
|
+
if member.filename.endswith('/'):
|
|
1058
|
+
new_path.mkdir(parents=True, exist_ok=True)
|
|
1059
|
+
continue
|
|
1060
|
+
|
|
1061
|
+
# Handle files
|
|
455
1062
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
continue
|
|
477
|
-
|
|
478
|
-
# Handle files
|
|
479
|
-
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
480
|
-
with zipf.open(member) as member_file, new_path.open('wb') as f:
|
|
481
|
-
# Use shutil.copyfileobj to copy files in chunks, so it does
|
|
482
|
-
# not load the entire file into memory.
|
|
483
|
-
shutil.copyfileobj(member_file, f)
|
|
484
|
-
except zipfile.BadZipFile as e:
|
|
485
|
-
logger.error(f'Bad zip file: {zip_file_path}')
|
|
486
|
-
raise fastapi.HTTPException(
|
|
487
|
-
status_code=400,
|
|
488
|
-
detail=f'Invalid zip file: {common_utils.format_exception(e)}')
|
|
489
|
-
except Exception as e:
|
|
490
|
-
logger.error(f'Error unzipping file: {zip_file_path}')
|
|
491
|
-
raise fastapi.HTTPException(
|
|
492
|
-
status_code=500,
|
|
493
|
-
detail=(f'Error unzipping file: '
|
|
494
|
-
f'{common_utils.format_exception(e)}'))
|
|
1063
|
+
with zipf.open(member) as member_file, new_path.open(
|
|
1064
|
+
'wb') as f:
|
|
1065
|
+
# Use shutil.copyfileobj to copy files in chunks,
|
|
1066
|
+
# so it does not load the entire file into memory.
|
|
1067
|
+
shutil.copyfileobj(member_file, f)
|
|
1068
|
+
except zipfile.BadZipFile as e:
|
|
1069
|
+
logger.error(f'Bad zip file: {zip_file_path}')
|
|
1070
|
+
raise fastapi.HTTPException(
|
|
1071
|
+
status_code=400,
|
|
1072
|
+
detail=f'Invalid zip file: {common_utils.format_exception(e)}')
|
|
1073
|
+
except Exception as e:
|
|
1074
|
+
logger.error(f'Error unzipping file: {zip_file_path}')
|
|
1075
|
+
raise fastapi.HTTPException(
|
|
1076
|
+
status_code=500,
|
|
1077
|
+
detail=(f'Error unzipping file: '
|
|
1078
|
+
f'{common_utils.format_exception(e)}'))
|
|
1079
|
+
finally:
|
|
1080
|
+
# Cleanup the temporary file regardless of
|
|
1081
|
+
# success/failure handling above
|
|
1082
|
+
zip_file_path.unlink(missing_ok=True)
|
|
495
1083
|
|
|
496
|
-
|
|
497
|
-
zip_file_path.unlink()
|
|
1084
|
+
await context_utils.to_thread(_do_unzip)
|
|
498
1085
|
|
|
499
1086
|
|
|
500
1087
|
@app.post('/launch')
|
|
@@ -503,13 +1090,14 @@ async def launch(launch_body: payloads.LaunchBody,
|
|
|
503
1090
|
"""Launches a cluster or task."""
|
|
504
1091
|
request_id = request.state.request_id
|
|
505
1092
|
logger.info(f'Launching request: {request_id}')
|
|
506
|
-
executor.
|
|
1093
|
+
await executor.schedule_request_async(
|
|
507
1094
|
request_id,
|
|
508
|
-
request_name=
|
|
1095
|
+
request_name=request_names.RequestName.CLUSTER_LAUNCH,
|
|
509
1096
|
request_body=launch_body,
|
|
510
1097
|
func=execution.launch,
|
|
511
1098
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
512
1099
|
request_cluster_name=launch_body.cluster_name,
|
|
1100
|
+
retryable=launch_body.retry_until_up,
|
|
513
1101
|
)
|
|
514
1102
|
|
|
515
1103
|
|
|
@@ -518,9 +1106,9 @@ async def launch(launch_body: payloads.LaunchBody,
|
|
|
518
1106
|
async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
|
|
519
1107
|
"""Executes a task on an existing cluster."""
|
|
520
1108
|
cluster_name = exec_body.cluster_name
|
|
521
|
-
executor.
|
|
1109
|
+
await executor.schedule_request_async(
|
|
522
1110
|
request_id=request.state.request_id,
|
|
523
|
-
request_name=
|
|
1111
|
+
request_name=request_names.RequestName.CLUSTER_EXEC,
|
|
524
1112
|
request_body=exec_body,
|
|
525
1113
|
func=execution.exec,
|
|
526
1114
|
precondition=preconditions.ClusterStartCompletePrecondition(
|
|
@@ -536,9 +1124,9 @@ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
|
|
|
536
1124
|
async def stop(request: fastapi.Request,
|
|
537
1125
|
stop_body: payloads.StopOrDownBody) -> None:
|
|
538
1126
|
"""Stops a cluster."""
|
|
539
|
-
executor.
|
|
1127
|
+
await executor.schedule_request_async(
|
|
540
1128
|
request_id=request.state.request_id,
|
|
541
|
-
request_name=
|
|
1129
|
+
request_name=request_names.RequestName.CLUSTER_STOP,
|
|
542
1130
|
request_body=stop_body,
|
|
543
1131
|
func=core.stop,
|
|
544
1132
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -552,9 +1140,13 @@ async def status(
|
|
|
552
1140
|
status_body: payloads.StatusBody = payloads.StatusBody()
|
|
553
1141
|
) -> None:
|
|
554
1142
|
"""Gets cluster statuses."""
|
|
555
|
-
|
|
1143
|
+
if state.get_block_requests():
|
|
1144
|
+
raise fastapi.HTTPException(
|
|
1145
|
+
status_code=503,
|
|
1146
|
+
detail='Server is shutting down, please try again later.')
|
|
1147
|
+
await executor.schedule_request_async(
|
|
556
1148
|
request_id=request.state.request_id,
|
|
557
|
-
request_name=
|
|
1149
|
+
request_name=request_names.RequestName.CLUSTER_STATUS,
|
|
558
1150
|
request_body=status_body,
|
|
559
1151
|
func=core.status,
|
|
560
1152
|
schedule_type=(requests_lib.ScheduleType.LONG if
|
|
@@ -567,9 +1159,9 @@ async def status(
|
|
|
567
1159
|
async def endpoints(request: fastapi.Request,
|
|
568
1160
|
endpoint_body: payloads.EndpointsBody) -> None:
|
|
569
1161
|
"""Gets the endpoint for a given cluster and port number (endpoint)."""
|
|
570
|
-
executor.
|
|
1162
|
+
await executor.schedule_request_async(
|
|
571
1163
|
request_id=request.state.request_id,
|
|
572
|
-
request_name=
|
|
1164
|
+
request_name=request_names.RequestName.CLUSTER_ENDPOINTS,
|
|
573
1165
|
request_body=endpoint_body,
|
|
574
1166
|
func=core.endpoints,
|
|
575
1167
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -581,9 +1173,9 @@ async def endpoints(request: fastapi.Request,
|
|
|
581
1173
|
async def down(request: fastapi.Request,
|
|
582
1174
|
down_body: payloads.StopOrDownBody) -> None:
|
|
583
1175
|
"""Tears down a cluster."""
|
|
584
|
-
executor.
|
|
1176
|
+
await executor.schedule_request_async(
|
|
585
1177
|
request_id=request.state.request_id,
|
|
586
|
-
request_name=
|
|
1178
|
+
request_name=request_names.RequestName.CLUSTER_DOWN,
|
|
587
1179
|
request_body=down_body,
|
|
588
1180
|
func=core.down,
|
|
589
1181
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -595,9 +1187,9 @@ async def down(request: fastapi.Request,
|
|
|
595
1187
|
async def start(request: fastapi.Request,
|
|
596
1188
|
start_body: payloads.StartBody) -> None:
|
|
597
1189
|
"""Restarts a cluster."""
|
|
598
|
-
executor.
|
|
1190
|
+
await executor.schedule_request_async(
|
|
599
1191
|
request_id=request.state.request_id,
|
|
600
|
-
request_name=
|
|
1192
|
+
request_name=request_names.RequestName.CLUSTER_START,
|
|
601
1193
|
request_body=start_body,
|
|
602
1194
|
func=core.start,
|
|
603
1195
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -609,9 +1201,9 @@ async def start(request: fastapi.Request,
|
|
|
609
1201
|
async def autostop(request: fastapi.Request,
|
|
610
1202
|
autostop_body: payloads.AutostopBody) -> None:
|
|
611
1203
|
"""Schedules an autostop/autodown for a cluster."""
|
|
612
|
-
executor.
|
|
1204
|
+
await executor.schedule_request_async(
|
|
613
1205
|
request_id=request.state.request_id,
|
|
614
|
-
request_name=
|
|
1206
|
+
request_name=request_names.RequestName.CLUSTER_AUTOSTOP,
|
|
615
1207
|
request_body=autostop_body,
|
|
616
1208
|
func=core.autostop,
|
|
617
1209
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -623,9 +1215,9 @@ async def autostop(request: fastapi.Request,
|
|
|
623
1215
|
async def queue(request: fastapi.Request,
|
|
624
1216
|
queue_body: payloads.QueueBody) -> None:
|
|
625
1217
|
"""Gets the job queue of a cluster."""
|
|
626
|
-
executor.
|
|
1218
|
+
await executor.schedule_request_async(
|
|
627
1219
|
request_id=request.state.request_id,
|
|
628
|
-
request_name=
|
|
1220
|
+
request_name=request_names.RequestName.CLUSTER_QUEUE,
|
|
629
1221
|
request_body=queue_body,
|
|
630
1222
|
func=core.queue,
|
|
631
1223
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -637,9 +1229,9 @@ async def queue(request: fastapi.Request,
|
|
|
637
1229
|
async def job_status(request: fastapi.Request,
|
|
638
1230
|
job_status_body: payloads.JobStatusBody) -> None:
|
|
639
1231
|
"""Gets the status of a job."""
|
|
640
|
-
executor.
|
|
1232
|
+
await executor.schedule_request_async(
|
|
641
1233
|
request_id=request.state.request_id,
|
|
642
|
-
request_name=
|
|
1234
|
+
request_name=request_names.RequestName.CLUSTER_JOB_STATUS,
|
|
643
1235
|
request_body=job_status_body,
|
|
644
1236
|
func=core.job_status,
|
|
645
1237
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -651,9 +1243,9 @@ async def job_status(request: fastapi.Request,
|
|
|
651
1243
|
async def cancel(request: fastapi.Request,
|
|
652
1244
|
cancel_body: payloads.CancelBody) -> None:
|
|
653
1245
|
"""Cancels jobs on a cluster."""
|
|
654
|
-
executor.
|
|
1246
|
+
await executor.schedule_request_async(
|
|
655
1247
|
request_id=request.state.request_id,
|
|
656
|
-
request_name=
|
|
1248
|
+
request_name=request_names.RequestName.CLUSTER_JOB_CANCEL,
|
|
657
1249
|
request_body=cancel_body,
|
|
658
1250
|
func=core.cancel,
|
|
659
1251
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -670,36 +1262,27 @@ async def logs(
|
|
|
670
1262
|
# TODO(zhwu): This should wait for the request on the cluster, e.g., async
|
|
671
1263
|
# launch, to finish, so that a user does not need to manually pull the
|
|
672
1264
|
# request status.
|
|
673
|
-
executor.
|
|
1265
|
+
executor.check_request_thread_executor_available()
|
|
1266
|
+
request_task = await executor.prepare_request_async(
|
|
674
1267
|
request_id=request.state.request_id,
|
|
675
|
-
request_name=
|
|
1268
|
+
request_name=request_names.RequestName.CLUSTER_JOB_LOGS,
|
|
676
1269
|
request_body=cluster_job_body,
|
|
677
1270
|
func=core.tail_logs,
|
|
678
|
-
# TODO(aylei): We have tail logs scheduled as SHORT request, because it
|
|
679
|
-
# should be responsive. However, it can be long running if the user's
|
|
680
|
-
# job keeps running, and we should avoid it taking the SHORT worker.
|
|
681
1271
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
682
1272
|
request_cluster_name=cluster_job_body.cluster_name,
|
|
683
1273
|
)
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
1274
|
+
task = executor.execute_request_in_coroutine(request_task)
|
|
1275
|
+
background_tasks.add_task(task.cancel)
|
|
687
1276
|
# TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
|
|
688
1277
|
# the same approach as /stream.
|
|
689
|
-
return stream_utils.
|
|
690
|
-
request_id=
|
|
1278
|
+
return stream_utils.stream_response_for_long_request(
|
|
1279
|
+
request_id=request.state.request_id,
|
|
691
1280
|
logs_path=request_task.log_path,
|
|
692
1281
|
background_tasks=background_tasks,
|
|
1282
|
+
kill_request_on_disconnect=False,
|
|
693
1283
|
)
|
|
694
1284
|
|
|
695
1285
|
|
|
696
|
-
@app.get('/users')
|
|
697
|
-
async def users() -> List[Dict[str, Any]]:
|
|
698
|
-
"""Gets all users."""
|
|
699
|
-
user_list = global_user_state.get_all_users()
|
|
700
|
-
return [user.to_dict() for user in user_list]
|
|
701
|
-
|
|
702
|
-
|
|
703
1286
|
@app.post('/download_logs')
|
|
704
1287
|
async def download_logs(
|
|
705
1288
|
request: fastapi.Request,
|
|
@@ -711,9 +1294,9 @@ async def download_logs(
|
|
|
711
1294
|
# We should reuse the original request body, so that the env vars, such as
|
|
712
1295
|
# user hash, are kept the same.
|
|
713
1296
|
cluster_jobs_body.local_dir = str(logs_dir_on_api_server)
|
|
714
|
-
executor.
|
|
1297
|
+
await executor.schedule_request_async(
|
|
715
1298
|
request_id=request.state.request_id,
|
|
716
|
-
request_name=
|
|
1299
|
+
request_name=request_names.RequestName.CLUSTER_JOB_DOWNLOAD_LOGS,
|
|
717
1300
|
request_body=cluster_jobs_body,
|
|
718
1301
|
func=core.download_logs,
|
|
719
1302
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -722,7 +1305,8 @@ async def download_logs(
|
|
|
722
1305
|
|
|
723
1306
|
|
|
724
1307
|
@app.post('/download')
|
|
725
|
-
async def download(download_body: payloads.DownloadBody
|
|
1308
|
+
async def download(download_body: payloads.DownloadBody,
|
|
1309
|
+
request: fastapi.Request) -> None:
|
|
726
1310
|
"""Downloads a folder from the cluster to the local machine."""
|
|
727
1311
|
folder_paths = [
|
|
728
1312
|
pathlib.Path(folder_path) for folder_path in download_body.folder_paths
|
|
@@ -747,11 +1331,25 @@ async def download(download_body: payloads.DownloadBody) -> None:
|
|
|
747
1331
|
logs_dir_on_api_server).expanduser().resolve() / zip_filename
|
|
748
1332
|
|
|
749
1333
|
try:
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
1334
|
+
|
|
1335
|
+
def _zip_files_and_folders(folder_paths, zip_path):
|
|
1336
|
+
folders = [
|
|
1337
|
+
str(folder_path.expanduser().resolve())
|
|
1338
|
+
for folder_path in folder_paths
|
|
1339
|
+
]
|
|
1340
|
+
# Check for optional query parameter to control zip entry structure
|
|
1341
|
+
relative = request.query_params.get('relative', 'home')
|
|
1342
|
+
if relative == 'items':
|
|
1343
|
+
# Dashboard-friendly: entries relative to selected folders
|
|
1344
|
+
storage_utils.zip_files_and_folders(folders,
|
|
1345
|
+
zip_path,
|
|
1346
|
+
relative_to_items=True)
|
|
1347
|
+
else:
|
|
1348
|
+
# CLI-friendly (default): entries with full paths for mapping
|
|
1349
|
+
storage_utils.zip_files_and_folders(folders, zip_path)
|
|
1350
|
+
|
|
1351
|
+
await context_utils.to_thread(_zip_files_and_folders, folder_paths,
|
|
1352
|
+
zip_path)
|
|
755
1353
|
|
|
756
1354
|
# Add home path to the response headers, so that the client can replace
|
|
757
1355
|
# the remote path in the zip file to the local path.
|
|
@@ -773,13 +1371,84 @@ async def download(download_body: payloads.DownloadBody) -> None:
|
|
|
773
1371
|
detail=f'Error creating zip file: {str(e)}')
|
|
774
1372
|
|
|
775
1373
|
|
|
776
|
-
|
|
777
|
-
|
|
1374
|
+
# TODO(aylei): run it asynchronously after global_user_state support async op
|
|
1375
|
+
@app.post('/provision_logs')
|
|
1376
|
+
def provision_logs(provision_logs_body: payloads.ProvisionLogsBody,
|
|
1377
|
+
follow: bool = True,
|
|
1378
|
+
tail: int = 0) -> fastapi.responses.StreamingResponse:
|
|
1379
|
+
"""Streams the provision.log for the latest launch request of a cluster."""
|
|
1380
|
+
log_path = None
|
|
1381
|
+
cluster_name = provision_logs_body.cluster_name
|
|
1382
|
+
worker = provision_logs_body.worker
|
|
1383
|
+
# stream head node logs
|
|
1384
|
+
if worker is None:
|
|
1385
|
+
# Prefer clusters table first, then cluster_history as fallback.
|
|
1386
|
+
log_path_str = global_user_state.get_cluster_provision_log_path(
|
|
1387
|
+
cluster_name)
|
|
1388
|
+
if not log_path_str:
|
|
1389
|
+
log_path_str = (
|
|
1390
|
+
global_user_state.get_cluster_history_provision_log_path(
|
|
1391
|
+
cluster_name))
|
|
1392
|
+
if not log_path_str:
|
|
1393
|
+
raise fastapi.HTTPException(
|
|
1394
|
+
status_code=404,
|
|
1395
|
+
detail=('Provision log path is not recorded for this cluster. '
|
|
1396
|
+
'Please relaunch to generate provisioning logs.'))
|
|
1397
|
+
log_path = pathlib.Path(log_path_str).expanduser().resolve()
|
|
1398
|
+
if not log_path.exists():
|
|
1399
|
+
raise fastapi.HTTPException(
|
|
1400
|
+
status_code=404,
|
|
1401
|
+
detail=f'Provision log path does not exist: {str(log_path)}')
|
|
1402
|
+
|
|
1403
|
+
# stream worker node logs
|
|
1404
|
+
else:
|
|
1405
|
+
handle = global_user_state.get_handle_from_cluster_name(cluster_name)
|
|
1406
|
+
if handle is None:
|
|
1407
|
+
raise fastapi.HTTPException(
|
|
1408
|
+
status_code=404,
|
|
1409
|
+
detail=('Cluster handle is not recorded for this cluster. '
|
|
1410
|
+
'Please relaunch to generate provisioning logs.'))
|
|
1411
|
+
# instance_ids includes head node
|
|
1412
|
+
instance_ids = handle.instance_ids
|
|
1413
|
+
if instance_ids is None:
|
|
1414
|
+
raise fastapi.HTTPException(
|
|
1415
|
+
status_code=400,
|
|
1416
|
+
detail='Instance IDs are not recorded for this cluster. '
|
|
1417
|
+
'Please relaunch to generate provisioning logs.')
|
|
1418
|
+
if worker > len(instance_ids) - 1:
|
|
1419
|
+
raise fastapi.HTTPException(
|
|
1420
|
+
status_code=400,
|
|
1421
|
+
detail=f'Worker {worker} is out of range. '
|
|
1422
|
+
f'The cluster has {len(instance_ids)} nodes.')
|
|
1423
|
+
log_path = metadata_utils.get_instance_log_dir(
|
|
1424
|
+
handle.get_cluster_name_on_cloud(), instance_ids[worker])
|
|
1425
|
+
|
|
1426
|
+
# Tail semantics: 0 means print all lines. Convert 0 -> None for streamer.
|
|
1427
|
+
effective_tail = None if tail is None or tail <= 0 else tail
|
|
1428
|
+
|
|
1429
|
+
return fastapi.responses.StreamingResponse(
|
|
1430
|
+
content=stream_utils.log_streamer(None,
|
|
1431
|
+
log_path,
|
|
1432
|
+
tail=effective_tail,
|
|
1433
|
+
follow=follow,
|
|
1434
|
+
cluster_name=cluster_name),
|
|
1435
|
+
media_type='text/plain',
|
|
1436
|
+
headers={
|
|
1437
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
1438
|
+
'X-Accel-Buffering': 'no',
|
|
1439
|
+
'Transfer-Encoding': 'chunked',
|
|
1440
|
+
},
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
|
|
1444
|
+
@app.post('/cost_report')
|
|
1445
|
+
async def cost_report(request: fastapi.Request,
|
|
1446
|
+
cost_report_body: payloads.CostReportBody) -> None:
|
|
778
1447
|
"""Gets the cost report of a cluster."""
|
|
779
|
-
executor.
|
|
1448
|
+
await executor.schedule_request_async(
|
|
780
1449
|
request_id=request.state.request_id,
|
|
781
|
-
request_name=
|
|
782
|
-
request_body=
|
|
1450
|
+
request_name=request_names.RequestName.CLUSTER_COST_REPORT,
|
|
1451
|
+
request_body=cost_report_body,
|
|
783
1452
|
func=core.cost_report,
|
|
784
1453
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
785
1454
|
)
|
|
@@ -788,9 +1457,9 @@ async def cost_report(request: fastapi.Request) -> None:
|
|
|
788
1457
|
@app.get('/storage/ls')
|
|
789
1458
|
async def storage_ls(request: fastapi.Request) -> None:
|
|
790
1459
|
"""Gets the storages."""
|
|
791
|
-
executor.
|
|
1460
|
+
await executor.schedule_request_async(
|
|
792
1461
|
request_id=request.state.request_id,
|
|
793
|
-
request_name=
|
|
1462
|
+
request_name=request_names.RequestName.STORAGE_LS,
|
|
794
1463
|
request_body=payloads.RequestBody(),
|
|
795
1464
|
func=core.storage_ls,
|
|
796
1465
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
@@ -801,9 +1470,9 @@ async def storage_ls(request: fastapi.Request) -> None:
|
|
|
801
1470
|
async def storage_delete(request: fastapi.Request,
|
|
802
1471
|
storage_body: payloads.StorageBody) -> None:
|
|
803
1472
|
"""Deletes a storage."""
|
|
804
|
-
executor.
|
|
1473
|
+
await executor.schedule_request_async(
|
|
805
1474
|
request_id=request.state.request_id,
|
|
806
|
-
request_name=
|
|
1475
|
+
request_name=request_names.RequestName.STORAGE_DELETE,
|
|
807
1476
|
request_body=storage_body,
|
|
808
1477
|
func=core.storage_delete,
|
|
809
1478
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -814,9 +1483,9 @@ async def storage_delete(request: fastapi.Request,
|
|
|
814
1483
|
async def local_up(request: fastapi.Request,
|
|
815
1484
|
local_up_body: payloads.LocalUpBody) -> None:
|
|
816
1485
|
"""Launches a Kubernetes cluster on API server."""
|
|
817
|
-
executor.
|
|
1486
|
+
await executor.schedule_request_async(
|
|
818
1487
|
request_id=request.state.request_id,
|
|
819
|
-
request_name=
|
|
1488
|
+
request_name=request_names.RequestName.LOCAL_UP,
|
|
820
1489
|
request_body=local_up_body,
|
|
821
1490
|
func=core.local_up,
|
|
822
1491
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
@@ -824,37 +1493,65 @@ async def local_up(request: fastapi.Request,
|
|
|
824
1493
|
|
|
825
1494
|
|
|
826
1495
|
@app.post('/local_down')
|
|
827
|
-
async def local_down(request: fastapi.Request
|
|
1496
|
+
async def local_down(request: fastapi.Request,
|
|
1497
|
+
local_down_body: payloads.LocalDownBody) -> None:
|
|
828
1498
|
"""Tears down the Kubernetes cluster started by local_up."""
|
|
829
|
-
executor.
|
|
1499
|
+
await executor.schedule_request_async(
|
|
830
1500
|
request_id=request.state.request_id,
|
|
831
|
-
request_name=
|
|
832
|
-
request_body=
|
|
1501
|
+
request_name=request_names.RequestName.LOCAL_DOWN,
|
|
1502
|
+
request_body=local_down_body,
|
|
833
1503
|
func=core.local_down,
|
|
834
1504
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
835
1505
|
)
|
|
836
1506
|
|
|
837
1507
|
|
|
1508
|
+
async def get_expanded_request_id(request_id: str) -> str:
|
|
1509
|
+
"""Gets the expanded request ID for a given request ID prefix."""
|
|
1510
|
+
request_tasks = await requests_lib.get_requests_async_with_prefix(
|
|
1511
|
+
request_id, fields=['request_id'])
|
|
1512
|
+
if request_tasks is None:
|
|
1513
|
+
raise fastapi.HTTPException(status_code=404,
|
|
1514
|
+
detail=f'Request {request_id!r} not found')
|
|
1515
|
+
if len(request_tasks) > 1:
|
|
1516
|
+
raise fastapi.HTTPException(status_code=400,
|
|
1517
|
+
detail=('Multiple requests found for '
|
|
1518
|
+
f'request ID prefix: {request_id}'))
|
|
1519
|
+
return request_tasks[0].request_id
|
|
1520
|
+
|
|
1521
|
+
|
|
838
1522
|
# === API server related APIs ===
|
|
839
|
-
@app.get('/api/get')
|
|
840
|
-
async def api_get(request_id: str) ->
|
|
1523
|
+
@app.get('/api/get', response_class=fastapi_responses.ORJSONResponse)
|
|
1524
|
+
async def api_get(request_id: str) -> payloads.RequestPayload:
|
|
841
1525
|
"""Gets a request with a given request ID prefix."""
|
|
1526
|
+
# Validate request_id prefix matches a single request.
|
|
1527
|
+
request_id = await get_expanded_request_id(request_id)
|
|
1528
|
+
|
|
842
1529
|
while True:
|
|
843
|
-
|
|
844
|
-
if
|
|
1530
|
+
req_status = await requests_lib.get_request_status_async(request_id)
|
|
1531
|
+
if req_status is None:
|
|
845
1532
|
print(f'No task with request ID {request_id}', flush=True)
|
|
846
1533
|
raise fastapi.HTTPException(
|
|
847
1534
|
status_code=404, detail=f'Request {request_id!r} not found')
|
|
848
|
-
if
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
return request_task.encode()
|
|
1535
|
+
if (req_status.status == requests_lib.RequestStatus.RUNNING and
|
|
1536
|
+
daemons.is_daemon_request_id(request_id)):
|
|
1537
|
+
# Daemon requests run forever, break without waiting for complete.
|
|
1538
|
+
break
|
|
1539
|
+
if req_status.status > requests_lib.RequestStatus.RUNNING:
|
|
1540
|
+
break
|
|
855
1541
|
# yield control to allow other coroutines to run, sleep shortly
|
|
856
1542
|
# to avoid storming the DB and CPU in the meantime
|
|
857
1543
|
await asyncio.sleep(0.1)
|
|
1544
|
+
request_task = await requests_lib.get_request_async(request_id)
|
|
1545
|
+
# TODO(aylei): refine this, /api/get will not be retried and this is
|
|
1546
|
+
# meaningless to retry. It is the original request that should be retried.
|
|
1547
|
+
if request_task.should_retry:
|
|
1548
|
+
raise fastapi.HTTPException(
|
|
1549
|
+
status_code=503, detail=f'Request {request_id!r} should be retried')
|
|
1550
|
+
request_error = request_task.get_error()
|
|
1551
|
+
if request_error is not None:
|
|
1552
|
+
raise fastapi.HTTPException(status_code=500,
|
|
1553
|
+
detail=request_task.encode().model_dump())
|
|
1554
|
+
return request_task.encode()
|
|
858
1555
|
|
|
859
1556
|
|
|
860
1557
|
@app.get('/api/stream')
|
|
@@ -888,13 +1585,18 @@ async def stream(
|
|
|
888
1585
|
clients, console for CLI/API clients), 'plain' (force plain text),
|
|
889
1586
|
'html' (force HTML), or 'console' (force console)
|
|
890
1587
|
"""
|
|
1588
|
+
# We need to save the user-supplied request ID for the response header.
|
|
1589
|
+
user_supplied_request_id = request_id
|
|
891
1590
|
if request_id is not None and log_path is not None:
|
|
892
1591
|
raise fastapi.HTTPException(
|
|
893
1592
|
status_code=400,
|
|
894
1593
|
detail='Only one of request_id and log_path can be provided')
|
|
895
1594
|
|
|
1595
|
+
if request_id is not None:
|
|
1596
|
+
request_id = await get_expanded_request_id(request_id)
|
|
1597
|
+
|
|
896
1598
|
if request_id is None and log_path is None:
|
|
897
|
-
request_id = requests_lib.
|
|
1599
|
+
request_id = await requests_lib.get_latest_request_id_async()
|
|
898
1600
|
if request_id is None:
|
|
899
1601
|
raise fastapi.HTTPException(status_code=404,
|
|
900
1602
|
detail='No request found')
|
|
@@ -921,19 +1623,40 @@ async def stream(
|
|
|
921
1623
|
'X-Accel-Buffering': 'no'
|
|
922
1624
|
})
|
|
923
1625
|
|
|
1626
|
+
polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
|
|
924
1627
|
# Original plain text streaming logic
|
|
925
1628
|
if request_id is not None:
|
|
926
|
-
request_task = requests_lib.
|
|
1629
|
+
request_task = await requests_lib.get_request_async(
|
|
1630
|
+
request_id, fields=['request_id', 'schedule_type'])
|
|
927
1631
|
if request_task is None:
|
|
928
1632
|
print(f'No task with request ID {request_id}')
|
|
929
1633
|
raise fastapi.HTTPException(
|
|
930
1634
|
status_code=404, detail=f'Request {request_id!r} not found')
|
|
1635
|
+
# req.log_path is derived from request_id,
|
|
1636
|
+
# so it's ok to just grab the request_id in the above query.
|
|
931
1637
|
log_path_to_stream = request_task.log_path
|
|
1638
|
+
if not log_path_to_stream.exists():
|
|
1639
|
+
# The log file might be deleted by the request GC daemon but the
|
|
1640
|
+
# request task is still in the database.
|
|
1641
|
+
raise fastapi.HTTPException(
|
|
1642
|
+
status_code=404,
|
|
1643
|
+
detail=f'Log of request {request_id!r} has been deleted')
|
|
1644
|
+
if request_task.schedule_type == requests_lib.ScheduleType.LONG:
|
|
1645
|
+
polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
|
|
1646
|
+
del request_task
|
|
932
1647
|
else:
|
|
933
1648
|
assert log_path is not None, (request_id, log_path)
|
|
934
1649
|
if log_path == constants.API_SERVER_LOGS:
|
|
935
1650
|
resolved_log_path = pathlib.Path(
|
|
936
1651
|
constants.API_SERVER_LOGS).expanduser()
|
|
1652
|
+
if not resolved_log_path.exists():
|
|
1653
|
+
raise fastapi.HTTPException(
|
|
1654
|
+
status_code=404,
|
|
1655
|
+
detail='Server log file does not exist. The API server may '
|
|
1656
|
+
'have been started with `--foreground` - check the '
|
|
1657
|
+
'stdout of API server process, such as: '
|
|
1658
|
+
'`kubectl logs -n api-server-namespace '
|
|
1659
|
+
'api-server-pod-name`')
|
|
937
1660
|
else:
|
|
938
1661
|
# This should be a log path under ~/sky_logs.
|
|
939
1662
|
resolved_logs_directory = pathlib.Path(
|
|
@@ -954,18 +1677,26 @@ async def stream(
|
|
|
954
1677
|
detail=f'Log path {log_path!r} does not exist')
|
|
955
1678
|
|
|
956
1679
|
log_path_to_stream = resolved_log_path
|
|
1680
|
+
|
|
1681
|
+
headers = {
|
|
1682
|
+
'Cache-Control': 'no-cache, no-transform',
|
|
1683
|
+
'X-Accel-Buffering': 'no',
|
|
1684
|
+
'Transfer-Encoding': 'chunked'
|
|
1685
|
+
}
|
|
1686
|
+
if request_id is not None:
|
|
1687
|
+
headers[server_constants.STREAM_REQUEST_HEADER] = (
|
|
1688
|
+
user_supplied_request_id
|
|
1689
|
+
if user_supplied_request_id else request_id)
|
|
1690
|
+
|
|
957
1691
|
return fastapi.responses.StreamingResponse(
|
|
958
1692
|
content=stream_utils.log_streamer(request_id,
|
|
959
1693
|
log_path_to_stream,
|
|
960
1694
|
plain_logs=format == 'plain',
|
|
961
1695
|
tail=tail,
|
|
962
|
-
follow=follow
|
|
1696
|
+
follow=follow,
|
|
1697
|
+
polling_interval=polling_interval),
|
|
963
1698
|
media_type='text/plain',
|
|
964
|
-
headers=
|
|
965
|
-
'Cache-Control': 'no-cache, no-transform',
|
|
966
|
-
'X-Accel-Buffering': 'no',
|
|
967
|
-
'Transfer-Encoding': 'chunked'
|
|
968
|
-
},
|
|
1699
|
+
headers=headers,
|
|
969
1700
|
)
|
|
970
1701
|
|
|
971
1702
|
|
|
@@ -973,11 +1704,11 @@ async def stream(
|
|
|
973
1704
|
async def api_cancel(request: fastapi.Request,
|
|
974
1705
|
request_cancel_body: payloads.RequestCancelBody) -> None:
|
|
975
1706
|
"""Cancels requests."""
|
|
976
|
-
executor.
|
|
1707
|
+
await executor.schedule_request_async(
|
|
977
1708
|
request_id=request.state.request_id,
|
|
978
|
-
request_name=
|
|
1709
|
+
request_name=request_names.RequestName.API_CANCEL,
|
|
979
1710
|
request_body=request_cancel_body,
|
|
980
|
-
func=requests_lib.
|
|
1711
|
+
func=requests_lib.kill_requests_with_prefix,
|
|
981
1712
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
982
1713
|
)
|
|
983
1714
|
|
|
@@ -985,10 +1716,14 @@ async def api_cancel(request: fastapi.Request,
|
|
|
985
1716
|
@app.get('/api/status')
|
|
986
1717
|
async def api_status(
|
|
987
1718
|
request_ids: Optional[List[str]] = fastapi.Query(
|
|
988
|
-
None, description='Request
|
|
1719
|
+
None, description='Request ID prefixes to get status for.'),
|
|
989
1720
|
all_status: bool = fastapi.Query(
|
|
990
1721
|
False, description='Get finished requests as well.'),
|
|
991
|
-
|
|
1722
|
+
limit: Optional[int] = fastapi.Query(
|
|
1723
|
+
None, description='Number of requests to show.'),
|
|
1724
|
+
fields: Optional[List[str]] = fastapi.Query(
|
|
1725
|
+
None, description='Fields to get. If None, get all fields.'),
|
|
1726
|
+
) -> List[payloads.RequestPayload]:
|
|
992
1727
|
"""Gets the list of requests."""
|
|
993
1728
|
if request_ids is None:
|
|
994
1729
|
statuses = None
|
|
@@ -997,53 +1732,120 @@ async def api_status(
|
|
|
997
1732
|
requests_lib.RequestStatus.PENDING,
|
|
998
1733
|
requests_lib.RequestStatus.RUNNING,
|
|
999
1734
|
]
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1735
|
+
request_tasks = await requests_lib.get_request_tasks_async(
|
|
1736
|
+
req_filter=requests_lib.RequestTaskFilter(
|
|
1737
|
+
status=statuses,
|
|
1738
|
+
limit=limit,
|
|
1739
|
+
fields=fields,
|
|
1740
|
+
sort=True,
|
|
1741
|
+
))
|
|
1742
|
+
return requests_lib.encode_requests(request_tasks)
|
|
1004
1743
|
else:
|
|
1005
1744
|
encoded_request_tasks = []
|
|
1006
1745
|
for request_id in request_ids:
|
|
1007
|
-
|
|
1008
|
-
|
|
1746
|
+
request_tasks = await requests_lib.get_requests_async_with_prefix(
|
|
1747
|
+
request_id)
|
|
1748
|
+
if request_tasks is None:
|
|
1009
1749
|
continue
|
|
1010
|
-
|
|
1750
|
+
for request_task in request_tasks:
|
|
1751
|
+
encoded_request_tasks.append(request_task.readable_encode())
|
|
1011
1752
|
return encoded_request_tasks
|
|
1012
1753
|
|
|
1013
1754
|
|
|
1014
|
-
@app.get(
|
|
1015
|
-
|
|
1755
|
+
@app.get(
|
|
1756
|
+
'/api/health',
|
|
1757
|
+
# response_model_exclude_unset omits unset fields
|
|
1758
|
+
# in the response JSON.
|
|
1759
|
+
response_model_exclude_unset=True)
|
|
1760
|
+
async def health(request: fastapi.Request) -> responses.APIHealthResponse:
|
|
1016
1761
|
"""Checks the health of the API server.
|
|
1017
1762
|
|
|
1018
1763
|
Returns:
|
|
1019
|
-
|
|
1020
|
-
- status: str; The status of the API server.
|
|
1021
|
-
- api_version: str; The API version of the API server.
|
|
1022
|
-
- version: str; The version of SkyPilot used for API server.
|
|
1023
|
-
- version_on_disk: str; The version of the SkyPilot installation on
|
|
1024
|
-
disk, which can be used to warn about restarting the API server
|
|
1025
|
-
- commit: str; The commit hash of SkyPilot used for API server.
|
|
1764
|
+
responses.APIHealthResponse: The health response.
|
|
1026
1765
|
"""
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1766
|
+
user = request.state.auth_user
|
|
1767
|
+
server_status = common.ApiServerStatus.HEALTHY
|
|
1768
|
+
if getattr(request.state, 'anonymous_user', False):
|
|
1769
|
+
# API server authentication is enabled, but the request is not
|
|
1770
|
+
# authenticated. We still have to serve the request because the
|
|
1771
|
+
# /api/health endpoint has two different usage:
|
|
1772
|
+
# 1. For health check from `api start` and external ochestration
|
|
1773
|
+
# tools (k8s), which does not require authentication and user info.
|
|
1774
|
+
# 2. Return server info to client and hint client to login if required.
|
|
1775
|
+
# Separating these two usage to different APIs will break backward
|
|
1776
|
+
# compatibility for existing ochestration solutions (e.g. helm chart).
|
|
1777
|
+
# So we serve these two usages in a backward compatible manner below.
|
|
1778
|
+
client_version = versions.get_remote_api_version()
|
|
1779
|
+
# - For Client with API version >= 14, we return 200 response with
|
|
1780
|
+
# status=NEEDS_AUTH, new client will handle the login process.
|
|
1781
|
+
# - For health check from `sky api start`, the client code always uses
|
|
1782
|
+
# the same API version with the server, thus there is no compatibility
|
|
1783
|
+
# issue.
|
|
1784
|
+
server_status = common.ApiServerStatus.NEEDS_AUTH
|
|
1785
|
+
if client_version is None:
|
|
1786
|
+
# - For health check from ochestration tools (e.g. k8s), we also
|
|
1787
|
+
# return 200 with status=NEEDS_AUTH, which passes HTTP probe
|
|
1788
|
+
# check.
|
|
1789
|
+
# - There is no harm when an malicious client calls /api/health
|
|
1790
|
+
# without authentication since no sensitive information is
|
|
1791
|
+
# returned.
|
|
1792
|
+
return responses.APIHealthResponse(
|
|
1793
|
+
status=common.ApiServerStatus.HEALTHY,)
|
|
1794
|
+
# TODO(aylei): remove this after min_compatible_api_version >= 14.
|
|
1795
|
+
if client_version < 14:
|
|
1796
|
+
# For Client with API version < 14, the NEEDS_AUTH status is not
|
|
1797
|
+
# honored. Return 401 to trigger the login process.
|
|
1798
|
+
raise fastapi.HTTPException(status_code=401,
|
|
1799
|
+
detail='Authentication required')
|
|
1800
|
+
|
|
1801
|
+
logger.debug(f'Health endpoint: request.state.auth_user = {user}')
|
|
1802
|
+
return responses.APIHealthResponse(
|
|
1803
|
+
status=server_status,
|
|
1804
|
+
# Kept for backward compatibility, clients before 0.11.0 will read this
|
|
1805
|
+
# field to check compatibility and hint the user to upgrade the CLI.
|
|
1806
|
+
# TODO(aylei): remove this field after 0.13.0
|
|
1807
|
+
api_version=str(server_constants.API_VERSION),
|
|
1808
|
+
version=sky.__version__,
|
|
1809
|
+
version_on_disk=common.get_skypilot_version_on_disk(),
|
|
1810
|
+
commit=sky.__commit__,
|
|
1811
|
+
# Whether basic auth on api server is enabled
|
|
1812
|
+
basic_auth_enabled=os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH,
|
|
1813
|
+
'false').lower() == 'true',
|
|
1814
|
+
user=user if user is not None else None,
|
|
1815
|
+
# Whether service account token is enabled
|
|
1816
|
+
service_account_token_enabled=(os.environ.get(
|
|
1817
|
+
constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
|
|
1818
|
+
'false').lower() == 'true'),
|
|
1819
|
+
# Whether basic auth on ingress is enabled
|
|
1820
|
+
ingress_basic_auth_enabled=os.environ.get(
|
|
1821
|
+
constants.SKYPILOT_INGRESS_BASIC_AUTH_ENABLED,
|
|
1822
|
+
'false').lower() == 'true',
|
|
1823
|
+
)
|
|
1824
|
+
|
|
1825
|
+
|
|
1826
|
+
class KubernetesSSHMessageType(IntEnum):
|
|
1827
|
+
REGULAR_DATA = 0
|
|
1828
|
+
PINGPONG = 1
|
|
1829
|
+
LATENCY_MEASUREMENT = 2
|
|
1034
1830
|
|
|
1035
1831
|
|
|
1036
1832
|
@app.websocket('/kubernetes-pod-ssh-proxy')
|
|
1037
1833
|
async def kubernetes_pod_ssh_proxy(
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
) -> None:
|
|
1834
|
+
websocket: fastapi.WebSocket,
|
|
1835
|
+
cluster_name: str,
|
|
1836
|
+
client_version: Optional[int] = None) -> None:
|
|
1041
1837
|
"""Proxies SSH to the Kubernetes pod with websocket."""
|
|
1042
1838
|
await websocket.accept()
|
|
1043
|
-
cluster_name = cluster_name_body.cluster_name
|
|
1044
1839
|
logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
|
|
1045
1840
|
|
|
1046
|
-
|
|
1841
|
+
timestamps_supported = client_version is not None and client_version > 21
|
|
1842
|
+
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
1843
|
+
client_version = {client_version}')
|
|
1844
|
+
|
|
1845
|
+
# Run core.status in another thread to avoid blocking the event loop.
|
|
1846
|
+
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1847
|
+
cluster_records = await context_utils.to_thread_with_executor(
|
|
1848
|
+
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1047
1849
|
cluster_record = cluster_records[0]
|
|
1048
1850
|
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1049
1851
|
raise fastapi.HTTPException(
|
|
@@ -1082,17 +1884,70 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1082
1884
|
return
|
|
1083
1885
|
|
|
1084
1886
|
logger.info(f'Starting port-forward to local port: {local_port}')
|
|
1887
|
+
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
1888
|
+
pid=os.getpid())
|
|
1889
|
+
ssh_failed = False
|
|
1890
|
+
websocket_closed = False
|
|
1085
1891
|
try:
|
|
1892
|
+
conn_gauge.inc()
|
|
1086
1893
|
# Connect to the local port
|
|
1087
1894
|
reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
|
|
1088
1895
|
|
|
1089
1896
|
async def websocket_to_ssh():
|
|
1090
1897
|
try:
|
|
1091
1898
|
async for message in websocket.iter_bytes():
|
|
1899
|
+
if timestamps_supported:
|
|
1900
|
+
type_size = struct.calcsize('!B')
|
|
1901
|
+
message_type = struct.unpack('!B',
|
|
1902
|
+
message[:type_size])[0]
|
|
1903
|
+
if (message_type ==
|
|
1904
|
+
KubernetesSSHMessageType.REGULAR_DATA):
|
|
1905
|
+
# Regular data - strip type byte and forward to SSH
|
|
1906
|
+
message = message[type_size:]
|
|
1907
|
+
elif message_type == KubernetesSSHMessageType.PINGPONG:
|
|
1908
|
+
# PING message - respond with PONG (type 1)
|
|
1909
|
+
ping_id_size = struct.calcsize('!I')
|
|
1910
|
+
if len(message) != type_size + ping_id_size:
|
|
1911
|
+
raise ValueError('Invalid PING message '
|
|
1912
|
+
f'length: {len(message)}')
|
|
1913
|
+
# Return the same PING message, so that the client
|
|
1914
|
+
# can measure the latency.
|
|
1915
|
+
await websocket.send_bytes(message)
|
|
1916
|
+
continue
|
|
1917
|
+
elif (message_type ==
|
|
1918
|
+
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
|
|
1919
|
+
# Latency measurement from client
|
|
1920
|
+
latency_size = struct.calcsize('!Q')
|
|
1921
|
+
if len(message) != type_size + latency_size:
|
|
1922
|
+
raise ValueError(
|
|
1923
|
+
'Invalid latency measurement '
|
|
1924
|
+
f'message length: {len(message)}')
|
|
1925
|
+
avg_latency_ms = struct.unpack(
|
|
1926
|
+
'!Q',
|
|
1927
|
+
message[type_size:type_size + latency_size])[0]
|
|
1928
|
+
latency_seconds = avg_latency_ms / 1000
|
|
1929
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
|
|
1930
|
+
continue
|
|
1931
|
+
else:
|
|
1932
|
+
# Unknown message type.
|
|
1933
|
+
raise ValueError(
|
|
1934
|
+
f'Unknown message type: {message_type}')
|
|
1092
1935
|
writer.write(message)
|
|
1093
|
-
|
|
1936
|
+
try:
|
|
1937
|
+
await writer.drain()
|
|
1938
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1939
|
+
# Typically we will not reach here, if the ssh to pod
|
|
1940
|
+
# is disconnected, ssh_to_websocket will exit first.
|
|
1941
|
+
# But just in case.
|
|
1942
|
+
logger.error('Failed to write to pod through '
|
|
1943
|
+
f'port-forward connection: {e}')
|
|
1944
|
+
nonlocal ssh_failed
|
|
1945
|
+
ssh_failed = True
|
|
1946
|
+
break
|
|
1094
1947
|
except fastapi.WebSocketDisconnect:
|
|
1095
1948
|
pass
|
|
1949
|
+
nonlocal websocket_closed
|
|
1950
|
+
websocket_closed = True
|
|
1096
1951
|
writer.close()
|
|
1097
1952
|
|
|
1098
1953
|
async def ssh_to_websocket():
|
|
@@ -1100,87 +1955,262 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1100
1955
|
while True:
|
|
1101
1956
|
data = await reader.read(1024)
|
|
1102
1957
|
if not data:
|
|
1958
|
+
if not websocket_closed:
|
|
1959
|
+
logger.warning('SSH connection to pod is '
|
|
1960
|
+
'disconnected before websocket '
|
|
1961
|
+
'connection is closed')
|
|
1962
|
+
nonlocal ssh_failed
|
|
1963
|
+
ssh_failed = True
|
|
1103
1964
|
break
|
|
1965
|
+
if timestamps_supported:
|
|
1966
|
+
# Prepend message type byte (0 = regular data)
|
|
1967
|
+
message_type_bytes = struct.pack(
|
|
1968
|
+
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
|
|
1969
|
+
data = message_type_bytes + data
|
|
1104
1970
|
await websocket.send_bytes(data)
|
|
1105
1971
|
except Exception: # pylint: disable=broad-except
|
|
1106
1972
|
pass
|
|
1107
|
-
|
|
1973
|
+
try:
|
|
1974
|
+
await websocket.close()
|
|
1975
|
+
except Exception: # pylint: disable=broad-except
|
|
1976
|
+
# The websocket might has been closed by the client.
|
|
1977
|
+
pass
|
|
1108
1978
|
|
|
1109
1979
|
await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
|
|
1110
1980
|
finally:
|
|
1111
|
-
|
|
1981
|
+
conn_gauge.dec()
|
|
1982
|
+
reason = ''
|
|
1983
|
+
try:
|
|
1984
|
+
logger.info('Terminating kubectl port-forward process')
|
|
1985
|
+
proc.terminate()
|
|
1986
|
+
except ProcessLookupError:
|
|
1987
|
+
stdout = await proc.stdout.read()
|
|
1988
|
+
logger.error('kubectl port-forward was terminated before the '
|
|
1989
|
+
'ssh websocket connection was closed. Remaining '
|
|
1990
|
+
f'output: {str(stdout)}')
|
|
1991
|
+
reason = 'KubectlPortForwardExit'
|
|
1992
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
1993
|
+
pid=os.getpid(), reason='KubectlPortForwardExit').inc()
|
|
1994
|
+
else:
|
|
1995
|
+
if ssh_failed:
|
|
1996
|
+
reason = 'SSHToPodDisconnected'
|
|
1997
|
+
else:
|
|
1998
|
+
reason = 'ClientClosed'
|
|
1999
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2000
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2001
|
+
|
|
2002
|
+
|
|
2003
|
+
@app.get('/all_contexts')
|
|
2004
|
+
async def all_contexts(request: fastapi.Request) -> None:
|
|
2005
|
+
"""Gets all Kubernetes and SSH node pool contexts."""
|
|
2006
|
+
|
|
2007
|
+
await executor.schedule_request_async(
|
|
2008
|
+
request_id=request.state.request_id,
|
|
2009
|
+
request_name=request_names.RequestName.ALL_CONTEXTS,
|
|
2010
|
+
request_body=payloads.RequestBody(),
|
|
2011
|
+
func=core.get_all_contexts,
|
|
2012
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
2013
|
+
)
|
|
1112
2014
|
|
|
1113
2015
|
|
|
1114
2016
|
# === Internal APIs ===
|
|
1115
2017
|
@app.get('/api/completion/cluster_name')
|
|
1116
2018
|
async def complete_cluster_name(incomplete: str,) -> List[str]:
|
|
1117
|
-
return
|
|
2019
|
+
return await context_utils.to_thread(
|
|
2020
|
+
global_user_state.get_cluster_names_start_with, incomplete)
|
|
1118
2021
|
|
|
1119
2022
|
|
|
1120
2023
|
@app.get('/api/completion/storage_name')
|
|
1121
2024
|
async def complete_storage_name(incomplete: str,) -> List[str]:
|
|
1122
|
-
return
|
|
2025
|
+
return await context_utils.to_thread(
|
|
2026
|
+
global_user_state.get_storage_names_start_with, incomplete)
|
|
1123
2027
|
|
|
1124
2028
|
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
2029
|
+
@app.get('/api/completion/volume_name')
|
|
2030
|
+
async def complete_volume_name(incomplete: str,) -> List[str]:
|
|
2031
|
+
return await context_utils.to_thread(
|
|
2032
|
+
global_user_state.get_volume_names_start_with, incomplete)
|
|
1129
2033
|
|
|
1130
|
-
Handles the /dashboard prefix from Next.js configuration.
|
|
1131
|
-
"""
|
|
1132
|
-
# Check if the path starts with 'dashboard/' and remove it if it does
|
|
1133
|
-
if full_path.startswith('dashboard/'):
|
|
1134
|
-
full_path = full_path[len('dashboard/'):]
|
|
1135
2034
|
|
|
1136
|
-
|
|
2035
|
+
@app.get('/api/completion/api_request')
|
|
2036
|
+
async def complete_api_request(incomplete: str,) -> List[str]:
|
|
2037
|
+
return await requests_lib.get_api_request_ids_start_with(incomplete)
|
|
2038
|
+
|
|
2039
|
+
|
|
2040
|
+
@app.get('/dashboard/{full_path:path}')
|
|
2041
|
+
async def serve_dashboard(full_path: str):
|
|
2042
|
+
"""Serves the Next.js dashboard application.
|
|
2043
|
+
|
|
2044
|
+
Args:
|
|
2045
|
+
full_path: The path requested by the client.
|
|
2046
|
+
e.g. /clusters, /jobs
|
|
2047
|
+
|
|
2048
|
+
Returns:
|
|
2049
|
+
FileResponse for static files or index.html for client-side routing.
|
|
2050
|
+
|
|
2051
|
+
Raises:
|
|
2052
|
+
HTTPException: If the path is invalid or file not found.
|
|
2053
|
+
"""
|
|
2054
|
+
# Try to serve the staticfile directly e.g. /skypilot.svg,
|
|
2055
|
+
# /favicon.ico, and /_next/, etc.
|
|
1137
2056
|
file_path = os.path.join(server_constants.DASHBOARD_DIR, full_path)
|
|
1138
2057
|
if os.path.isfile(file_path):
|
|
1139
2058
|
return fastapi.responses.FileResponse(file_path)
|
|
1140
2059
|
|
|
1141
|
-
#
|
|
1142
|
-
#
|
|
1143
|
-
# client will be redirected to the index.html.
|
|
2060
|
+
# Serve index.html for client-side routing
|
|
2061
|
+
# e.g. /clusters, /jobs
|
|
1144
2062
|
index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
|
|
1145
2063
|
try:
|
|
1146
2064
|
with open(index_path, 'r', encoding='utf-8') as f:
|
|
1147
2065
|
content = f.read()
|
|
2066
|
+
|
|
1148
2067
|
return fastapi.responses.HTMLResponse(content=content)
|
|
1149
2068
|
except Exception as e:
|
|
1150
2069
|
logger.error(f'Error serving dashboard: {e}')
|
|
1151
2070
|
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
|
1152
2071
|
|
|
1153
2072
|
|
|
2073
|
+
# Redirect the root path to dashboard
|
|
2074
|
+
@app.get('/')
|
|
2075
|
+
async def root():
|
|
2076
|
+
return fastapi.responses.RedirectResponse(url='/dashboard/')
|
|
2077
|
+
|
|
2078
|
+
|
|
2079
|
+
def _init_or_restore_server_user_hash():
|
|
2080
|
+
"""Restores the server user hash from the global user state db.
|
|
2081
|
+
|
|
2082
|
+
The API server must have a stable user hash across restarts and potential
|
|
2083
|
+
multiple replicas. Thus we persist the user hash in db and restore it on
|
|
2084
|
+
startup. When upgrading from old version, the user hash will be read from
|
|
2085
|
+
the local file (if any) to keep the user hash consistent.
|
|
2086
|
+
"""
|
|
2087
|
+
|
|
2088
|
+
def apply_user_hash(user_hash: str) -> None:
|
|
2089
|
+
# For local API server, the user hash in db and local file should be
|
|
2090
|
+
# same so there is no harm to override here.
|
|
2091
|
+
common_utils.set_user_hash_locally(user_hash)
|
|
2092
|
+
# Refresh the server user hash for current process after restore or
|
|
2093
|
+
# initialize the user hash in db, child processes will get the correct
|
|
2094
|
+
# server id from the local cache file.
|
|
2095
|
+
common_lib.refresh_server_id()
|
|
2096
|
+
|
|
2097
|
+
user_hash = global_user_state.get_system_config(_SERVER_USER_HASH_KEY)
|
|
2098
|
+
if user_hash is not None:
|
|
2099
|
+
apply_user_hash(user_hash)
|
|
2100
|
+
return
|
|
2101
|
+
|
|
2102
|
+
# Initial deployment, generate a user hash and save it to the db.
|
|
2103
|
+
user_hash = common_utils.get_user_hash()
|
|
2104
|
+
global_user_state.set_system_config(_SERVER_USER_HASH_KEY, user_hash)
|
|
2105
|
+
apply_user_hash(user_hash)
|
|
2106
|
+
|
|
2107
|
+
|
|
1154
2108
|
if __name__ == '__main__':
|
|
1155
2109
|
import uvicorn
|
|
1156
2110
|
|
|
1157
2111
|
from sky.server import uvicorn as skyuvicorn
|
|
1158
2112
|
|
|
1159
|
-
|
|
2113
|
+
logger.info('Initializing SkyPilot API server')
|
|
2114
|
+
skyuvicorn.add_timestamp_prefix_for_server_logs()
|
|
1160
2115
|
|
|
1161
2116
|
parser = argparse.ArgumentParser()
|
|
1162
2117
|
parser.add_argument('--host', default='127.0.0.1')
|
|
1163
2118
|
parser.add_argument('--port', default=46580, type=int)
|
|
1164
2119
|
parser.add_argument('--deploy', action='store_true')
|
|
2120
|
+
# Serve metrics on a separate port to isolate it from the application APIs:
|
|
2121
|
+
# metrics port will not be exposed to the public network typically.
|
|
2122
|
+
parser.add_argument('--metrics-port', default=9090, type=int)
|
|
1165
2123
|
cmd_args = parser.parse_args()
|
|
2124
|
+
if cmd_args.port == cmd_args.metrics_port:
|
|
2125
|
+
logger.error('port and metrics-port cannot be the same, exiting.')
|
|
2126
|
+
raise ValueError('port and metrics-port cannot be the same')
|
|
2127
|
+
|
|
2128
|
+
# Fail fast if the port is not available to avoid corrupt the state
|
|
2129
|
+
# of potential running server instance.
|
|
2130
|
+
# We might reach here because the running server is currently not
|
|
2131
|
+
# responding, thus the healthz check fails and `sky api start` think
|
|
2132
|
+
# we should start a new server instance.
|
|
2133
|
+
if not common_utils.is_port_available(cmd_args.port):
|
|
2134
|
+
logger.error(f'Port {cmd_args.port} is not available, exiting.')
|
|
2135
|
+
raise RuntimeError(f'Port {cmd_args.port} is not available')
|
|
2136
|
+
|
|
2137
|
+
# Maybe touch the signal file on API server startup. Do it again here even
|
|
2138
|
+
# if we already touched it in the sky/server/common.py::_start_api_server.
|
|
2139
|
+
# This is because the sky/server/common.py::_start_api_server function call
|
|
2140
|
+
# is running outside the skypilot API server process tree. The process tree
|
|
2141
|
+
# starts within that function (see the `subprocess.Popen` call in
|
|
2142
|
+
# sky/server/common.py::_start_api_server). When pg is used, the
|
|
2143
|
+
# _start_api_server function will not load the config file from db, which
|
|
2144
|
+
# will ignore the consolidation mode config. Here, inside the process tree,
|
|
2145
|
+
# we already reload the config as a server (with env var _start_api_server),
|
|
2146
|
+
# so we will respect the consolidation mode config.
|
|
2147
|
+
# Refers to #7717 for more details.
|
|
2148
|
+
managed_job_utils.is_consolidation_mode(on_api_restart=True)
|
|
2149
|
+
|
|
1166
2150
|
# Show the privacy policy if it is not already shown. We place it here so
|
|
1167
2151
|
# that it is shown only when the API server is started.
|
|
1168
2152
|
usage_lib.maybe_show_privacy_policy()
|
|
1169
2153
|
|
|
1170
|
-
|
|
2154
|
+
# Initialize global user state db
|
|
2155
|
+
db_utils.set_max_connections(1)
|
|
2156
|
+
logger.info('Initializing database engine')
|
|
2157
|
+
global_user_state.initialize_and_get_db()
|
|
2158
|
+
logger.info('Database engine initialized')
|
|
2159
|
+
# Initialize request db
|
|
2160
|
+
requests_lib.reset_db_and_logs()
|
|
2161
|
+
# Restore the server user hash
|
|
2162
|
+
logger.info('Initializing server user hash')
|
|
2163
|
+
_init_or_restore_server_user_hash()
|
|
2164
|
+
|
|
2165
|
+
max_db_connections = global_user_state.get_max_db_connections()
|
|
2166
|
+
logger.info(f'Max db connections: {max_db_connections}')
|
|
2167
|
+
|
|
2168
|
+
# Reserve memory for jobs and serve/pool controller in consolidation mode.
|
|
2169
|
+
reserved_memory_mb = (
|
|
2170
|
+
controller_utils.compute_memory_reserved_for_controllers(
|
|
2171
|
+
reserve_for_controllers=os.environ.get(
|
|
2172
|
+
constants.OVERRIDE_CONSOLIDATION_MODE) is not None,
|
|
2173
|
+
# For jobs controller, we need to reserve for both jobs and
|
|
2174
|
+
# pool controller.
|
|
2175
|
+
reserve_extra_for_pool=not os.environ.get(
|
|
2176
|
+
constants.IS_SKYPILOT_SERVE_CONTROLLER)))
|
|
2177
|
+
|
|
2178
|
+
config = server_config.compute_server_config(
|
|
2179
|
+
cmd_args.deploy,
|
|
2180
|
+
max_db_connections,
|
|
2181
|
+
reserved_memory_mb=reserved_memory_mb)
|
|
2182
|
+
|
|
1171
2183
|
num_workers = config.num_server_workers
|
|
1172
2184
|
|
|
1173
|
-
|
|
2185
|
+
queue_server: Optional[multiprocessing.Process] = None
|
|
2186
|
+
workers: List[executor.RequestWorker] = []
|
|
2187
|
+
# Global background tasks that will be scheduled in a separate event loop.
|
|
2188
|
+
global_tasks: List[asyncio.Task] = []
|
|
1174
2189
|
try:
|
|
1175
|
-
|
|
2190
|
+
background = uvloop.new_event_loop()
|
|
2191
|
+
if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
|
|
2192
|
+
metrics_server = metrics.build_metrics_server(
|
|
2193
|
+
cmd_args.host, cmd_args.metrics_port)
|
|
2194
|
+
global_tasks.append(background.create_task(metrics_server.serve()))
|
|
2195
|
+
global_tasks.append(
|
|
2196
|
+
background.create_task(requests_lib.requests_gc_daemon()))
|
|
2197
|
+
global_tasks.append(
|
|
2198
|
+
background.create_task(
|
|
2199
|
+
global_user_state.cluster_event_retention_daemon()))
|
|
2200
|
+
threading.Thread(target=background.run_forever, daemon=True).start()
|
|
2201
|
+
|
|
2202
|
+
queue_server, workers = executor.start(config)
|
|
2203
|
+
|
|
1176
2204
|
logger.info(f'Starting SkyPilot API server, workers={num_workers}')
|
|
1177
2205
|
# We don't support reload for now, since it may cause leakage of request
|
|
1178
2206
|
# workers or interrupt running requests.
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
2207
|
+
uvicorn_config = uvicorn.Config('sky.server.server:app',
|
|
2208
|
+
host=cmd_args.host,
|
|
2209
|
+
port=cmd_args.port,
|
|
2210
|
+
workers=num_workers,
|
|
2211
|
+
ws_per_message_deflate=False)
|
|
2212
|
+
skyuvicorn.run(uvicorn_config,
|
|
2213
|
+
max_db_connections=config.num_db_connections_per_worker)
|
|
1184
2214
|
except Exception as exc: # pylint: disable=broad-except
|
|
1185
2215
|
logger.error(f'Failed to start SkyPilot API server: '
|
|
1186
2216
|
f'{common_utils.format_exception(exc, use_bracket=True)}')
|
|
@@ -1188,17 +2218,11 @@ if __name__ == '__main__':
|
|
|
1188
2218
|
finally:
|
|
1189
2219
|
logger.info('Shutting down SkyPilot API server...')
|
|
1190
2220
|
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
# Terminate processes in reverse order in case dependency, especially
|
|
1200
|
-
# queue server. Terminate queue server first does not affect the
|
|
1201
|
-
# correctness of cleanup but introduce redundant error messages.
|
|
1202
|
-
subprocess_utils.run_in_parallel(cleanup,
|
|
1203
|
-
list(reversed(sub_procs)),
|
|
1204
|
-
num_threads=len(sub_procs))
|
|
2221
|
+
for gt in global_tasks:
|
|
2222
|
+
gt.cancel()
|
|
2223
|
+
subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
|
|
2224
|
+
workers,
|
|
2225
|
+
num_threads=len(workers))
|
|
2226
|
+
if queue_server is not None:
|
|
2227
|
+
queue_server.kill()
|
|
2228
|
+
queue_server.join()
|