skypilot-nightly 1.0.0.dev20250616__py3-none-any.whl → 1.0.0.dev20250618__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sky/__init__.py +2 -4
  2. sky/backends/backend_utils.py +7 -0
  3. sky/backends/cloud_vm_ray_backend.py +91 -96
  4. sky/cli.py +5 -6311
  5. sky/client/cli.py +66 -639
  6. sky/client/sdk.py +22 -2
  7. sky/clouds/kubernetes.py +8 -0
  8. sky/clouds/scp.py +7 -26
  9. sky/clouds/utils/scp_utils.py +177 -124
  10. sky/dashboard/out/404.html +1 -1
  11. sky/dashboard/out/_next/static/{OZxMW3bxAJmqgn5f4MdhO → LRpGymRCqq-feuFyoWz4m}/_buildManifest.js +1 -1
  12. sky/dashboard/out/_next/static/chunks/641.c8e452bc5070a630.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/984.ae8c08791d274ca0.js +50 -0
  14. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-36bc0962129f72df.js +6 -0
  15. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-cf490d1fa38f3740.js +16 -0
  16. sky/dashboard/out/_next/static/chunks/pages/users-928edf039219e47b.js +1 -0
  17. sky/dashboard/out/_next/static/chunks/webpack-ebc2404fd6ce581c.js +1 -0
  18. sky/dashboard/out/_next/static/css/6c12ecc3bd2239b6.css +3 -0
  19. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  20. sky/dashboard/out/clusters/[cluster].html +1 -1
  21. sky/dashboard/out/clusters.html +1 -1
  22. sky/dashboard/out/config.html +1 -1
  23. sky/dashboard/out/index.html +1 -1
  24. sky/dashboard/out/infra/[context].html +1 -1
  25. sky/dashboard/out/infra.html +1 -1
  26. sky/dashboard/out/jobs/[job].html +1 -1
  27. sky/dashboard/out/jobs.html +1 -1
  28. sky/dashboard/out/users.html +1 -1
  29. sky/dashboard/out/workspace/new.html +1 -1
  30. sky/dashboard/out/workspaces/[name].html +1 -1
  31. sky/dashboard/out/workspaces.html +1 -1
  32. sky/global_user_state.py +50 -11
  33. sky/jobs/controller.py +98 -31
  34. sky/jobs/scheduler.py +37 -29
  35. sky/jobs/server/core.py +36 -3
  36. sky/jobs/state.py +69 -9
  37. sky/jobs/utils.py +11 -0
  38. sky/logs/__init__.py +17 -0
  39. sky/logs/agent.py +73 -0
  40. sky/logs/gcp.py +91 -0
  41. sky/models.py +1 -0
  42. sky/provision/__init__.py +1 -0
  43. sky/provision/instance_setup.py +35 -0
  44. sky/provision/provisioner.py +11 -0
  45. sky/provision/scp/__init__.py +15 -0
  46. sky/provision/scp/config.py +93 -0
  47. sky/provision/scp/instance.py +528 -0
  48. sky/resources.py +164 -29
  49. sky/server/common.py +21 -9
  50. sky/server/requests/payloads.py +19 -1
  51. sky/server/server.py +121 -29
  52. sky/setup_files/dependencies.py +11 -1
  53. sky/skylet/constants.py +48 -1
  54. sky/skylet/job_lib.py +83 -19
  55. sky/task.py +171 -21
  56. sky/templates/kubernetes-ray.yml.j2 +60 -4
  57. sky/templates/scp-ray.yml.j2 +3 -50
  58. sky/users/permission.py +47 -34
  59. sky/users/rbac.py +10 -1
  60. sky/users/server.py +274 -9
  61. sky/utils/command_runner.py +1 -1
  62. sky/utils/common_utils.py +16 -14
  63. sky/utils/context.py +1 -1
  64. sky/utils/controller_utils.py +12 -3
  65. sky/utils/dag_utils.py +17 -4
  66. sky/utils/kubernetes/deploy_remote_cluster.py +17 -8
  67. sky/utils/schemas.py +83 -5
  68. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/METADATA +9 -1
  69. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/RECORD +80 -79
  70. sky/benchmark/__init__.py +0 -0
  71. sky/benchmark/benchmark_state.py +0 -295
  72. sky/benchmark/benchmark_utils.py +0 -641
  73. sky/dashboard/out/_next/static/chunks/600.bd2ed8c076b720ec.js +0 -16
  74. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-59950b2f83b66e48.js +0 -6
  75. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-b3dbf38b51cb29be.js +0 -16
  76. sky/dashboard/out/_next/static/chunks/pages/users-c69ffcab9d6e5269.js +0 -1
  77. sky/dashboard/out/_next/static/chunks/webpack-1b69b196a4dbffef.js +0 -1
  78. sky/dashboard/out/_next/static/css/8e97adcaacc15293.css +0 -3
  79. sky/skylet/providers/scp/__init__.py +0 -2
  80. sky/skylet/providers/scp/config.py +0 -149
  81. sky/skylet/providers/scp/node_provider.py +0 -578
  82. /sky/dashboard/out/_next/static/{OZxMW3bxAJmqgn5f4MdhO → LRpGymRCqq-feuFyoWz4m}/_ssgManifest.js +0 -0
  83. /sky/dashboard/out/_next/static/chunks/{37-824c707421f6f003.js → 37-3a4d77ad62932eaf.js} +0 -0
  84. /sky/dashboard/out/_next/static/chunks/{843-ab9c4f609239155f.js → 843-b3040e493f6e7947.js} +0 -0
  85. /sky/dashboard/out/_next/static/chunks/{938-385d190b95815e11.js → 938-1493ac755eadeb35.js} +0 -0
  86. /sky/dashboard/out/_next/static/chunks/{973-c807fc34f09c7df3.js → 973-db3c97c2bfbceb65.js} +0 -0
  87. /sky/dashboard/out/_next/static/chunks/pages/{_app-32b2caae3445bf3b.js → _app-c416e87d5c2715cf.js} +0 -0
  88. /sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-c8c2191328532b7d.js → [name]-c4ff1ec05e2f3daf.js} +0 -0
  89. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/WHEEL +0 -0
  90. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/entry_points.txt +0 -0
  91. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/licenses/LICENSE +0 -0
  92. {skypilot_nightly-1.0.0.dev20250616.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/top_level.txt +0 -0
sky/resources.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Resources: compute requirements of Tasks."""
2
2
  import dataclasses
3
+ import math
3
4
  import textwrap
4
5
  import typing
5
6
  from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
@@ -60,7 +61,7 @@ class AutostopConfig:
60
61
 
61
62
  @classmethod
62
63
  def from_yaml_config(
63
- cls, config: Union[bool, int, Dict[str, Any], None]
64
+ cls, config: Union[bool, int, str, Dict[str, Any], None]
64
65
  ) -> Optional['AutostopConfig']:
65
66
  if isinstance(config, bool):
66
67
  if config:
@@ -71,6 +72,11 @@ class AutostopConfig:
71
72
  if isinstance(config, int):
72
73
  return cls(idle_minutes=config, down=False, enabled=True)
73
74
 
75
+ if isinstance(config, str):
76
+ return cls(idle_minutes=parse_time_minutes(config),
77
+ down=False,
78
+ enabled=True)
79
+
74
80
  if isinstance(config, dict):
75
81
  # If we have a dict, autostop is enabled. (Only way to disable is
76
82
  # with `false`, a bool.)
@@ -101,7 +107,7 @@ class Resources:
101
107
  """
102
108
  # If any fields changed, increment the version. For backward compatibility,
103
109
  # modify the __setstate__ method to handle the old version.
104
- _VERSION = 26
110
+ _VERSION = 27
105
111
 
106
112
  def __init__(
107
113
  self,
@@ -118,12 +124,13 @@ class Resources:
118
124
  region: Optional[str] = None,
119
125
  zone: Optional[str] = None,
120
126
  image_id: Union[Dict[Optional[str], str], str, None] = None,
121
- disk_size: Optional[int] = None,
127
+ disk_size: Optional[Union[str, int]] = None,
122
128
  disk_tier: Optional[Union[str, resources_utils.DiskTier]] = None,
123
129
  network_tier: Optional[Union[str, resources_utils.NetworkTier]] = None,
124
130
  ports: Optional[Union[int, str, List[str], Tuple[str]]] = None,
125
131
  labels: Optional[Dict[str, str]] = None,
126
- autostop: Union[bool, int, Dict[str, Any], None] = None,
132
+ autostop: Union[bool, int, str, Dict[str, Any], None] = None,
133
+ priority: Optional[int] = None,
127
134
  volumes: Optional[List[Dict[str, Any]]] = None,
128
135
  # Internal use only.
129
136
  # pylint: disable=invalid-name
@@ -217,6 +224,9 @@ class Resources:
217
224
  not supported and will be ignored.
218
225
  autostop: the autostop configuration to use. For launched resources,
219
226
  may or may not correspond to the actual current autostop config.
227
+ priority: the priority for this resource configuration. Must be an
228
+ integer from 0 to 1000, where higher values indicate higher priority.
229
+ If None, no priority is set.
220
230
  volumes: the volumes to mount on the instance.
221
231
  _docker_login_config: the docker configuration to use. This includes
222
232
  the docker username, password, and registry server. If None, skip
@@ -279,11 +289,7 @@ class Resources:
279
289
  self._job_recovery = job_recovery
280
290
 
281
291
  if disk_size is not None:
282
- if round(disk_size) != disk_size:
283
- with ux_utils.print_exception_no_traceback():
284
- raise ValueError(
285
- f'OS disk size must be an integer. Got: {disk_size}.')
286
- self._disk_size = int(disk_size)
292
+ self._disk_size = int(parse_memory_resource(disk_size, 'disk_size'))
287
293
  else:
288
294
  self._disk_size = _DEFAULT_DISK_SIZE_GB
289
295
 
@@ -357,10 +363,14 @@ class Resources:
357
363
  self._cluster_config_overrides = _cluster_config_overrides
358
364
  self._cached_repr: Optional[str] = None
359
365
 
366
+ # Initialize _priority before calling the setter
367
+ self._priority: Optional[int] = None
368
+
360
369
  self._set_cpus(cpus)
361
370
  self._set_memory(memory)
362
371
  self._set_accelerators(accelerators, accelerator_args)
363
372
  self._set_autostop_config(autostop)
373
+ self._set_priority(priority)
364
374
  self._set_volumes(volumes)
365
375
 
366
376
  def validate(self):
@@ -617,6 +627,14 @@ class Resources:
617
627
  """
618
628
  return self._autostop_config
619
629
 
630
+ @property
631
+ def priority(self) -> Optional[int]:
632
+ """The priority for this resource configuration.
633
+
634
+ Higher values indicate higher priority. Valid range is 0-1000.
635
+ """
636
+ return self._priority
637
+
620
638
  @property
621
639
  def is_image_managed(self) -> Optional[bool]:
622
640
  return self._is_image_managed
@@ -689,25 +707,27 @@ class Resources:
689
707
  self._memory = None
690
708
  return
691
709
 
692
- self._memory = str(memory)
693
- if isinstance(memory, str):
694
- if memory.endswith(('+', 'x')):
695
- # 'x' is used internally for make sure our resources used by
696
- # jobs controller (memory: 3x) to have enough memory based on
697
- # the vCPUs.
698
- num_memory_gb = memory[:-1]
699
- else:
700
- num_memory_gb = memory
701
-
702
- try:
703
- memory_gb = float(num_memory_gb)
704
- except ValueError:
705
- with ux_utils.print_exception_no_traceback():
706
- raise ValueError(
707
- f'The "memory" field should be either a number or '
708
- f'a string "<number>+". Found: {memory!r}') from None
710
+ memory = parse_memory_resource(str(memory),
711
+ 'memory',
712
+ ret_type=float,
713
+ allow_plus=True,
714
+ allow_x=True)
715
+ self._memory = memory
716
+ if memory.endswith(('+', 'x')):
717
+ # 'x' is used internally for make sure our resources used by
718
+ # jobs controller (memory: 3x) to have enough memory based on
719
+ # the vCPUs.
720
+ num_memory_gb = memory[:-1]
709
721
  else:
710
- memory_gb = float(memory)
722
+ num_memory_gb = memory
723
+
724
+ try:
725
+ memory_gb = float(num_memory_gb)
726
+ except ValueError:
727
+ with ux_utils.print_exception_no_traceback():
728
+ raise ValueError(
729
+ f'The "memory" field should be either a number or '
730
+ f'a string "<number>+". Found: {memory!r}') from None
711
731
 
712
732
  if memory_gb <= 0:
713
733
  with ux_utils.print_exception_no_traceback():
@@ -796,10 +816,24 @@ class Resources:
796
816
 
797
817
  def _set_autostop_config(
798
818
  self,
799
- autostop: Union[bool, int, Dict[str, Any], None],
819
+ autostop: Union[bool, int, str, Dict[str, Any], None],
800
820
  ) -> None:
801
821
  self._autostop_config = AutostopConfig.from_yaml_config(autostop)
802
822
 
823
+ def _set_priority(self, priority: Optional[int]) -> None:
824
+ """Sets the priority for this resource configuration.
825
+
826
+ Args:
827
+ priority: Priority value from 0 to 1000, where higher values
828
+ indicate higher priority. If None, no priority is set.
829
+ """
830
+ if priority is not None:
831
+ if not 0 <= priority <= 1000:
832
+ with ux_utils.print_exception_no_traceback():
833
+ raise ValueError(f'Priority must be between 0 and 1000. '
834
+ f'Found: {priority}')
835
+ self._priority = priority
836
+
803
837
  def _set_volumes(
804
838
  self,
805
839
  volumes: Optional[List[Dict[str, Any]]],
@@ -852,6 +886,7 @@ class Resources:
852
886
  else:
853
887
  volume['attach_mode'] = read_write_mode
854
888
  if volume['storage_type'] == network_type:
889
+ # TODO(luca): add units to this disk_size as well
855
890
  if ('disk_size' in volume and
856
891
  round(volume['disk_size']) != volume['disk_size']):
857
892
  with ux_utils.print_exception_no_traceback():
@@ -1716,6 +1751,7 @@ class Resources:
1716
1751
  ports=override.pop('ports', self.ports),
1717
1752
  labels=override.pop('labels', self.labels),
1718
1753
  autostop=override.pop('autostop', current_autostop_config),
1754
+ priority=override.pop('priority', self.priority),
1719
1755
  volumes=override.pop('volumes', self.volumes),
1720
1756
  infra=override.pop('infra', None),
1721
1757
  _docker_login_config=override.pop('_docker_login_config',
@@ -1936,6 +1972,7 @@ class Resources:
1936
1972
  resources_fields['ports'] = config.pop('ports', None)
1937
1973
  resources_fields['labels'] = config.pop('labels', None)
1938
1974
  resources_fields['autostop'] = config.pop('autostop', None)
1975
+ resources_fields['priority'] = config.pop('priority', None)
1939
1976
  resources_fields['volumes'] = config.pop('volumes', None)
1940
1977
  resources_fields['_docker_login_config'] = config.pop(
1941
1978
  '_docker_login_config', None)
@@ -1955,7 +1992,9 @@ class Resources:
1955
1992
  resources_fields['accelerator_args'] = dict(
1956
1993
  resources_fields['accelerator_args'])
1957
1994
  if resources_fields['disk_size'] is not None:
1958
- resources_fields['disk_size'] = int(resources_fields['disk_size'])
1995
+ # although it will end up being an int, we don't know at this point
1996
+ # if it has units or not, so we store it as a string
1997
+ resources_fields['disk_size'] = str(resources_fields['disk_size'])
1959
1998
 
1960
1999
  assert not config, f'Invalid resource args: {config.keys()}'
1961
2000
  return Resources(**resources_fields)
@@ -2006,6 +2045,7 @@ class Resources:
2006
2045
  config['volumes'] = volumes
2007
2046
  if self._autostop_config is not None:
2008
2047
  config['autostop'] = self._autostop_config.to_yaml_config()
2048
+ add_if_not_none('priority', self.priority)
2009
2049
  if self._docker_login_config is not None:
2010
2050
  config['_docker_login_config'] = dataclasses.asdict(
2011
2051
  self._docker_login_config)
@@ -2174,6 +2214,9 @@ class Resources:
2174
2214
  if version < 26:
2175
2215
  self._network_tier = state.get('_network_tier', None)
2176
2216
 
2217
+ if version < 27:
2218
+ self._priority = None
2219
+
2177
2220
  self.__dict__.update(state)
2178
2221
 
2179
2222
 
@@ -2219,3 +2262,95 @@ def _maybe_add_docker_prefix_to_image_id(
2219
2262
  for k, v in image_id_dict.items():
2220
2263
  if not v.startswith('docker:'):
2221
2264
  image_id_dict[k] = f'docker:{v}'
2265
+
2266
+
2267
+ def parse_time_minutes(time: str) -> int:
2268
+ """Convert a time string to minutes.
2269
+
2270
+ Args:
2271
+ time: Time string with optional unit suffix (e.g., '30m', '2h', '1d')
2272
+
2273
+ Returns:
2274
+ Time in minutes as an integer
2275
+ """
2276
+ time_str = str(time)
2277
+
2278
+ if time_str.isdecimal():
2279
+ # We assume it is already in minutes to maintain backwards
2280
+ # compatibility
2281
+ return int(time_str)
2282
+
2283
+ time_str = time_str.lower()
2284
+ for unit, multiplier in constants.TIME_UNITS.items():
2285
+ if time_str.endswith(unit):
2286
+ try:
2287
+ value = int(time_str[:-len(unit)])
2288
+ return math.ceil(value * multiplier)
2289
+ except ValueError:
2290
+ continue
2291
+
2292
+ raise ValueError(f'Invalid time format: {time}')
2293
+
2294
+
2295
+ def parse_memory_resource(resource_qty_str: Union[str, int, float],
2296
+ field_name: str,
2297
+ ret_type: type = int,
2298
+ unit: str = 'g',
2299
+ allow_plus: bool = False,
2300
+ allow_x: bool = False,
2301
+ allow_rounding: bool = False) -> str:
2302
+ """Returns memory size in chosen units given a resource quantity string.
2303
+
2304
+ Args:
2305
+ resource_qty_str: Resource quantity string
2306
+ unit: Unit to convert to
2307
+ allow_plus: Whether to allow '+' prefix
2308
+ allow_x: Whether to allow 'x' suffix
2309
+ """
2310
+ assert unit in constants.MEMORY_SIZE_UNITS, f'Invalid unit: {unit}'
2311
+
2312
+ error_msg = f'"{field_name}" field should be a <int><b|k|m|g|t|p><+?>,'\
2313
+ f' got {resource_qty_str}'
2314
+
2315
+ resource_str = str(resource_qty_str)
2316
+
2317
+ # Handle plus and x suffixes, x is only used internally for jobs controller
2318
+ plus = ''
2319
+ if resource_str.endswith('+'):
2320
+ if allow_plus:
2321
+ resource_str = resource_str[:-1]
2322
+ plus = '+'
2323
+ else:
2324
+ raise ValueError(error_msg)
2325
+
2326
+ x = ''
2327
+ if resource_str.endswith('x'):
2328
+ if allow_x:
2329
+ resource_str = resource_str[:-1]
2330
+ x = 'x'
2331
+ else:
2332
+ raise ValueError(error_msg)
2333
+
2334
+ try:
2335
+ # We assume it is already in the wanted units to maintain backwards
2336
+ # compatibility
2337
+ ret_type(resource_str)
2338
+ return f'{resource_str}{plus}{x}'
2339
+ except ValueError:
2340
+ pass
2341
+
2342
+ resource_str = resource_str.lower()
2343
+ for mem_unit, multiplier in constants.MEMORY_SIZE_UNITS.items():
2344
+ if resource_str.endswith(mem_unit):
2345
+ try:
2346
+ value = ret_type(resource_str[:-len(mem_unit)])
2347
+ converted = (value * multiplier /
2348
+ constants.MEMORY_SIZE_UNITS[unit])
2349
+ if not allow_rounding and ret_type(converted) != converted:
2350
+ raise ValueError(error_msg)
2351
+ converted = ret_type(converted)
2352
+ return f'{converted}{plus}{x}'
2353
+ except ValueError:
2354
+ continue
2355
+
2356
+ raise ValueError(error_msg)
sky/server/common.py CHANGED
@@ -13,7 +13,7 @@ import subprocess
13
13
  import sys
14
14
  import time
15
15
  import typing
16
- from typing import Any, Dict, Literal, Optional
16
+ from typing import Any, Dict, Literal, Optional, Tuple
17
17
  from urllib import parse
18
18
  import uuid
19
19
 
@@ -128,6 +128,8 @@ class ApiServerInfo:
128
128
  version: Optional[str] = None
129
129
  version_on_disk: Optional[str] = None
130
130
  commit: Optional[str] = None
131
+ user: Optional[Dict[str, Any]] = None
132
+ basic_auth_enabled: bool = False
131
133
 
132
134
 
133
135
  def get_api_cookie_jar_path() -> pathlib.Path:
@@ -261,11 +263,15 @@ def get_api_server_status(endpoint: Optional[str] = None) -> ApiServerInfo:
261
263
  version = result.get('version')
262
264
  version_on_disk = result.get('version_on_disk')
263
265
  commit = result.get('commit')
266
+ user = result.get('user')
267
+ basic_auth_enabled = result.get('basic_auth_enabled')
264
268
  server_info = ApiServerInfo(status=ApiServerStatus.HEALTHY,
265
269
  api_version=api_version,
266
270
  version=version,
267
271
  version_on_disk=version_on_disk,
268
- commit=commit)
272
+ commit=commit,
273
+ user=user,
274
+ basic_auth_enabled=basic_auth_enabled)
269
275
  if api_version is None or version is None or commit is None:
270
276
  logger.warning(f'API server response missing '
271
277
  f'version info. {server_url} may '
@@ -320,7 +326,8 @@ def get_request_id(response: 'requests.Response') -> RequestId:
320
326
 
321
327
  def _start_api_server(deploy: bool = False,
322
328
  host: str = '127.0.0.1',
323
- foreground: bool = False):
329
+ foreground: bool = False,
330
+ enable_basic_auth: bool = False):
324
331
  """Starts a SkyPilot API server locally."""
325
332
  server_url = get_server_url(host)
326
333
  assert server_url in AVAILABLE_LOCAL_API_SERVER_URLS, (
@@ -354,6 +361,8 @@ def _start_api_server(deploy: bool = False,
354
361
  if foreground:
355
362
  # Replaces the current process with the API server
356
363
  os.environ[constants.ENV_VAR_IS_SKYPILOT_SERVER] = 'true'
364
+ if enable_basic_auth:
365
+ os.environ[constants.ENV_VAR_ENABLE_BASIC_AUTH] = 'true'
357
366
  os.execvp(args[0], args)
358
367
 
359
368
  log_path = os.path.expanduser(constants.API_SERVER_LOGS)
@@ -365,6 +374,8 @@ def _start_api_server(deploy: bool = False,
365
374
  # the API server.
366
375
  server_env = os.environ.copy()
367
376
  server_env[constants.ENV_VAR_IS_SKYPILOT_SERVER] = 'true'
377
+ if enable_basic_auth:
378
+ server_env[constants.ENV_VAR_ENABLE_BASIC_AUTH] = 'true'
368
379
  with open(log_path, 'w', encoding='utf-8') as log_file:
369
380
  # Because the log file is opened using a with statement, it may seem
370
381
  # that the file will be closed when the with statement is exited
@@ -428,10 +439,10 @@ def _start_api_server(deploy: bool = False,
428
439
 
429
440
  def check_server_healthy(
430
441
  endpoint: Optional[str] = None
431
- ) -> Literal[
442
+ ) -> Tuple[Literal[
432
443
  # Use an incomplete list of Literals here to enforce raising for other
433
444
  # enum values.
434
- ApiServerStatus.HEALTHY, ApiServerStatus.NEEDS_AUTH]:
445
+ ApiServerStatus.HEALTHY, ApiServerStatus.NEEDS_AUTH], ApiServerInfo]:
435
446
  """Check if the API server is healthy.
436
447
 
437
448
  Args:
@@ -508,7 +519,7 @@ def check_server_healthy(
508
519
 
509
520
  hinted_for_server_install_version_mismatch = True
510
521
 
511
- return api_server_status
522
+ return api_server_status, api_server_info
512
523
 
513
524
 
514
525
  def _get_version_info_hint(server_info: ApiServerInfo) -> str:
@@ -559,10 +570,11 @@ def get_skypilot_version_on_disk() -> str:
559
570
 
560
571
  def check_server_healthy_or_start_fn(deploy: bool = False,
561
572
  host: str = '127.0.0.1',
562
- foreground: bool = False):
573
+ foreground: bool = False,
574
+ enable_basic_auth: bool = False):
563
575
  api_server_status = None
564
576
  try:
565
- api_server_status = check_server_healthy()
577
+ api_server_status, _ = check_server_healthy()
566
578
  if api_server_status == ApiServerStatus.NEEDS_AUTH:
567
579
  endpoint = get_server_url()
568
580
  with ux_utils.print_exception_no_traceback():
@@ -580,7 +592,7 @@ def check_server_healthy_or_start_fn(deploy: bool = False,
580
592
  # have started the server while we were waiting for the lock.
581
593
  api_server_info = get_api_server_status(endpoint)
582
594
  if api_server_info.status == ApiServerStatus.UNHEALTHY:
583
- _start_api_server(deploy, host, foreground)
595
+ _start_api_server(deploy, host, foreground, enable_basic_auth)
584
596
 
585
597
 
586
598
  def check_server_healthy_or_start(func):
@@ -336,10 +336,28 @@ class ClusterJobsDownloadLogsBody(RequestBody):
336
336
  local_dir: str = constants.SKY_LOGS_DIRECTORY
337
337
 
338
338
 
339
+ class UserCreateBody(RequestBody):
340
+ """The request body for the user create endpoint."""
341
+ username: str
342
+ password: str
343
+ role: Optional[str] = None
344
+
345
+
346
+ class UserDeleteBody(RequestBody):
347
+ """The request body for the user delete endpoint."""
348
+ user_id: str
349
+
350
+
339
351
  class UserUpdateBody(RequestBody):
340
352
  """The request body for the user update endpoint."""
341
353
  user_id: str
342
- role: str
354
+ role: Optional[str] = None
355
+ password: Optional[str] = None
356
+
357
+
358
+ class UserImportBody(RequestBody):
359
+ """The request body for the user import endpoint."""
360
+ csv_content: str
343
361
 
344
362
 
345
363
  class DownloadBody(RequestBody):
sky/server/server.py CHANGED
@@ -23,6 +23,7 @@ import zipfile
23
23
  import aiofiles
24
24
  import fastapi
25
25
  from fastapi.middleware import cors
26
+ from passlib.hash import apr_md5_crypt
26
27
  import starlette.middleware.base
27
28
 
28
29
  import sky
@@ -102,6 +103,74 @@ logger = sky_logging.init_logger(__name__)
102
103
  # response will block other requests from being processed.
103
104
 
104
105
 
106
+ def _basic_auth_401_response(content: str):
107
+ """Return a 401 response with basic auth realm."""
108
+ return fastapi.responses.JSONResponse(
109
+ status_code=401,
110
+ headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
111
+ content=content)
112
+
113
+
114
+ # TODO(hailong): Remove this function and use request.state.auth_user instead.
115
+ async def _override_user_info_in_request_body(request: fastapi.Request,
116
+ auth_user: Optional[models.User]):
117
+ body = await request.body()
118
+ if auth_user and body:
119
+ try:
120
+ original_json = await request.json()
121
+ except json.JSONDecodeError as e:
122
+ logger.error(f'Error parsing request JSON: {e}')
123
+ else:
124
+ logger.debug(f'Overriding user for {request.state.request_id}: '
125
+ f'{auth_user.name}, {auth_user.id}')
126
+ if 'env_vars' in original_json:
127
+ if isinstance(original_json.get('env_vars'), dict):
128
+ original_json['env_vars'][
129
+ constants.USER_ID_ENV_VAR] = auth_user.id
130
+ original_json['env_vars'][
131
+ constants.USER_ENV_VAR] = auth_user.name
132
+ else:
133
+ logger.warning(
134
+ f'"env_vars" in request body is not a dictionary '
135
+ f'for request {request.state.request_id}. '
136
+ 'Skipping user info injection into body.')
137
+ else:
138
+ original_json['env_vars'] = {}
139
+ original_json['env_vars'][
140
+ constants.USER_ID_ENV_VAR] = auth_user.id
141
+ original_json['env_vars'][
142
+ constants.USER_ENV_VAR] = auth_user.name
143
+ request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
144
+
145
+
146
+ def _try_set_basic_auth_user(request: fastapi.Request):
147
+ auth_header = request.headers.get('authorization')
148
+ if not auth_header or not auth_header.lower().startswith('basic '):
149
+ return
150
+
151
+ # Check username and password
152
+ encoded = auth_header.split(' ', 1)[1]
153
+ try:
154
+ decoded = base64.b64decode(encoded).decode()
155
+ username, password = decoded.split(':', 1)
156
+ except Exception: # pylint: disable=broad-except
157
+ return
158
+
159
+ users = global_user_state.get_user_by_name(username)
160
+ if not users:
161
+ return
162
+
163
+ for user in users:
164
+ if not user.name or not user.password:
165
+ continue
166
+ username_encoded = username.encode('utf8')
167
+ db_username_encoded = user.name.encode('utf8')
168
+ if (username_encoded == db_username_encoded and
169
+ apr_md5_crypt.verify(password, user.password)):
170
+ request.state.auth_user = user
171
+ break
172
+
173
+
105
174
  class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
106
175
  """Middleware to handle RBAC."""
107
176
 
@@ -112,7 +181,7 @@ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
112
181
  request.url.path.startswith('/api/')):
113
182
  return await call_next(request)
114
183
 
115
- auth_user = _get_auth_user_header(request)
184
+ auth_user = request.state.auth_user
116
185
  if auth_user is None:
117
186
  return await call_next(request)
118
187
 
@@ -149,6 +218,50 @@ def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
149
218
  return models.User(id=user_hash, name=user_name)
150
219
 
151
220
 
221
+ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
222
+ """Middleware to handle HTTP Basic Auth."""
223
+
224
+ async def dispatch(self, request: fastapi.Request, call_next):
225
+ if request.url.path.startswith('/api/'):
226
+ # Try to set the auth user from the basic auth header so the
227
+ # following endpoint handlers can leverage the auth_user info
228
+ _try_set_basic_auth_user(request)
229
+ return await call_next(request)
230
+
231
+ auth_header = request.headers.get('authorization')
232
+ if not auth_header or not auth_header.lower().startswith('basic '):
233
+ return _basic_auth_401_response('Invalid basic auth')
234
+
235
+ # Check username and password
236
+ encoded = auth_header.split(' ', 1)[1]
237
+ try:
238
+ decoded = base64.b64decode(encoded).decode()
239
+ username, password = decoded.split(':', 1)
240
+ except Exception: # pylint: disable=broad-except
241
+ return _basic_auth_401_response('Invalid basic auth')
242
+
243
+ users = global_user_state.get_user_by_name(username)
244
+ if not users:
245
+ return _basic_auth_401_response('Invalid credentials')
246
+
247
+ valid_user = False
248
+ for user in users:
249
+ if not user.name or not user.password:
250
+ continue
251
+ username_encoded = username.encode('utf8')
252
+ db_username_encoded = user.name.encode('utf8')
253
+ if (username_encoded == db_username_encoded and
254
+ apr_md5_crypt.verify(password, user.password)):
255
+ valid_user = True
256
+ request.state.auth_user = user
257
+ await _override_user_info_in_request_body(request, user)
258
+ break
259
+ if not valid_user:
260
+ return _basic_auth_401_response('Invalid credentials')
261
+
262
+ return await call_next(request)
263
+
264
+
152
265
  class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
153
266
  """Middleware to handle auth proxy."""
154
267
 
@@ -168,33 +281,7 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
168
281
  else:
169
282
  request.state.auth_user = None
170
283
 
171
- body = await request.body()
172
- if auth_user and body:
173
- try:
174
- original_json = await request.json()
175
- except json.JSONDecodeError as e:
176
- logger.error(f'Error parsing request JSON: {e}')
177
- else:
178
- logger.debug(f'Overriding user for {request.state.request_id}: '
179
- f'{auth_user.name}, {auth_user.id}')
180
- if 'env_vars' in original_json:
181
- if isinstance(original_json.get('env_vars'), dict):
182
- original_json['env_vars'][
183
- constants.USER_ID_ENV_VAR] = auth_user.id
184
- original_json['env_vars'][
185
- constants.USER_ENV_VAR] = auth_user.name
186
- else:
187
- logger.warning(
188
- f'"env_vars" in request body is not a dictionary '
189
- f'for request {request.state.request_id}. '
190
- 'Skipping user info injection into body.')
191
- else:
192
- original_json['env_vars'] = {}
193
- original_json['env_vars'][
194
- constants.USER_ID_ENV_VAR] = auth_user.id
195
- original_json['env_vars'][
196
- constants.USER_ENV_VAR] = auth_user.name
197
- request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
284
+ await _override_user_info_in_request_body(request, auth_user)
198
285
  return await call_next(request)
199
286
 
200
287
 
@@ -306,6 +393,9 @@ app.add_middleware(
306
393
  allow_headers=['*'],
307
394
  # TODO(syang): remove X-Request-ID when v0.10.0 is released.
308
395
  expose_headers=['X-Request-ID', 'X-Skypilot-Request-ID'])
396
+ enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
397
+ if str(enable_basic_auth).lower() == 'true':
398
+ app.add_middleware(BasicAuthMiddleware)
309
399
  app.add_middleware(AuthProxyMiddleware)
310
400
  app.add_middleware(RequestIDMiddleware)
311
401
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
@@ -1232,7 +1322,7 @@ async def health(request: fastapi.Request) -> Dict[str, Any]:
1232
1322
  disk, which can be used to warn about restarting the API server
1233
1323
  - commit: str; The commit hash of SkyPilot used for API server.
1234
1324
  """
1235
- user = _get_auth_user_header(request)
1325
+ user = request.state.auth_user
1236
1326
  return {
1237
1327
  'status': common.ApiServerStatus.HEALTHY.value,
1238
1328
  'api_version': server_constants.API_VERSION,
@@ -1240,6 +1330,8 @@ async def health(request: fastapi.Request) -> Dict[str, Any]:
1240
1330
  'version_on_disk': common.get_skypilot_version_on_disk(),
1241
1331
  'commit': sky.__commit__,
1242
1332
  'user': user.to_dict() if user is not None else None,
1333
+ 'basic_auth_enabled': os.environ.get(
1334
+ constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false').lower() == 'true',
1243
1335
  }
1244
1336
 
1245
1337
 
@@ -58,8 +58,17 @@ install_requires = [
58
58
  'setproctitle',
59
59
  'sqlalchemy',
60
60
  'psycopg2-binary',
61
+ # TODO(hailong): These three dependencies should be removed after we make
62
+ # the client-side actually not importing them.
61
63
  'casbin',
62
64
  'sqlalchemy_adapter',
65
+ 'passlib',
66
+ ]
67
+
68
+ server_dependencies = [
69
+ 'casbin',
70
+ 'sqlalchemy_adapter',
71
+ 'passlib',
63
72
  ]
64
73
 
65
74
  local_ray = [
@@ -162,7 +171,8 @@ extras_require: Dict[str, List[str]] = {
162
171
  'nebius': [
163
172
  'nebius>=0.2.0',
164
173
  ] + aws_dependencies,
165
- 'hyperbolic': [] # No dependencies needed for hyperbolic
174
+ 'hyperbolic': [], # No dependencies needed for hyperbolic
175
+ 'server': server_dependencies,
166
176
  }
167
177
 
168
178
  # Nebius needs python3.10. If python 3.9 [all] will not install nebius