skypilot-nightly 1.0.0.dev20250626__py3-none-any.whl → 1.0.0.dev20250628__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/kubernetes.py +7 -0
  3. sky/adaptors/nebius.py +2 -2
  4. sky/admin_policy.py +27 -17
  5. sky/authentication.py +12 -5
  6. sky/backends/backend_utils.py +92 -26
  7. sky/check.py +5 -2
  8. sky/client/cli/command.py +38 -6
  9. sky/client/sdk.py +217 -167
  10. sky/client/service_account_auth.py +47 -0
  11. sky/clouds/aws.py +10 -4
  12. sky/clouds/azure.py +5 -2
  13. sky/clouds/cloud.py +5 -2
  14. sky/clouds/gcp.py +31 -18
  15. sky/clouds/kubernetes.py +54 -34
  16. sky/clouds/nebius.py +8 -2
  17. sky/clouds/ssh.py +5 -2
  18. sky/clouds/utils/aws_utils.py +10 -4
  19. sky/clouds/utils/gcp_utils.py +22 -7
  20. sky/clouds/utils/oci_utils.py +62 -14
  21. sky/dashboard/out/404.html +1 -1
  22. sky/dashboard/out/_next/static/{bs6UB9V4Jq10TIZ5x-kBK → ZYLkkWSYZjJhLVsObh20y}/_buildManifest.js +1 -1
  23. sky/dashboard/out/_next/static/chunks/43-f38a531f6692f281.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/601-111d06d9ded11d00.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/{616-d6128fa9e7cae6e6.js → 616-50a620ac4a23deb4.js} +1 -1
  26. sky/dashboard/out/_next/static/chunks/691.fd9292250ab089af.js +21 -0
  27. sky/dashboard/out/_next/static/chunks/{785.dc2686c3c1235554.js → 785.3446c12ffdf3d188.js} +1 -1
  28. sky/dashboard/out/_next/static/chunks/871-e547295e7e21399c.js +6 -0
  29. sky/dashboard/out/_next/static/chunks/937.72796f7afe54075b.js +1 -0
  30. sky/dashboard/out/_next/static/chunks/938-0a770415b5ce4649.js +1 -0
  31. sky/dashboard/out/_next/static/chunks/982.d7bd80ed18cad4cc.js +1 -0
  32. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-21080826c6095f21.js +6 -0
  33. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-77d4816945b04793.js +6 -0
  34. sky/dashboard/out/_next/static/chunks/pages/{clusters-f119a5630a1efd61.js → clusters-65b2c90320b8afb8.js} +1 -1
  35. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-64bdc0b2d3a44709.js +16 -0
  36. sky/dashboard/out/_next/static/chunks/pages/{jobs-0a5695ff3075d94a.js → jobs-df7407b5e37d3750.js} +1 -1
  37. sky/dashboard/out/_next/static/chunks/pages/{users-4978cbb093e141e7.js → users-d7684eaa04c4f58f.js} +1 -1
  38. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-cb7e720b739de53a.js → [name]-04e1b3ad4207b1e9.js} +1 -1
  39. sky/dashboard/out/_next/static/chunks/pages/{workspaces-50e230828730cfb3.js → workspaces-c470366a6179f16e.js} +1 -1
  40. sky/dashboard/out/_next/static/chunks/{webpack-08fdb9e6070127fc.js → webpack-75a3310ef922a299.js} +1 -1
  41. sky/dashboard/out/_next/static/css/605ac87514049058.css +3 -0
  42. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  43. sky/dashboard/out/clusters/[cluster].html +1 -1
  44. sky/dashboard/out/clusters.html +1 -1
  45. sky/dashboard/out/config.html +1 -1
  46. sky/dashboard/out/index.html +1 -1
  47. sky/dashboard/out/infra/[context].html +1 -1
  48. sky/dashboard/out/infra.html +1 -1
  49. sky/dashboard/out/jobs/[job].html +1 -1
  50. sky/dashboard/out/jobs.html +1 -1
  51. sky/dashboard/out/users.html +1 -1
  52. sky/dashboard/out/volumes.html +1 -1
  53. sky/dashboard/out/workspace/new.html +1 -1
  54. sky/dashboard/out/workspaces/[name].html +1 -1
  55. sky/dashboard/out/workspaces.html +1 -1
  56. sky/data/storage.py +8 -3
  57. sky/global_user_state.py +257 -9
  58. sky/jobs/client/sdk.py +20 -25
  59. sky/models.py +16 -0
  60. sky/provision/kubernetes/config.py +1 -1
  61. sky/provision/kubernetes/instance.py +7 -4
  62. sky/provision/kubernetes/network.py +15 -9
  63. sky/provision/kubernetes/network_utils.py +42 -23
  64. sky/provision/kubernetes/utils.py +73 -35
  65. sky/provision/nebius/utils.py +10 -4
  66. sky/resources.py +10 -4
  67. sky/serve/client/sdk.py +28 -34
  68. sky/server/common.py +51 -3
  69. sky/server/constants.py +3 -0
  70. sky/server/requests/executor.py +4 -0
  71. sky/server/requests/payloads.py +33 -0
  72. sky/server/requests/requests.py +19 -0
  73. sky/server/rest.py +6 -15
  74. sky/server/server.py +121 -6
  75. sky/skylet/constants.py +6 -0
  76. sky/skypilot_config.py +32 -4
  77. sky/users/permission.py +29 -0
  78. sky/users/server.py +384 -5
  79. sky/users/token_service.py +196 -0
  80. sky/utils/common_utils.py +4 -5
  81. sky/utils/config_utils.py +41 -0
  82. sky/utils/controller_utils.py +5 -1
  83. sky/utils/resource_checker.py +153 -0
  84. sky/utils/resources_utils.py +12 -4
  85. sky/utils/schemas.py +87 -60
  86. sky/utils/subprocess_utils.py +2 -6
  87. sky/workspaces/core.py +9 -117
  88. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/METADATA +1 -1
  89. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/RECORD +95 -92
  90. sky/dashboard/out/_next/static/chunks/43-36177d00f6956ab2.js +0 -1
  91. sky/dashboard/out/_next/static/chunks/690.55f9eed3be903f56.js +0 -16
  92. sky/dashboard/out/_next/static/chunks/871-3db673be3ee3750b.js +0 -6
  93. sky/dashboard/out/_next/static/chunks/937.3759f538f11a0953.js +0 -1
  94. sky/dashboard/out/_next/static/chunks/938-068520cc11738deb.js +0 -1
  95. sky/dashboard/out/_next/static/chunks/973-81b2d057178adb76.js +0 -1
  96. sky/dashboard/out/_next/static/chunks/982.1b61658204416b0f.js +0 -1
  97. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-aff040d7bc5d0086.js +0 -6
  98. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-8040f2483897ed0c.js +0 -6
  99. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-e4b23128db0774cd.js +0 -16
  100. sky/dashboard/out/_next/static/css/52082cf558ec9705.css +0 -3
  101. /sky/dashboard/out/_next/static/{bs6UB9V4Jq10TIZ5x-kBK → ZYLkkWSYZjJhLVsObh20y}/_ssgManifest.js +0 -0
  102. /sky/dashboard/out/_next/static/chunks/pages/{_app-9a3ce3170d2edcec.js → _app-050a9e637b057b24.js} +0 -0
  103. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/WHEEL +0 -0
  104. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/entry_points.txt +0 -0
  105. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/licenses/LICENSE +0 -0
  106. {skypilot_nightly-1.0.0.dev20250626.dist-info → skypilot_nightly-1.0.0.dev20250628.dist-info}/top_level.txt +0 -0
sky/client/sdk.py CHANGED
@@ -147,9 +147,8 @@ def check(infra_list: Optional[Tuple[str, ...]],
147
147
  body = payloads.CheckBody(clouds=clouds,
148
148
  verbose=verbose,
149
149
  workspace=workspace)
150
- response = rest.post(f'{server_common.get_server_url()}/check',
151
- json=json.loads(body.model_dump_json()),
152
- cookies=server_common.get_api_cookie_jar())
150
+ response = server_common.make_authenticated_request(
151
+ 'POST', '/check', json=json.loads(body.model_dump_json()))
153
152
  return server_common.get_request_id(response)
154
153
 
155
154
 
@@ -173,9 +172,8 @@ def enabled_clouds(workspace: Optional[str] = None,
173
172
  """
174
173
  if workspace is None:
175
174
  workspace = skypilot_config.get_active_workspace()
176
- response = rest.get((f'{server_common.get_server_url()}/enabled_clouds?'
177
- f'workspace={workspace}&expand={expand}'),
178
- cookies=server_common.get_api_cookie_jar())
175
+ response = server_common.make_authenticated_request(
176
+ 'GET', f'/enabled_clouds?workspace={workspace}&expand={expand}')
179
177
  return server_common.get_request_id(response)
180
178
 
181
179
 
@@ -223,9 +221,8 @@ def list_accelerators(gpus_only: bool = True,
223
221
  require_price=require_price,
224
222
  case_sensitive=case_sensitive,
225
223
  )
226
- response = rest.post(f'{server_common.get_server_url()}/list_accelerators',
227
- json=json.loads(body.model_dump_json()),
228
- cookies=server_common.get_api_cookie_jar())
224
+ response = server_common.make_authenticated_request(
225
+ 'POST', '/list_accelerators', json=json.loads(body.model_dump_json()))
229
226
  return server_common.get_request_id(response)
230
227
 
231
228
 
@@ -263,10 +260,10 @@ def list_accelerator_counts(
263
260
  quantity_filter=quantity_filter,
264
261
  clouds=clouds,
265
262
  )
266
- response = rest.post(
267
- f'{server_common.get_server_url()}/list_accelerator_counts',
268
- json=json.loads(body.model_dump_json()),
269
- cookies=server_common.get_api_cookie_jar())
263
+ response = server_common.make_authenticated_request(
264
+ 'POST',
265
+ '/list_accelerator_counts',
266
+ json=json.loads(body.model_dump_json()))
270
267
  return server_common.get_request_id(response)
271
268
 
272
269
 
@@ -303,16 +300,14 @@ def optimize(
303
300
  body = payloads.OptimizeBody(dag=dag_str,
304
301
  minimize=minimize,
305
302
  request_options=admin_policy_request_options)
306
- response = rest.post(f'{server_common.get_server_url()}/optimize',
307
- json=json.loads(body.model_dump_json()),
308
- cookies=server_common.get_api_cookie_jar())
303
+ response = server_common.make_authenticated_request(
304
+ 'POST', '/optimize', json=json.loads(body.model_dump_json()))
309
305
  return server_common.get_request_id(response)
310
306
 
311
307
 
312
308
  def workspaces() -> server_common.RequestId:
313
309
  """Gets the workspaces."""
314
- response = rest.get(f'{server_common.get_server_url()}/workspaces',
315
- cookies=server_common.get_api_cookie_jar())
310
+ response = server_common.make_authenticated_request('GET', '/workspaces')
316
311
  return server_common.get_request_id(response)
317
312
 
318
313
 
@@ -346,9 +341,8 @@ def validate(
346
341
  dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
347
342
  body = payloads.ValidateBody(dag=dag_str,
348
343
  request_options=admin_policy_request_options)
349
- response = rest.post(f'{server_common.get_server_url()}/validate',
350
- json=json.loads(body.model_dump_json()),
351
- cookies=server_common.get_api_cookie_jar())
344
+ response = server_common.make_authenticated_request(
345
+ 'POST', '/validate', json=json.loads(body.model_dump_json()))
352
346
  if response.status_code == 400:
353
347
  with ux_utils.print_exception_no_traceback():
354
348
  raise exceptions.deserialize_exception(
@@ -632,12 +626,8 @@ def _launch(
632
626
  _is_launched_by_sky_serve_controller),
633
627
  disable_controller_check=_disable_controller_check,
634
628
  )
635
- response = rest.post(
636
- f'{server_common.get_server_url()}/launch',
637
- json=json.loads(body.model_dump_json()),
638
- timeout=5,
639
- cookies=server_common.get_api_cookie_jar(),
640
- )
629
+ response = server_common.make_authenticated_request(
630
+ 'POST', '/launch', json=json.loads(body.model_dump_json()), timeout=5)
641
631
  return server_common.get_request_id(response)
642
632
 
643
633
 
@@ -716,12 +706,8 @@ def exec( # pylint: disable=redefined-builtin
716
706
  backend=backend.NAME if backend else None,
717
707
  )
718
708
 
719
- response = rest.post(
720
- f'{server_common.get_server_url()}/exec',
721
- json=json.loads(body.model_dump_json()),
722
- timeout=5,
723
- cookies=server_common.get_api_cookie_jar(),
724
- )
709
+ response = server_common.make_authenticated_request(
710
+ 'POST', '/exec', json=json.loads(body.model_dump_json()), timeout=5)
725
711
  return server_common.get_request_id(response)
726
712
 
727
713
 
@@ -769,13 +755,13 @@ def tail_logs(cluster_name: str,
769
755
  follow=follow,
770
756
  tail=tail,
771
757
  )
772
- response = rest.post(
773
- f'{server_common.get_server_url()}/logs',
758
+ response = server_common.make_authenticated_request(
759
+ 'POST',
760
+ '/logs',
774
761
  json=json.loads(body.model_dump_json()),
775
762
  stream=True,
776
763
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
777
- None),
778
- cookies=server_common.get_api_cookie_jar())
764
+ None))
779
765
  request_id = server_common.get_request_id(response)
780
766
  # Log request is idempotent when tail is 0, thus can resume previous
781
767
  # streaming point on retry.
@@ -816,9 +802,8 @@ def download_logs(cluster_name: str,
816
802
  cluster_name=cluster_name,
817
803
  job_ids=job_ids,
818
804
  )
819
- response = rest.post(f'{server_common.get_server_url()}/download_logs',
820
- json=json.loads(body.model_dump_json()),
821
- cookies=server_common.get_api_cookie_jar())
805
+ response = server_common.make_authenticated_request(
806
+ 'POST', '/download_logs', json=json.loads(body.model_dump_json()))
822
807
  job_id_remote_path_dict = stream_and_get(
823
808
  server_common.get_request_id(response))
824
809
  remote2local_path_dict = client_common.download_logs_from_api_server(
@@ -896,12 +881,8 @@ def start(
896
881
  down=down,
897
882
  force=force,
898
883
  )
899
- response = rest.post(
900
- f'{server_common.get_server_url()}/start',
901
- json=json.loads(body.model_dump_json()),
902
- timeout=5,
903
- cookies=server_common.get_api_cookie_jar(),
904
- )
884
+ response = server_common.make_authenticated_request(
885
+ 'POST', '/start', json=json.loads(body.model_dump_json()), timeout=5)
905
886
  return server_common.get_request_id(response)
906
887
 
907
888
 
@@ -942,12 +923,8 @@ def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
942
923
  cluster_name=cluster_name,
943
924
  purge=purge,
944
925
  )
945
- response = rest.post(
946
- f'{server_common.get_server_url()}/down',
947
- json=json.loads(body.model_dump_json()),
948
- timeout=5,
949
- cookies=server_common.get_api_cookie_jar(),
950
- )
926
+ response = server_common.make_authenticated_request(
927
+ 'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
951
928
  return server_common.get_request_id(response)
952
929
 
953
930
 
@@ -991,12 +968,8 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
991
968
  cluster_name=cluster_name,
992
969
  purge=purge,
993
970
  )
994
- response = rest.post(
995
- f'{server_common.get_server_url()}/stop',
996
- json=json.loads(body.model_dump_json()),
997
- timeout=5,
998
- cookies=server_common.get_api_cookie_jar(),
999
- )
971
+ response = server_common.make_authenticated_request(
972
+ 'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
1000
973
  return server_common.get_request_id(response)
1001
974
 
1002
975
 
@@ -1061,12 +1034,8 @@ def autostop(
1061
1034
  idle_minutes=idle_minutes,
1062
1035
  down=down,
1063
1036
  )
1064
- response = rest.post(
1065
- f'{server_common.get_server_url()}/autostop',
1066
- json=json.loads(body.model_dump_json()),
1067
- timeout=5,
1068
- cookies=server_common.get_api_cookie_jar(),
1069
- )
1037
+ response = server_common.make_authenticated_request(
1038
+ 'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
1070
1039
  return server_common.get_request_id(response)
1071
1040
 
1072
1041
 
@@ -1124,9 +1093,8 @@ def queue(cluster_name: str,
1124
1093
  skip_finished=skip_finished,
1125
1094
  all_users=all_users,
1126
1095
  )
1127
- response = rest.post(f'{server_common.get_server_url()}/queue',
1128
- json=json.loads(body.model_dump_json()),
1129
- cookies=server_common.get_api_cookie_jar())
1096
+ response = server_common.make_authenticated_request(
1097
+ 'POST', '/queue', json=json.loads(body.model_dump_json()))
1130
1098
  return server_common.get_request_id(response)
1131
1099
 
1132
1100
 
@@ -1166,9 +1134,8 @@ def job_status(cluster_name: str,
1166
1134
  cluster_name=cluster_name,
1167
1135
  job_ids=job_ids,
1168
1136
  )
1169
- response = rest.post(f'{server_common.get_server_url()}/job_status',
1170
- json=json.loads(body.model_dump_json()),
1171
- cookies=server_common.get_api_cookie_jar())
1137
+ response = server_common.make_authenticated_request(
1138
+ 'POST', '/job_status', json=json.loads(body.model_dump_json()))
1172
1139
  return server_common.get_request_id(response)
1173
1140
 
1174
1141
 
@@ -1220,9 +1187,8 @@ def cancel(
1220
1187
  job_ids=job_ids,
1221
1188
  try_cancel_if_cluster_is_init=_try_cancel_if_cluster_is_init,
1222
1189
  )
1223
- response = rest.post(f'{server_common.get_server_url()}/cancel',
1224
- json=json.loads(body.model_dump_json()),
1225
- cookies=server_common.get_api_cookie_jar())
1190
+ response = server_common.make_authenticated_request(
1191
+ 'POST', '/cancel', json=json.loads(body.model_dump_json()))
1226
1192
  return server_common.get_request_id(response)
1227
1193
 
1228
1194
 
@@ -1316,9 +1282,8 @@ def status(
1316
1282
  refresh=refresh,
1317
1283
  all_users=all_users,
1318
1284
  )
1319
- response = rest.post(f'{server_common.get_server_url()}/status',
1320
- json=json.loads(body.model_dump_json()),
1321
- cookies=server_common.get_api_cookie_jar())
1285
+ response = server_common.make_authenticated_request(
1286
+ 'POST', '/status', json=json.loads(body.model_dump_json()))
1322
1287
  return server_common.get_request_id(response)
1323
1288
 
1324
1289
 
@@ -1351,9 +1316,8 @@ def endpoints(
1351
1316
  cluster=cluster,
1352
1317
  port=port,
1353
1318
  )
1354
- response = rest.post(f'{server_common.get_server_url()}/endpoints',
1355
- json=json.loads(body.model_dump_json()),
1356
- cookies=server_common.get_api_cookie_jar())
1319
+ response = server_common.make_authenticated_request(
1320
+ 'POST', '/endpoints', json=json.loads(body.model_dump_json()))
1357
1321
  return server_common.get_request_id(response)
1358
1322
 
1359
1323
 
@@ -1396,9 +1360,8 @@ def cost_report(days: Optional[int] = None) -> server_common.RequestId: # pylin
1396
1360
  }
1397
1361
  """
1398
1362
  body = payloads.CostReportBody(days=days)
1399
- response = rest.post(f'{server_common.get_server_url()}/cost_report',
1400
- json=json.loads(body.model_dump_json()),
1401
- cookies=server_common.get_api_cookie_jar())
1363
+ response = server_common.make_authenticated_request(
1364
+ 'POST', '/cost_report', json=json.loads(body.model_dump_json()))
1402
1365
  return server_common.get_request_id(response)
1403
1366
 
1404
1367
 
@@ -1427,8 +1390,7 @@ def storage_ls() -> server_common.RequestId:
1427
1390
  }
1428
1391
  ]
1429
1392
  """
1430
- response = rest.get(f'{server_common.get_server_url()}/storage/ls',
1431
- cookies=server_common.get_api_cookie_jar())
1393
+ response = server_common.make_authenticated_request('GET', '/storage/ls')
1432
1394
  return server_common.get_request_id(response)
1433
1395
 
1434
1396
 
@@ -1451,9 +1413,8 @@ def storage_delete(name: str) -> server_common.RequestId:
1451
1413
  ValueError: If the storage does not exist.
1452
1414
  """
1453
1415
  body = payloads.StorageBody(name=name)
1454
- response = rest.post(f'{server_common.get_server_url()}/storage/delete',
1455
- json=json.loads(body.model_dump_json()),
1456
- cookies=server_common.get_api_cookie_jar())
1416
+ response = server_common.make_authenticated_request(
1417
+ 'POST', '/storage/delete', json=json.loads(body.model_dump_json()))
1457
1418
  return server_common.get_request_id(response)
1458
1419
 
1459
1420
 
@@ -1490,9 +1451,8 @@ def local_up(gpus: bool,
1490
1451
  cleanup=cleanup,
1491
1452
  context_name=context_name,
1492
1453
  password=password)
1493
- response = rest.post(f'{server_common.get_server_url()}/local_up',
1494
- json=json.loads(body.model_dump_json()),
1495
- cookies=server_common.get_api_cookie_jar())
1454
+ response = server_common.make_authenticated_request(
1455
+ 'POST', '/local_up', json=json.loads(body.model_dump_json()))
1496
1456
  return server_common.get_request_id(response)
1497
1457
 
1498
1458
 
@@ -1508,8 +1468,7 @@ def local_down() -> server_common.RequestId:
1508
1468
  with ux_utils.print_exception_no_traceback():
1509
1469
  raise ValueError('sky local down is only supported when running '
1510
1470
  'SkyPilot locally.')
1511
- response = rest.post(f'{server_common.get_server_url()}/local_down',
1512
- cookies=server_common.get_api_cookie_jar())
1471
+ response = server_common.make_authenticated_request('POST', '/local_down')
1513
1472
  return server_common.get_request_id(response)
1514
1473
 
1515
1474
 
@@ -1538,9 +1497,9 @@ def _update_remote_ssh_node_pools(file: str,
1538
1497
  hosts_info = ssh_utils.prepare_hosts_info(
1539
1498
  name, pool_config, upload_ssh_key_func=_upload_ssh_key_and_wait)
1540
1499
  pools_config[name] = {'hosts': hosts_info}
1541
- rest.post(f'{server_common.get_server_url()}/ssh_node_pools',
1542
- json=pools_config,
1543
- cookies=server_common.get_api_cookie_jar())
1500
+ server_common.make_authenticated_request('POST',
1501
+ '/ssh_node_pools',
1502
+ json=pools_config)
1544
1503
 
1545
1504
 
1546
1505
  def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
@@ -1559,8 +1518,9 @@ def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
1559
1518
  raise ValueError(f'SSH key file not found: {key_file_path}')
1560
1519
 
1561
1520
  with open(os.path.expanduser(key_file_path), 'rb') as key_file:
1562
- response = rest.post(
1563
- f'{server_common.get_server_url()}/ssh_node_pools/keys',
1521
+ response = server_common.make_authenticated_request(
1522
+ 'POST',
1523
+ '/ssh_node_pools/keys',
1564
1524
  files={
1565
1525
  'key_file': (key_name, key_file, 'application/octet-stream')
1566
1526
  },
@@ -1593,15 +1553,14 @@ def ssh_up(infra: Optional[str] = None,
1593
1553
  body = payloads.SSHUpBody(infra=infra, cleanup=False)
1594
1554
  if infra is not None:
1595
1555
  # Call the specific pool deployment endpoint
1596
- response = rest.post(
1597
- f'{server_common.get_server_url()}/ssh_node_pools/{infra}/deploy',
1598
- cookies=server_common.get_api_cookie_jar())
1556
+ response = server_common.make_authenticated_request(
1557
+ 'POST', f'/ssh_node_pools/{infra}/deploy')
1599
1558
  else:
1600
1559
  # Call the general deployment endpoint
1601
- response = rest.post(
1602
- f'{server_common.get_server_url()}/ssh_node_pools/deploy',
1603
- json=json.loads(body.model_dump_json()),
1604
- cookies=server_common.get_api_cookie_jar())
1560
+ response = server_common.make_authenticated_request(
1561
+ 'POST',
1562
+ '/ssh_node_pools/deploy',
1563
+ json=json.loads(body.model_dump_json()))
1605
1564
  return server_common.get_request_id(response)
1606
1565
 
1607
1566
 
@@ -1622,15 +1581,14 @@ def ssh_down(infra: Optional[str] = None) -> server_common.RequestId:
1622
1581
  body = payloads.SSHUpBody(infra=infra, cleanup=True)
1623
1582
  if infra is not None:
1624
1583
  # Call the specific pool down endpoint
1625
- response = rest.post(
1626
- f'{server_common.get_server_url()}/ssh_node_pools/{infra}/down',
1627
- cookies=server_common.get_api_cookie_jar())
1584
+ response = server_common.make_authenticated_request(
1585
+ 'POST', f'/ssh_node_pools/{infra}/down')
1628
1586
  else:
1629
1587
  # Call the general down endpoint
1630
- response = rest.post(
1631
- f'{server_common.get_server_url()}/ssh_node_pools/down',
1632
- json=json.loads(body.model_dump_json()),
1633
- cookies=server_common.get_api_cookie_jar())
1588
+ response = server_common.make_authenticated_request(
1589
+ 'POST',
1590
+ '/ssh_node_pools/down',
1591
+ json=json.loads(body.model_dump_json()))
1634
1592
  return server_common.get_request_id(response)
1635
1593
 
1636
1594
 
@@ -1653,11 +1611,10 @@ def realtime_kubernetes_gpu_availability(
1653
1611
  quantity_filter=quantity_filter,
1654
1612
  is_ssh=is_ssh,
1655
1613
  )
1656
- response = rest.post(
1657
- f'{server_common.get_server_url()}/'
1658
- 'realtime_kubernetes_gpu_availability',
1659
- json=json.loads(body.model_dump_json()),
1660
- cookies=server_common.get_api_cookie_jar())
1614
+ response = server_common.make_authenticated_request(
1615
+ 'POST',
1616
+ '/realtime_kubernetes_gpu_availability',
1617
+ json=json.loads(body.model_dump_json()))
1661
1618
  return server_common.get_request_id(response)
1662
1619
 
1663
1620
 
@@ -1686,10 +1643,10 @@ def kubernetes_node_info(
1686
1643
  information.
1687
1644
  """
1688
1645
  body = payloads.KubernetesNodeInfoRequestBody(context=context)
1689
- response = rest.post(
1690
- f'{server_common.get_server_url()}/kubernetes_node_info',
1691
- json=json.loads(body.model_dump_json()),
1692
- cookies=server_common.get_api_cookie_jar())
1646
+ response = server_common.make_authenticated_request(
1647
+ 'POST',
1648
+ '/kubernetes_node_info',
1649
+ json=json.loads(body.model_dump_json()))
1693
1650
  return server_common.get_request_id(response)
1694
1651
 
1695
1652
 
@@ -1717,8 +1674,8 @@ def status_kubernetes() -> server_common.RequestId:
1717
1674
  dictionary job info, see jobs.queue_from_kubernetes_pod for details.
1718
1675
  - context: Kubernetes context used to fetch the cluster information.
1719
1676
  """
1720
- response = rest.get(f'{server_common.get_server_url()}/status_kubernetes',
1721
- cookies=server_common.get_api_cookie_jar())
1677
+ response = server_common.make_authenticated_request('GET',
1678
+ '/status_kubernetes')
1722
1679
  return server_common.get_request_id(response)
1723
1680
 
1724
1681
 
@@ -1744,11 +1701,12 @@ def get(request_id: str) -> Any:
1744
1701
  see ``Request Raises`` in the documentation of the specific requests
1745
1702
  above.
1746
1703
  """
1747
- response = rest.get_without_retry(
1748
- f'{server_common.get_server_url()}/api/get?request_id={request_id}',
1704
+ response = server_common.make_authenticated_request(
1705
+ 'GET',
1706
+ f'/api/get?request_id={request_id}',
1707
+ retry=False,
1749
1708
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1750
- None),
1751
- cookies=server_common.get_api_cookie_jar())
1709
+ None))
1752
1710
  request_task = None
1753
1711
  if response.status_code == 200:
1754
1712
  request_task = requests_lib.Request.decode(
@@ -1822,13 +1780,14 @@ def stream_and_get(
1822
1780
  'follow': follow,
1823
1781
  'format': 'console',
1824
1782
  }
1825
- response = rest.get_without_retry(
1826
- f'{server_common.get_server_url()}/api/stream',
1783
+ response = server_common.make_authenticated_request(
1784
+ 'GET',
1785
+ '/api/stream',
1827
1786
  params=params,
1787
+ retry=False,
1828
1788
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1829
1789
  None),
1830
- stream=True,
1831
- cookies=server_common.get_api_cookie_jar())
1790
+ stream=True)
1832
1791
  if response.status_code in [404, 400]:
1833
1792
  detail = response.json().get('detail')
1834
1793
  with ux_utils.print_exception_no_traceback():
@@ -1882,10 +1841,11 @@ def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
1882
1841
  echo(f'Cancelling {len(request_ids)} request{plural}: '
1883
1842
  f'{request_id_str}...')
1884
1843
 
1885
- response = rest.post(f'{server_common.get_server_url()}/api/cancel',
1886
- json=json.loads(body.model_dump_json()),
1887
- timeout=5,
1888
- cookies=server_common.get_api_cookie_jar())
1844
+ response = server_common.make_authenticated_request(
1845
+ 'POST',
1846
+ '/api/cancel',
1847
+ json=json.loads(body.model_dump_json()),
1848
+ timeout=5)
1889
1849
  return server_common.get_request_id(response)
1890
1850
 
1891
1851
 
@@ -1909,12 +1869,12 @@ def api_status(
1909
1869
  """
1910
1870
  body = payloads.RequestStatusBody(request_ids=request_ids,
1911
1871
  all_status=all_status)
1912
- response = rest.get(
1913
- f'{server_common.get_server_url()}/api/status',
1872
+ response = server_common.make_authenticated_request(
1873
+ 'GET',
1874
+ '/api/status',
1914
1875
  params=server_common.request_body_to_params(body),
1915
1876
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1916
- None),
1917
- cookies=server_common.get_api_cookie_jar())
1877
+ None))
1918
1878
  server_common.handle_request_error(response)
1919
1879
  return [
1920
1880
  requests_lib.RequestPayload(**request) for request in response.json()
@@ -1948,8 +1908,7 @@ def api_info() -> Dict[str, Any]:
1948
1908
  Note that user may be None if we are not using an auth proxy.
1949
1909
 
1950
1910
  """
1951
- response = rest.get(f'{server_common.get_server_url()}/api/health',
1952
- cookies=server_common.get_api_cookie_jar())
1911
+ response = server_common.make_authenticated_request('GET', '/api/health')
1953
1912
  response.raise_for_status()
1954
1913
  return response.json()
1955
1914
 
@@ -2074,9 +2033,53 @@ def api_server_logs(follow: bool = True, tail: Optional[int] = None) -> None:
2074
2033
  stream_and_get(log_path=constants.API_SERVER_LOGS, tail=tail)
2075
2034
 
2076
2035
 
2036
+ def _save_config_updates(endpoint: Optional[str] = None,
2037
+ service_account_token: Optional[str] = None) -> None:
2038
+ """Save endpoint and/or service account token to config file."""
2039
+ config_path = pathlib.Path(
2040
+ skypilot_config.get_user_config_path()).expanduser()
2041
+ with filelock.FileLock(config_path.with_suffix('.lock')):
2042
+ if not config_path.exists():
2043
+ config_path.touch()
2044
+ config: Dict[str, Any] = {}
2045
+ else:
2046
+ config = skypilot_config.get_user_config()
2047
+ config = dict(config)
2048
+
2049
+ # Update endpoint if provided
2050
+ if endpoint is not None:
2051
+ # We should always reset the api_server config to avoid legacy
2052
+ # service account token.
2053
+ config['api_server'] = {}
2054
+ config['api_server']['endpoint'] = endpoint
2055
+
2056
+ # Update service account token if provided
2057
+ if service_account_token is not None:
2058
+ if 'api_server' not in config:
2059
+ config['api_server'] = {}
2060
+ config['api_server'][
2061
+ 'service_account_token'] = service_account_token
2062
+
2063
+ common_utils.dump_yaml(str(config_path), config)
2064
+ skypilot_config.reload_config()
2065
+
2066
+
2067
+ def _validate_endpoint(endpoint: Optional[str]) -> str:
2068
+ """Validate and normalize the endpoint URL."""
2069
+ if endpoint is None:
2070
+ endpoint = click.prompt('Enter your SkyPilot API server endpoint')
2071
+ # Check endpoint is a valid URL
2072
+ if (endpoint is not None and not endpoint.startswith('http://') and
2073
+ not endpoint.startswith('https://')):
2074
+ raise click.BadParameter('Endpoint must be a valid URL.')
2075
+ return endpoint.rstrip('/')
2076
+
2077
+
2077
2078
  @usage_lib.entrypoint
2078
2079
  @annotations.client_api
2079
- def api_login(endpoint: Optional[str] = None, get_token: bool = False) -> None:
2080
+ def api_login(endpoint: Optional[str] = None,
2081
+ relogin: bool = False,
2082
+ service_account_token: Optional[str] = None) -> None:
2080
2083
  """Logs into a SkyPilot API server.
2081
2084
 
2082
2085
  This sets the endpoint globally, i.e., all SkyPilot CLI and SDK calls will
@@ -2088,25 +2091,80 @@ def api_login(endpoint: Optional[str] = None, get_token: bool = False) -> None:
2088
2091
  Args:
2089
2092
  endpoint: The endpoint of the SkyPilot API server, e.g.,
2090
2093
  http://1.2.3.4:46580 or https://skypilot.mydomain.com.
2091
- get_token: Whether to force getting a new token even if not needed.
2094
+ relogin: Whether to force relogin with OAuth2 when enabled.
2095
+ service_account_token: Service account token for authentication.
2092
2096
 
2093
2097
  Returns:
2094
2098
  None
2095
2099
  """
2100
+ # Validate and normalize endpoint
2101
+ endpoint = _validate_endpoint(endpoint)
2102
+
2103
+ def _show_logged_in_message(
2104
+ endpoint: str, dashboard_url: str, user: Optional[Dict[str, Any]],
2105
+ server_status: server_common.ApiServerStatus) -> None:
2106
+ """Show the logged in message."""
2107
+ if server_status != server_common.ApiServerStatus.HEALTHY:
2108
+ with ux_utils.print_exception_no_traceback():
2109
+ raise ValueError(f'Cannot log in API server at '
2110
+ f'{endpoint} (status: {server_status.value})')
2111
+
2112
+ identity_info = f'\n{ux_utils.INDENT_SYMBOL}{colorama.Fore.GREEN}User: '
2113
+ if user:
2114
+ user_name = user.get('name')
2115
+ user_id = user.get('id')
2116
+ if user_name and user_id:
2117
+ identity_info += f'{user_name} ({user_id})'
2118
+ elif user_id:
2119
+ identity_info += user_id
2120
+ else:
2121
+ identity_info = ''
2122
+ dashboard_msg = f'Dashboard: {dashboard_url}'
2123
+ click.secho(
2124
+ f'Logged into SkyPilot API server at: {endpoint}'
2125
+ f'{identity_info}'
2126
+ f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
2127
+ f'{dashboard_msg}',
2128
+ fg='green')
2129
+
2130
+ # Handle service account token authentication
2131
+ if service_account_token:
2132
+ if not service_account_token.startswith('sky_'):
2133
+ raise ValueError('Invalid service account token format. '
2134
+ 'Token must start with "sky_"')
2135
+
2136
+ # Save both endpoint and token to config in a single operation
2137
+ _save_config_updates(endpoint=endpoint,
2138
+ service_account_token=service_account_token)
2139
+
2140
+ # Test the authentication by checking server health
2141
+ try:
2142
+ server_status, api_server_info = server_common.check_server_healthy(
2143
+ endpoint)
2144
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2145
+ _show_logged_in_message(endpoint, dashboard_url,
2146
+ api_server_info.user, server_status)
2147
+
2148
+ return
2149
+ except exceptions.ApiServerConnectionError as e:
2150
+ with ux_utils.print_exception_no_traceback():
2151
+ raise RuntimeError(
2152
+ f'Failed to connect to API server at {endpoint}: {e}'
2153
+ ) from e
2154
+ except Exception as e: # pylint: disable=broad-except
2155
+ with ux_utils.print_exception_no_traceback():
2156
+ raise RuntimeError(
2157
+ f'{colorama.Fore.RED}Service account token authentication '
2158
+ f'failed:{colorama.Style.RESET_ALL} {e}') from None
2159
+
2160
+ # OAuth2/cookie-based authentication flow
2096
2161
  # TODO(zhwu): this SDK sets global endpoint, which may not be the best
2097
2162
  # design as a user may expect this is only effective for the current
2098
2163
  # session. We should consider using env var for specifying endpoint.
2099
- if endpoint is None:
2100
- endpoint = click.prompt('Enter your SkyPilot API server endpoint')
2101
- # Check endpoint is a valid URL
2102
- if (endpoint is not None and not endpoint.startswith('http://') and
2103
- not endpoint.startswith('https://')):
2104
- raise click.BadParameter('Endpoint must be a valid URL.')
2105
- endpoint = endpoint.rstrip('/')
2106
2164
 
2107
2165
  server_status, api_server_info = server_common.check_server_healthy(
2108
2166
  endpoint)
2109
- if server_status == server_common.ApiServerStatus.NEEDS_AUTH or get_token:
2167
+ if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
2110
2168
  # We detected an auth proxy, so go through the auth proxy cookie flow.
2111
2169
  token: Optional[str] = None
2112
2170
  server: Optional[oauth_lib.HTTPServer] = None
@@ -2253,20 +2311,12 @@ def api_login(endpoint: Optional[str] = None, get_token: bool = False) -> None:
2253
2311
  f.write(user_hash)
2254
2312
 
2255
2313
  # Set the endpoint in the config file
2256
- config_path = pathlib.Path(
2257
- skypilot_config.get_user_config_path()).expanduser()
2258
- with filelock.FileLock(config_path.with_suffix('.lock')):
2259
- if not config_path.exists():
2260
- config_path.touch()
2261
- config = {'api_server': {'endpoint': endpoint}}
2262
- else:
2263
- config = skypilot_config.get_user_config()
2264
- config.set_nested(('api_server', 'endpoint'), endpoint)
2265
- common_utils.dump_yaml(str(config_path), dict(config))
2266
- dashboard_url = server_common.get_dashboard_url(endpoint)
2267
- dashboard_msg = f'Dashboard: {dashboard_url}'
2268
- click.secho(
2269
- f'Logged into SkyPilot API server at: {endpoint}'
2270
- f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
2271
- f'{dashboard_msg}',
2272
- fg='green')
2314
+ _save_config_updates(endpoint=endpoint)
2315
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2316
+
2317
+ # After successful authentication, check server health again to get user
2318
+ # identity
2319
+ server_status, final_api_server_info = server_common.check_server_healthy(
2320
+ endpoint)
2321
+ _show_logged_in_message(endpoint, dashboard_url, final_api_server_info.user,
2322
+ server_status)