skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/aws.py +1 -61
  3. sky/adaptors/slurm.py +478 -0
  4. sky/backends/backend_utils.py +45 -4
  5. sky/backends/cloud_vm_ray_backend.py +32 -33
  6. sky/backends/task_codegen.py +340 -2
  7. sky/catalog/__init__.py +0 -3
  8. sky/catalog/kubernetes_catalog.py +12 -4
  9. sky/catalog/slurm_catalog.py +243 -0
  10. sky/check.py +14 -3
  11. sky/client/cli/command.py +329 -22
  12. sky/client/sdk.py +56 -2
  13. sky/clouds/__init__.py +2 -0
  14. sky/clouds/cloud.py +7 -0
  15. sky/clouds/slurm.py +578 -0
  16. sky/clouds/ssh.py +2 -1
  17. sky/clouds/vast.py +10 -0
  18. sky/core.py +128 -36
  19. sky/dashboard/out/404.html +1 -1
  20. sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/9353-8369df1cf105221c.js +1 -0
  26. sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +34 -0
  27. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +16 -0
  28. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-abfcac9c137aa543.js → [cluster]-a7565f586ef86467.js} +1 -1
  29. sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-9e5d47818b9bdadd.js} +1 -1
  30. sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
  31. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-c0b5935149902e6f.js → [context]-12c559ec4d81fdbd.js} +1 -1
  32. sky/dashboard/out/_next/static/chunks/pages/{infra-aed0ea19df7cf961.js → infra-d187cd0413d72475.js} +1 -1
  33. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +16 -0
  34. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-9faf940b253e3e06.js → [pool]-8d0f4655400b4eb9.js} +2 -2
  35. sky/dashboard/out/_next/static/chunks/pages/{jobs-2072b48b617989c9.js → jobs-e5a98f17f8513a96.js} +1 -1
  36. sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-4f46050ca065d8f8.js +1 -0
  37. sky/dashboard/out/_next/static/chunks/pages/{users-f42674164aa73423.js → users-2f7646eb77785a2c.js} +1 -1
  38. sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-ef19d49c6d0e8500.js} +1 -1
  39. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-96e0f298308da7e2.js} +1 -1
  40. sky/dashboard/out/_next/static/chunks/pages/{workspaces-531b2f8c4bf89f82.js → workspaces-cb4da3abe08ebf19.js} +1 -1
  41. sky/dashboard/out/_next/static/chunks/{webpack-64e05f17bf2cf8ce.js → webpack-fba3de387ff6bb08.js} +1 -1
  42. sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +3 -0
  43. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  44. sky/dashboard/out/clusters/[cluster].html +1 -1
  45. sky/dashboard/out/clusters.html +1 -1
  46. sky/dashboard/out/config.html +1 -1
  47. sky/dashboard/out/index.html +1 -1
  48. sky/dashboard/out/infra/[context].html +1 -1
  49. sky/dashboard/out/infra.html +1 -1
  50. sky/dashboard/out/jobs/[job].html +1 -1
  51. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  52. sky/dashboard/out/jobs.html +1 -1
  53. sky/dashboard/out/plugins/[...slug].html +1 -0
  54. sky/dashboard/out/users.html +1 -1
  55. sky/dashboard/out/volumes.html +1 -1
  56. sky/dashboard/out/workspace/new.html +1 -1
  57. sky/dashboard/out/workspaces/[name].html +1 -1
  58. sky/dashboard/out/workspaces.html +1 -1
  59. sky/data/mounting_utils.py +16 -2
  60. sky/global_user_state.py +3 -3
  61. sky/models.py +2 -0
  62. sky/optimizer.py +6 -5
  63. sky/provision/__init__.py +1 -0
  64. sky/provision/common.py +20 -0
  65. sky/provision/docker_utils.py +15 -2
  66. sky/provision/kubernetes/utils.py +42 -6
  67. sky/provision/provisioner.py +15 -6
  68. sky/provision/slurm/__init__.py +12 -0
  69. sky/provision/slurm/config.py +13 -0
  70. sky/provision/slurm/instance.py +572 -0
  71. sky/provision/slurm/utils.py +583 -0
  72. sky/provision/vast/instance.py +4 -1
  73. sky/provision/vast/utils.py +10 -6
  74. sky/serve/server/impl.py +1 -1
  75. sky/server/constants.py +1 -1
  76. sky/server/plugins.py +222 -0
  77. sky/server/requests/executor.py +5 -2
  78. sky/server/requests/payloads.py +12 -1
  79. sky/server/requests/request_names.py +2 -0
  80. sky/server/requests/requests.py +5 -1
  81. sky/server/requests/serializers/encoders.py +17 -0
  82. sky/server/requests/serializers/return_value_serializers.py +60 -0
  83. sky/server/server.py +78 -8
  84. sky/server/server_utils.py +30 -0
  85. sky/setup_files/dependencies.py +2 -0
  86. sky/skylet/attempt_skylet.py +13 -3
  87. sky/skylet/constants.py +34 -9
  88. sky/skylet/events.py +10 -4
  89. sky/skylet/executor/__init__.py +1 -0
  90. sky/skylet/executor/slurm.py +189 -0
  91. sky/skylet/job_lib.py +2 -1
  92. sky/skylet/log_lib.py +22 -6
  93. sky/skylet/log_lib.pyi +8 -6
  94. sky/skylet/skylet.py +5 -1
  95. sky/skylet/subprocess_daemon.py +2 -1
  96. sky/ssh_node_pools/constants.py +12 -0
  97. sky/ssh_node_pools/core.py +40 -3
  98. sky/ssh_node_pools/deploy/__init__.py +4 -0
  99. sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
  100. sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
  101. sky/ssh_node_pools/deploy/utils.py +173 -0
  102. sky/ssh_node_pools/server.py +11 -13
  103. sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
  104. sky/templates/kubernetes-ray.yml.j2 +8 -0
  105. sky/templates/slurm-ray.yml.j2 +85 -0
  106. sky/templates/vast-ray.yml.j2 +1 -0
  107. sky/users/model.conf +1 -1
  108. sky/users/permission.py +24 -1
  109. sky/users/rbac.py +31 -3
  110. sky/utils/annotations.py +108 -8
  111. sky/utils/command_runner.py +197 -5
  112. sky/utils/command_runner.pyi +27 -4
  113. sky/utils/common_utils.py +18 -3
  114. sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
  115. sky/utils/kubernetes/ssh-tunnel.sh +7 -376
  116. sky/utils/schemas.py +31 -0
  117. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/METADATA +48 -36
  118. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/RECORD +125 -107
  119. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
  120. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
  121. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
  123. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
  124. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
  125. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
  126. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
  127. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
  128. sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
  129. sky/utils/kubernetes/cleanup-tunnel.sh +0 -62
  130. /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → KYAhEFa3FTfq4JyKVgo-s}/_ssgManifest.js +0 -0
  131. /sky/dashboard/out/_next/static/chunks/{1141-e6aa9ab418717c59.js → 1141-9c810f01ff4f398a.js} +0 -0
  132. /sky/dashboard/out/_next/static/chunks/{3800-7b45f9fbb6308557.js → 3800-b589397dc09c5b4e.js} +0 -0
  133. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/WHEEL +0 -0
  134. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/entry_points.txt +0 -0
  135. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/licenses/LICENSE +0 -0
  136. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/top_level.txt +0 -0
sky/server/plugins.py ADDED
@@ -0,0 +1,222 @@
1
+ """Load plugins for the SkyPilot API server."""
2
+ import abc
3
+ import dataclasses
4
+ import importlib
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from fastapi import FastAPI
9
+
10
+ from sky import sky_logging
11
+ from sky.skylet import constants as skylet_constants
12
+ from sky.utils import common_utils
13
+ from sky.utils import config_utils
14
+ from sky.utils import yaml_utils
15
+
16
+ logger = sky_logging.init_logger(__name__)
17
+
18
+ _DEFAULT_PLUGINS_CONFIG_PATH = '~/.sky/plugins.yaml'
19
+ _PLUGINS_CONFIG_ENV_VAR = (
20
+ f'{skylet_constants.SKYPILOT_SERVER_ENV_VAR_PREFIX}PLUGINS_CONFIG')
21
+
22
+
23
+ class ExtensionContext:
24
+ """Context provided to plugins during installation.
25
+
26
+ Attributes:
27
+ app: The FastAPI application instance.
28
+ rbac_rules: List of RBAC rules registered by the plugin.
29
+ Example:
30
+ [
31
+ ('user', RBACRule(path='/plugins/api/xx/*', method='POST')),
32
+ ('user', RBACRule(path='/plugins/api/xx/*', method='DELETE'))
33
+ ]
34
+ """
35
+
36
+ def __init__(self, app: Optional[FastAPI] = None):
37
+ self.app = app
38
+ self.rbac_rules: List[Tuple[str, RBACRule]] = []
39
+
40
+ def register_rbac_rule(self,
41
+ path: str,
42
+ method: str,
43
+ description: Optional[str] = None,
44
+ role: str = 'user') -> None:
45
+ """Register an RBAC rule for this plugin.
46
+
47
+ This method allows plugins to declare which endpoints should be
48
+ restricted to admin users during the install phase.
49
+
50
+ Args:
51
+ path: The path pattern to restrict (supports wildcards with
52
+ keyMatch2).
53
+ Example: '/plugins/api/credentials/*'
54
+ method: The HTTP method to restrict. Example: 'POST', 'DELETE'
55
+ description: Optional description of what this rule protects.
56
+ role: The role to add this rule to (default: 'user').
57
+ Rules added to 'user' role block regular users but allow
58
+ admins.
59
+
60
+ Example:
61
+ def install(self, ctx: ExtensionContext):
62
+ # Only admin can upload credentials
63
+ ctx.register_rbac_rule(
64
+ path='/plugins/api/credentials/*',
65
+ method='POST',
66
+ description='Only admin can upload credentials'
67
+ )
68
+ """
69
+ rule = RBACRule(path=path, method=method, description=description)
70
+ self.rbac_rules.append((role, rule))
71
+ logger.debug(f'Registered RBAC rule for {role}: {method} {path}'
72
+ f'{f" - {description}" if description else ""}')
73
+
74
+
75
+ @dataclasses.dataclass
76
+ class RBACRule:
77
+ """RBAC rule for a plugin endpoint.
78
+
79
+ Attributes:
80
+ path: The path pattern to match (supports wildcards with keyMatch2).
81
+ Example: '/plugins/api/credentials/*'
82
+ method: The HTTP method to restrict. Example: 'POST', 'DELETE'
83
+ description: Optional description of what this rule protects.
84
+ """
85
+ path: str
86
+ method: str
87
+ description: Optional[str] = None
88
+
89
+
90
+ class BasePlugin(abc.ABC):
91
+ """Base class for all SkyPilot server plugins."""
92
+
93
+ @property
94
+ def js_extension_path(self) -> Optional[str]:
95
+ """Optional API route to the JavaScript extension to load."""
96
+ return None
97
+
98
+ @abc.abstractmethod
99
+ def install(self, extension_context: ExtensionContext):
100
+ """Hook called by API server to let the plugin install itself."""
101
+ raise NotImplementedError
102
+
103
+ def shutdown(self):
104
+ """Hook called by API server to let the plugin shutdown."""
105
+ pass
106
+
107
+
108
+ def _config_schema():
109
+ plugin_schema = {
110
+ 'type': 'object',
111
+ 'required': ['class'],
112
+ 'additionalProperties': False,
113
+ 'properties': {
114
+ 'class': {
115
+ 'type': 'string',
116
+ },
117
+ 'parameters': {
118
+ 'type': 'object',
119
+ 'required': [],
120
+ 'additionalProperties': True,
121
+ },
122
+ },
123
+ }
124
+ return {
125
+ 'type': 'object',
126
+ 'required': [],
127
+ 'additionalProperties': False,
128
+ 'properties': {
129
+ 'plugins': {
130
+ 'type': 'array',
131
+ 'items': plugin_schema,
132
+ 'default': [],
133
+ },
134
+ },
135
+ }
136
+
137
+
138
+ def _load_plugin_config() -> Optional[config_utils.Config]:
139
+ """Load plugin config."""
140
+ config_path = os.getenv(_PLUGINS_CONFIG_ENV_VAR,
141
+ _DEFAULT_PLUGINS_CONFIG_PATH)
142
+ config_path = os.path.expanduser(config_path)
143
+ if not os.path.exists(config_path):
144
+ return None
145
+ config = yaml_utils.read_yaml(config_path) or {}
146
+ common_utils.validate_schema(config,
147
+ _config_schema(),
148
+ err_msg_prefix='Invalid plugins config: ')
149
+ return config_utils.Config.from_dict(config)
150
+
151
+
152
+ _PLUGINS: Dict[str, BasePlugin] = {}
153
+ _EXTENSION_CONTEXT: Optional[ExtensionContext] = None
154
+
155
+
156
+ def load_plugins(extension_context: ExtensionContext):
157
+ """Load and initialize plugins from the config."""
158
+ global _EXTENSION_CONTEXT
159
+ _EXTENSION_CONTEXT = extension_context
160
+
161
+ config = _load_plugin_config()
162
+ if not config:
163
+ return
164
+
165
+ for plugin_config in config.get('plugins', []):
166
+ class_path = plugin_config['class']
167
+ module_path, class_name = class_path.rsplit('.', 1)
168
+ try:
169
+ module = importlib.import_module(module_path)
170
+ except ImportError as e:
171
+ raise ImportError(
172
+ f'Failed to import plugin module: {module_path}. '
173
+ 'Please check if the module is installed in your Python '
174
+ 'environment.') from e
175
+ try:
176
+ plugin_cls = getattr(module, class_name)
177
+ except AttributeError as e:
178
+ raise AttributeError(
179
+ f'Could not find plugin {class_name} class in module '
180
+ f'{module_path}. ') from e
181
+ if not issubclass(plugin_cls, BasePlugin):
182
+ raise TypeError(
183
+ f'Plugin {class_path} must inherit from BasePlugin.')
184
+ parameters = plugin_config.get('parameters') or {}
185
+ plugin = plugin_cls(**parameters)
186
+ plugin.install(extension_context)
187
+ _PLUGINS[class_path] = plugin
188
+
189
+
190
+ def get_plugins() -> List[BasePlugin]:
191
+ """Return shallow copies of the registered plugins."""
192
+ return list(_PLUGINS.values())
193
+
194
+
195
+ def get_plugin_rbac_rules() -> Dict[str, List[Dict[str, str]]]:
196
+ """Collect RBAC rules from all loaded plugins.
197
+
198
+ Collects rules from the ExtensionContext.
199
+
200
+ Returns:
201
+ Dictionary mapping role names to lists of blocklist rules.
202
+ Example:
203
+ {
204
+ 'user': [
205
+ {'path': '/plugins/api/credentials/*', 'method': 'POST'},
206
+ {'path': '/plugins/api/credentials/*', 'method': 'DELETE'}
207
+ ]
208
+ }
209
+ """
210
+ rules_by_role: Dict[str, List[Dict[str, str]]] = {}
211
+
212
+ # Collect rules registered via ExtensionContext
213
+ if _EXTENSION_CONTEXT:
214
+ for role, rule in _EXTENSION_CONTEXT.rbac_rules:
215
+ if role not in rules_by_role:
216
+ rules_by_role[role] = []
217
+ rules_by_role[role].append({
218
+ 'path': rule.path,
219
+ 'method': rule.method,
220
+ })
221
+
222
+ return rules_by_role
@@ -44,6 +44,7 @@ from sky.server import common as server_common
44
44
  from sky.server import config as server_config
45
45
  from sky.server import constants as server_constants
46
46
  from sky.server import metrics as metrics_lib
47
+ from sky.server import plugins
47
48
  from sky.server.requests import payloads
48
49
  from sky.server.requests import preconditions
49
50
  from sky.server.requests import process
@@ -159,6 +160,8 @@ queue_backend = server_config.QueueBackend.MULTIPROCESSING
159
160
  def executor_initializer(proc_group: str):
160
161
  setproctitle.setproctitle(f'SkyPilot:executor:{proc_group}:'
161
162
  f'{multiprocessing.current_process().pid}')
163
+ # Load plugins for executor process.
164
+ plugins.load_plugins(plugins.ExtensionContext())
162
165
  # Executor never stops, unless the whole process is killed.
163
166
  threading.Thread(target=metrics_lib.process_monitor,
164
167
  args=(f'worker:{proc_group}', threading.Event()),
@@ -533,8 +536,8 @@ def _request_execution_wrapper(request_id: str,
533
536
  # so that the "Request xxxx failed due to ..." log message will be
534
537
  # written to the original stdout and stderr file descriptors.
535
538
  _restore_output()
536
- logger.info(f'Request {request_id} failed due to '
537
- f'{common_utils.format_exception(e)}')
539
+ logger.error(f'Request {request_id} failed due to '
540
+ f'{common_utils.format_exception(e)}')
538
541
  return
539
542
  else:
540
543
  api_requests.set_request_succeeded(
@@ -82,7 +82,7 @@ def request_body_env_vars() -> dict:
82
82
  if common.is_api_server_local() and env_var in EXTERNAL_LOCAL_ENV_VARS:
83
83
  env_vars[env_var] = os.environ[env_var]
84
84
  env_vars[constants.USER_ID_ENV_VAR] = common_utils.get_user_hash()
85
- env_vars[constants.USER_ENV_VAR] = common_utils.get_current_user_name()
85
+ env_vars[constants.USER_ENV_VAR] = common_utils.get_local_user_name()
86
86
  env_vars[
87
87
  usage_constants.USAGE_RUN_ID_ENV_VAR] = usage_lib.messages.usage.run_id
88
88
  if not common.is_api_server_local():
@@ -670,6 +670,11 @@ class KubernetesNodeInfoRequestBody(RequestBody):
670
670
  context: Optional[str] = None
671
671
 
672
672
 
673
+ class SlurmNodeInfoRequestBody(RequestBody):
674
+ """The request body for the slurm node info endpoint."""
675
+ slurm_cluster_name: Optional[str] = None
676
+
677
+
673
678
  class ListAcceleratorsBody(RequestBody):
674
679
  """The request body for the list accelerators endpoint."""
675
680
  gpus_only: bool = True
@@ -854,3 +859,9 @@ class RequestPayload(BasePayload):
854
859
  status_msg: Optional[str] = None
855
860
  should_retry: bool = False
856
861
  finished_at: Optional[float] = None
862
+
863
+
864
+ class SlurmGpuAvailabilityRequestBody(RequestBody):
865
+ """Request body for getting Slurm real-time GPU availability."""
866
+ name_filter: Optional[str] = None
867
+ quantity_filter: Optional[int] = None
@@ -10,6 +10,8 @@ class RequestName(str, enum.Enum):
10
10
  REALTIME_KUBERNETES_GPU_AVAILABILITY = (
11
11
  'realtime_kubernetes_gpu_availability')
12
12
  KUBERNETES_NODE_INFO = 'kubernetes_node_info'
13
+ REALTIME_SLURM_GPU_AVAILABILITY = 'realtime_slurm_gpu_availability'
14
+ SLURM_NODE_INFO = 'slurm_node_info'
13
15
  STATUS_KUBERNETES = 'status_kubernetes'
14
16
  LIST_ACCELERATORS = 'list_accelerators'
15
17
  LIST_ACCELERATOR_COUNTS = 'list_accelerator_counts'
@@ -33,6 +33,7 @@ from sky.server import daemons
33
33
  from sky.server.requests import payloads
34
34
  from sky.server.requests.serializers import decoders
35
35
  from sky.server.requests.serializers import encoders
36
+ from sky.server.requests.serializers import return_value_serializers
36
37
  from sky.utils import asyncio_utils
37
38
  from sky.utils import common_utils
38
39
  from sky.utils import ux_utils
@@ -231,13 +232,16 @@ class Request:
231
232
  assert isinstance(self.request_body,
232
233
  payloads.RequestBody), (self.name, self.request_body)
233
234
  try:
235
+ # Use version-aware serializer to handle backward compatibility
236
+ # for old clients that don't recognize new fields.
237
+ serializer = return_value_serializers.get_serializer(self.name)
234
238
  return payloads.RequestPayload(
235
239
  request_id=self.request_id,
236
240
  name=self.name,
237
241
  entrypoint=encoders.pickle_and_encode(self.entrypoint),
238
242
  request_body=encoders.pickle_and_encode(self.request_body),
239
243
  status=self.status.value,
240
- return_value=orjson.dumps(self.return_value).decode('utf-8'),
244
+ return_value=serializer(self.return_value),
241
245
  error=orjson.dumps(self.error).decode('utf-8'),
242
246
  pid=self.pid,
243
247
  created_at=self.created_at,
@@ -266,6 +266,23 @@ def encode_realtime_gpu_availability(
266
266
  return encoded
267
267
 
268
268
 
269
+ @register_encoder('realtime_slurm_gpu_availability')
270
+ def encode_realtime_slurm_gpu_availability(
271
+ return_value: List[Tuple[str,
272
+ List[Any]]]) -> List[Tuple[str, List[List[Any]]]]:
273
+ # Convert RealtimeGpuAvailability namedtuples to lists
274
+ # for JSON serialization.
275
+ encoded = []
276
+ for context, gpu_list in return_value:
277
+ converted_gpu_list = []
278
+ for gpu in gpu_list:
279
+ assert isinstance(gpu, models.RealtimeGpuAvailability), (
280
+ f'Expected RealtimeGpuAvailability, got {type(gpu)}')
281
+ converted_gpu_list.append(list(gpu))
282
+ encoded.append((context, converted_gpu_list))
283
+ return encoded
284
+
285
+
269
286
  @register_encoder('list_accelerators')
270
287
  def encode_list_accelerators(
271
288
  return_value: Dict[str, List[Any]]) -> Dict[str, Any]:
@@ -0,0 +1,60 @@
1
+ """Version-aware serializers for request return values.
2
+
3
+ These serializers run at encode() time when remote_api_version is available,
4
+ to handle backward compatibility for old clients.
5
+
6
+ The existing encoders.py handles object -> dict conversion at set_return_value()
7
+ time. This module handles dict -> JSON string serialization at encode() time,
8
+ with version-aware field filtering for backward compatibility.
9
+ """
10
+ from typing import Any, Callable, Dict
11
+
12
+ import orjson
13
+
14
+ from sky.server import constants as server_constants
15
+ from sky.server import versions
16
+
17
+ handlers: Dict[str, Callable[[Any], str]] = {}
18
+
19
+
20
+ def register_serializer(*names: str):
21
+ """Decorator to register a version-aware serializer."""
22
+
23
+ def decorator(func):
24
+ for name in names:
25
+ if name != server_constants.DEFAULT_HANDLER_NAME:
26
+ name = server_constants.REQUEST_NAME_PREFIX + name
27
+ if name in handlers:
28
+ raise ValueError(f'Serializer {name} already registered: '
29
+ f'{handlers[name]}')
30
+ handlers[name] = func
31
+ return func
32
+
33
+ return decorator
34
+
35
+
36
+ def get_serializer(name: str) -> Callable[[Any], str]:
37
+ """Get the serializer for a request name."""
38
+ return handlers.get(name, handlers[server_constants.DEFAULT_HANDLER_NAME])
39
+
40
+
41
+ @register_serializer(server_constants.DEFAULT_HANDLER_NAME)
42
+ def default_serializer(return_value: Any) -> str:
43
+ """The default serializer."""
44
+ return orjson.dumps(return_value).decode('utf-8')
45
+
46
+
47
+ @register_serializer('kubernetes_node_info')
48
+ def serialize_kubernetes_node_info(return_value: Dict[str, Any]) -> str:
49
+ """Serialize kubernetes node info with version compatibility.
50
+
51
+ The is_ready field was added in API version 25. Remove it for old clients
52
+ that don't recognize it.
53
+ """
54
+ remote_api_version = versions.get_remote_api_version()
55
+ if (return_value and remote_api_version is not None and
56
+ remote_api_version < 25):
57
+ # Remove is_ready field for old clients that don't recognize it
58
+ for node_info in return_value.get('node_info_dict', {}).values():
59
+ node_info.pop('is_ready', None)
60
+ return orjson.dumps(return_value).decode('utf-8')
sky/server/server.py CHANGED
@@ -20,7 +20,7 @@ import struct
20
20
  import sys
21
21
  import threading
22
22
  import traceback
23
- from typing import Dict, List, Literal, Optional, Set, Tuple
23
+ from typing import Any, Dict, List, Literal, Optional, Set, Tuple
24
24
  import uuid
25
25
  import zipfile
26
26
 
@@ -48,6 +48,7 @@ from sky.jobs.server import server as jobs_rest
48
48
  from sky.metrics import utils as metrics_utils
49
49
  from sky.provision import metadata_utils
50
50
  from sky.provision.kubernetes import utils as kubernetes_utils
51
+ from sky.provision.slurm import utils as slurm_utils
51
52
  from sky.schemas.api import responses
52
53
  from sky.serve.server import server as serve_rest
53
54
  from sky.server import common
@@ -56,6 +57,8 @@ from sky.server import constants as server_constants
56
57
  from sky.server import daemons
57
58
  from sky.server import metrics
58
59
  from sky.server import middleware_utils
60
+ from sky.server import plugins
61
+ from sky.server import server_utils
59
62
  from sky.server import state
60
63
  from sky.server import stream_utils
61
64
  from sky.server import versions
@@ -470,7 +473,8 @@ async def schedule_on_boot_check_async():
470
473
  await executor.schedule_request_async(
471
474
  request_id='skypilot-server-on-boot-check',
472
475
  request_name=request_names.RequestName.CHECK,
473
- request_body=payloads.CheckBody(),
476
+ request_body=server_utils.build_body_at_server(
477
+ request=None, body_type=payloads.CheckBody),
474
478
  func=sky_check.check,
475
479
  schedule_type=requests_lib.ScheduleType.SHORT,
476
480
  is_skypilot_system=True,
@@ -493,7 +497,8 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
493
497
  await executor.schedule_request_async(
494
498
  request_id=event.id,
495
499
  request_name=event.name,
496
- request_body=payloads.RequestBody(),
500
+ request_body=server_utils.build_body_at_server(
501
+ request=None, body_type=payloads.RequestBody),
497
502
  func=event.run_event,
498
503
  schedule_type=requests_lib.ScheduleType.SHORT,
499
504
  is_skypilot_system=True,
@@ -652,6 +657,17 @@ app.add_middleware(BearerTokenMiddleware)
652
657
  # middleware above.
653
658
  app.add_middleware(InitializeRequestAuthUserMiddleware)
654
659
  app.add_middleware(RequestIDMiddleware)
660
+
661
+ # Load plugins after all the middlewares are added, to keep the core
662
+ # middleware stack intact if a plugin adds new middlewares.
663
+ # Note: server.py will be imported twice in server process, once as
664
+ # the top-level entrypoint module and once imported by uvicorn, we only
665
+ # load the plugin when imported by uvicorn for server process.
666
+ # TODO(aylei): move uvicorn app out of the top-level module to avoid
667
+ # duplicate app initialization.
668
+ if __name__ == 'sky.server.server':
669
+ plugins.load_plugins(plugins.ExtensionContext(app=app))
670
+
655
671
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
656
672
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
657
673
  app.include_router(users_rest.router, prefix='/users', tags=['users'])
@@ -746,8 +762,11 @@ async def enabled_clouds(request: fastapi.Request,
746
762
  await executor.schedule_request_async(
747
763
  request_id=request.state.request_id,
748
764
  request_name=request_names.RequestName.ENABLED_CLOUDS,
749
- request_body=payloads.EnabledCloudsBody(workspace=workspace,
750
- expand=expand),
765
+ request_body=server_utils.build_body_at_server(
766
+ request=request,
767
+ body_type=payloads.EnabledCloudsBody,
768
+ workspace=workspace,
769
+ expand=expand),
751
770
  func=core.enabled_clouds,
752
771
  schedule_type=requests_lib.ScheduleType.SHORT,
753
772
  )
@@ -784,6 +803,35 @@ async def kubernetes_node_info(
784
803
  )
785
804
 
786
805
 
806
+ @app.post('/slurm_gpu_availability')
807
+ async def slurm_gpu_availability(
808
+ request: fastapi.Request,
809
+ slurm_gpu_availability_body: payloads.SlurmGpuAvailabilityRequestBody
810
+ ) -> None:
811
+ """Gets real-time Slurm GPU availability."""
812
+ await executor.schedule_request_async(
813
+ request_id=request.state.request_id,
814
+ request_name=request_names.RequestName.REALTIME_SLURM_GPU_AVAILABILITY,
815
+ request_body=slurm_gpu_availability_body,
816
+ func=core.realtime_slurm_gpu_availability,
817
+ schedule_type=requests_lib.ScheduleType.SHORT,
818
+ )
819
+
820
+
821
+ @app.get('/slurm_node_info')
822
+ async def slurm_node_info(
823
+ request: fastapi.Request,
824
+ slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
825
+ """Gets detailed information for each node in the Slurm cluster."""
826
+ await executor.schedule_request_async(
827
+ request_id=request.state.request_id,
828
+ request_name=request_names.RequestName.SLURM_NODE_INFO,
829
+ request_body=slurm_node_info_body,
830
+ func=slurm_utils.slurm_node_info,
831
+ schedule_type=requests_lib.ScheduleType.SHORT,
832
+ )
833
+
834
+
787
835
  @app.get('/status_kubernetes')
788
836
  async def status_kubernetes(request: fastapi.Request) -> None:
789
837
  """[Experimental] Get all SkyPilot resources (including from other '
@@ -791,7 +839,8 @@ async def status_kubernetes(request: fastapi.Request) -> None:
791
839
  await executor.schedule_request_async(
792
840
  request_id=request.state.request_id,
793
841
  request_name=request_names.RequestName.STATUS_KUBERNETES,
794
- request_body=payloads.RequestBody(),
842
+ request_body=server_utils.build_body_at_server(
843
+ request=request, body_type=payloads.RequestBody),
795
844
  func=core.status_kubernetes,
796
845
  schedule_type=requests_lib.ScheduleType.SHORT,
797
846
  )
@@ -1460,7 +1509,8 @@ async def storage_ls(request: fastapi.Request) -> None:
1460
1509
  await executor.schedule_request_async(
1461
1510
  request_id=request.state.request_id,
1462
1511
  request_name=request_names.RequestName.STORAGE_LS,
1463
- request_body=payloads.RequestBody(),
1512
+ request_body=server_utils.build_body_at_server(
1513
+ request=request, body_type=payloads.RequestBody),
1464
1514
  func=core.storage_ls,
1465
1515
  schedule_type=requests_lib.ScheduleType.SHORT,
1466
1516
  )
@@ -1752,6 +1802,15 @@ async def api_status(
1752
1802
  return encoded_request_tasks
1753
1803
 
1754
1804
 
1805
+ @app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
1806
+ async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
1807
+ """Return metadata about loaded backend plugins."""
1808
+ plugin_info = [{
1809
+ 'js_extension_path': plugin.js_extension_path,
1810
+ } for plugin in plugins.get_plugins()]
1811
+ return {'plugins': plugin_info}
1812
+
1813
+
1755
1814
  @app.get(
1756
1815
  '/api/health',
1757
1816
  # response_model_exclude_unset omits unset fields
@@ -2007,7 +2066,8 @@ async def all_contexts(request: fastapi.Request) -> None:
2007
2066
  await executor.schedule_request_async(
2008
2067
  request_id=request.state.request_id,
2009
2068
  request_name=request_names.RequestName.ALL_CONTEXTS,
2010
- request_body=payloads.RequestBody(),
2069
+ request_body=server_utils.build_body_at_server(
2070
+ request=request, body_type=payloads.RequestBody),
2011
2071
  func=core.get_all_contexts,
2012
2072
  schedule_type=requests_lib.ScheduleType.SHORT,
2013
2073
  )
@@ -2057,6 +2117,14 @@ async def serve_dashboard(full_path: str):
2057
2117
  if os.path.isfile(file_path):
2058
2118
  return fastapi.responses.FileResponse(file_path)
2059
2119
 
2120
+ # Serve plugin catch-all page for any /plugins/* paths so client-side
2121
+ # routing can bootstrap correctly.
2122
+ if full_path == 'plugins' or full_path.startswith('plugins/'):
2123
+ plugin_catchall = os.path.join(server_constants.DASHBOARD_DIR,
2124
+ 'plugins', '[...slug].html')
2125
+ if os.path.isfile(plugin_catchall):
2126
+ return fastapi.responses.FileResponse(plugin_catchall)
2127
+
2060
2128
  # Serve index.html for client-side routing
2061
2129
  # e.g. /clusters, /jobs
2062
2130
  index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
@@ -2220,6 +2288,8 @@ if __name__ == '__main__':
2220
2288
 
2221
2289
  for gt in global_tasks:
2222
2290
  gt.cancel()
2291
+ for plugin in plugins.get_plugins():
2292
+ plugin.shutdown()
2223
2293
  subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
2224
2294
  workers,
2225
2295
  num_threads=len(workers))
@@ -0,0 +1,30 @@
1
+ """Utilities for the API server."""
2
+
3
+ from typing import Optional, Type, TypeVar
4
+
5
+ import fastapi
6
+
7
+ from sky.server.requests import payloads
8
+ from sky.skylet import constants
9
+
10
+ _BodyT = TypeVar('_BodyT', bound=payloads.RequestBody)
11
+
12
+
13
+ # TODO(aylei): remove this and disable request body construction at server-side
14
+ def build_body_at_server(request: Optional[fastapi.Request],
15
+ body_type: Type[_BodyT], **data) -> _BodyT:
16
+ """Builds the request body at the server.
17
+
18
+ For historical reasons, some handlers mimic a client request body
19
+ at server-side in order to coordinate with the interface of executor.
20
+ This will cause issues where the client info like user identity is not
21
+ respected in these handlers. This function is a helper to build the request
22
+ body at server-side with the auth user overridden.
23
+ """
24
+ request_body = body_type(**data)
25
+ if request is not None:
26
+ auth_user = getattr(request.state, 'auth_user', None)
27
+ if auth_user:
28
+ request_body.env_vars[constants.USER_ID_ENV_VAR] = auth_user.id
29
+ request_body.env_vars[constants.USER_ENV_VAR] = auth_user.name
30
+ return request_body
@@ -84,6 +84,7 @@ install_requires = [
84
84
  'bcrypt==4.0.1',
85
85
  'pyjwt',
86
86
  'gitpython',
87
+ 'paramiko',
87
88
  'types-paramiko',
88
89
  'alembic',
89
90
  'aiohttp',
@@ -234,6 +235,7 @@ cloud_dependencies: Dict[str, List[str]] = {
234
235
  'hyperbolic': [], # No dependencies needed for hyperbolic
235
236
  'seeweb': ['ecsapi==0.4.0'],
236
237
  'shadeform': [], # No dependencies needed for shadeform
238
+ 'slurm': [], # No dependencies needed for slurm
237
239
  }
238
240
 
239
241
  # Calculate which clouds should be included in the [all] installation.
@@ -9,6 +9,7 @@ import psutil
9
9
 
10
10
  from sky.skylet import constants
11
11
  from sky.skylet import runtime_utils
12
+ from sky.utils import common_utils
12
13
 
13
14
  VERSION_FILE = runtime_utils.get_runtime_dir_path(constants.SKYLET_VERSION_FILE)
14
15
  SKYLET_LOG_FILE = runtime_utils.get_runtime_dir_path(constants.SKYLET_LOG_FILE)
@@ -97,8 +98,13 @@ def restart_skylet():
97
98
  for pid in _find_running_skylet_pids():
98
99
  try:
99
100
  os.kill(pid, signal.SIGKILL)
100
- except (OSError, ProcessLookupError):
101
- # Process died between detection and kill
101
+ # Wait until process fully terminates so its socket gets released.
102
+ # Without this, find_free_port may race with the kernel closing the
103
+ # socket and fail to bind to the port that's supposed to be free.
104
+ psutil.Process(pid).wait(timeout=5)
105
+ except (OSError, ProcessLookupError, psutil.NoSuchProcess,
106
+ psutil.TimeoutExpired):
107
+ # Process died between detection and kill, or timeout waiting
102
108
  pass
103
109
  # Clean up the PID file
104
110
  try:
@@ -106,7 +112,11 @@ def restart_skylet():
106
112
  except OSError:
107
113
  pass # Best effort cleanup
108
114
 
109
- port = constants.SKYLET_GRPC_PORT
115
+ # TODO(kevin): Handle race conditions here. Race conditions can only
116
+ # happen on Slurm, where there could be multiple clusters running in
117
+ # one network namespace. For other clouds, the behaviour will be that
118
+ # it always gets port 46590 (default port).
119
+ port = common_utils.find_free_port(constants.SKYLET_GRPC_PORT)
110
120
  subprocess.run(
111
121
  # We have made sure that `attempt_skylet.py` is executed with the
112
122
  # skypilot runtime env activated, so that skylet can access the cloud