skypilot-nightly 1.0.0.dev20250916__py3-none-any.whl → 1.0.0.dev20250919__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (81) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/primeintellect.py +1 -0
  3. sky/adaptors/seeweb.py +68 -4
  4. sky/authentication.py +25 -0
  5. sky/backends/__init__.py +3 -2
  6. sky/backends/backend_utils.py +16 -12
  7. sky/backends/cloud_vm_ray_backend.py +57 -0
  8. sky/catalog/primeintellect_catalog.py +95 -0
  9. sky/clouds/__init__.py +2 -0
  10. sky/clouds/primeintellect.py +314 -0
  11. sky/core.py +77 -48
  12. sky/dashboard/out/404.html +1 -1
  13. sky/dashboard/out/_next/static/{y8s7LlyyfhMzpzCkxuD2r → VvaUqYDvHOcHZRnvMBmax}/_buildManifest.js +1 -1
  14. sky/dashboard/out/_next/static/chunks/1121-4ff1ec0dbc5792ab.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/3015-88c7c8d69b0b6dba.js +1 -0
  16. sky/dashboard/out/_next/static/chunks/{6856-e0754534b3015377.js → 6856-9a2538f38c004652.js} +1 -1
  17. sky/dashboard/out/_next/static/chunks/8969-a39efbadcd9fde80.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/9037-472ee1222cb1e158.js +6 -0
  19. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-1e9248ddbddcd122.js +16 -0
  20. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-0b4b35dc1dfe046c.js → [cluster]-9525660179df3605.js} +1 -1
  21. sky/dashboard/out/_next/static/chunks/{webpack-05f82d90d6fd7f82.js → webpack-b2a3938c22b6647b.js} +1 -1
  22. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  23. sky/dashboard/out/clusters/[cluster].html +1 -1
  24. sky/dashboard/out/clusters.html +1 -1
  25. sky/dashboard/out/config.html +1 -1
  26. sky/dashboard/out/index.html +1 -1
  27. sky/dashboard/out/infra/[context].html +1 -1
  28. sky/dashboard/out/infra.html +1 -1
  29. sky/dashboard/out/jobs/[job].html +1 -1
  30. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  31. sky/dashboard/out/jobs.html +1 -1
  32. sky/dashboard/out/users.html +1 -1
  33. sky/dashboard/out/volumes.html +1 -1
  34. sky/dashboard/out/workspace/new.html +1 -1
  35. sky/dashboard/out/workspaces/[name].html +1 -1
  36. sky/dashboard/out/workspaces.html +1 -1
  37. sky/global_user_state.py +99 -62
  38. sky/jobs/server/server.py +14 -1
  39. sky/jobs/state.py +26 -1
  40. sky/metrics/utils.py +174 -8
  41. sky/provision/__init__.py +1 -0
  42. sky/provision/docker_utils.py +6 -2
  43. sky/provision/primeintellect/__init__.py +10 -0
  44. sky/provision/primeintellect/config.py +11 -0
  45. sky/provision/primeintellect/instance.py +454 -0
  46. sky/provision/primeintellect/utils.py +398 -0
  47. sky/resources.py +9 -1
  48. sky/schemas/generated/jobsv1_pb2.py +40 -40
  49. sky/schemas/generated/servev1_pb2.py +58 -0
  50. sky/schemas/generated/servev1_pb2.pyi +115 -0
  51. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  52. sky/serve/serve_rpc_utils.py +179 -0
  53. sky/serve/serve_utils.py +29 -12
  54. sky/serve/server/core.py +37 -19
  55. sky/serve/server/impl.py +221 -129
  56. sky/server/metrics.py +52 -158
  57. sky/server/requests/executor.py +12 -8
  58. sky/server/requests/payloads.py +6 -0
  59. sky/server/requests/requests.py +1 -1
  60. sky/server/requests/serializers/encoders.py +3 -2
  61. sky/server/server.py +5 -41
  62. sky/setup_files/dependencies.py +1 -0
  63. sky/skylet/constants.py +10 -5
  64. sky/skylet/job_lib.py +14 -15
  65. sky/skylet/services.py +98 -0
  66. sky/skylet/skylet.py +3 -1
  67. sky/templates/kubernetes-ray.yml.j2 +22 -12
  68. sky/templates/primeintellect-ray.yml.j2 +71 -0
  69. sky/utils/locks.py +41 -10
  70. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/METADATA +36 -35
  71. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/RECORD +76 -64
  72. sky/dashboard/out/_next/static/chunks/1121-408ed10b2f9fce17.js +0 -1
  73. sky/dashboard/out/_next/static/chunks/3015-2ea98b57e318bd6e.js +0 -1
  74. sky/dashboard/out/_next/static/chunks/8969-0487dfbf149d9e53.js +0 -1
  75. sky/dashboard/out/_next/static/chunks/9037-f9800e64eb05dd1c.js +0 -6
  76. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-1cbba24bd1bd35f8.js +0 -16
  77. /sky/dashboard/out/_next/static/{y8s7LlyyfhMzpzCkxuD2r → VvaUqYDvHOcHZRnvMBmax}/_ssgManifest.js +0 -0
  78. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/WHEEL +0 -0
  79. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/entry_points.txt +0 -0
  80. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/licenses/LICENSE +0 -0
  81. {skypilot_nightly-1.0.0.dev20250916.dist-info → skypilot_nightly-1.0.0.dev20250919.dist-info}/top_level.txt +0 -0
sky/server/metrics.py CHANGED
@@ -1,11 +1,11 @@
1
1
  """Instrumentation for the API server."""
2
2
 
3
- import contextlib
4
- import functools
3
+ import asyncio
5
4
  import multiprocessing
6
5
  import os
7
6
  import threading
8
7
  import time
8
+ from typing import List
9
9
 
10
10
  import fastapi
11
11
  from prometheus_client import generate_latest
@@ -15,112 +15,12 @@ import psutil
15
15
  import starlette.middleware.base
16
16
  import uvicorn
17
17
 
18
+ from sky import core
18
19
  from sky import sky_logging
19
- from sky.skylet import constants
20
-
21
- # Whether the metrics are enabled, cannot be changed at runtime.
22
- METRICS_ENABLED = os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED,
23
- 'false').lower() == 'true'
24
-
25
- _KB = 2**10
26
- _MB = 2**20
27
- _MEM_BUCKETS = [
28
- _KB,
29
- 256 * _KB,
30
- 512 * _KB,
31
- _MB,
32
- 2 * _MB,
33
- 4 * _MB,
34
- 8 * _MB,
35
- 16 * _MB,
36
- 32 * _MB,
37
- 64 * _MB,
38
- 128 * _MB,
39
- 256 * _MB,
40
- float('inf'),
41
- ]
20
+ from sky.metrics import utils as metrics_utils
42
21
 
43
22
  logger = sky_logging.init_logger(__name__)
44
23
 
45
- # Total number of API server requests, grouped by path, method, and status.
46
- SKY_APISERVER_REQUESTS_TOTAL = prom.Counter(
47
- 'sky_apiserver_requests_total',
48
- 'Total number of API server requests',
49
- ['path', 'method', 'status'],
50
- )
51
-
52
- # Time spent processing API server requests, grouped by path, method, and
53
- # status.
54
- SKY_APISERVER_REQUEST_DURATION_SECONDS = prom.Histogram(
55
- 'sky_apiserver_request_duration_seconds',
56
- 'Time spent processing API server requests',
57
- ['path', 'method', 'status'],
58
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
59
- 60.0, 120.0, float('inf')),
60
- )
61
-
62
- # Time spent processing a piece of code, refer to time_it().
63
- SKY_APISERVER_CODE_DURATION_SECONDS = prom.Histogram(
64
- 'sky_apiserver_code_duration_seconds',
65
- 'Time spent processing code',
66
- ['name', 'group'],
67
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
68
- 60.0, 120.0, float('inf')),
69
- )
70
-
71
- SKY_APISERVER_EVENT_LOOP_LAG_SECONDS = prom.Histogram(
72
- 'sky_apiserver_event_loop_lag_seconds',
73
- 'Scheduling delay of the server event loop',
74
- ['pid'],
75
- buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 20.0,
76
- 60.0, float('inf')),
77
- )
78
-
79
- SKY_APISERVER_WEBSOCKET_CONNECTIONS = prom.Gauge(
80
- 'sky_apiserver_websocket_connections',
81
- 'Number of websocket connections',
82
- ['pid'],
83
- multiprocess_mode='livesum',
84
- )
85
-
86
- SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL = prom.Counter(
87
- 'sky_apiserver_websocket_closed_total',
88
- 'Number of websocket closed',
89
- ['pid', 'reason'],
90
- )
91
-
92
- # The number of execution starts in each worker process, we do not record
93
- # histogram here as the duration has been measured in
94
- # SKY_APISERVER_CODE_DURATION_SECONDS without the worker label (process id).
95
- # Recording histogram WITH worker label will cause high cardinality.
96
- SKY_APISERVER_PROCESS_EXECUTION_START_TOTAL = prom.Counter(
97
- 'sky_apiserver_process_execution_start_total',
98
- 'Total number of execution starts in each worker process',
99
- ['request', 'pid'],
100
- )
101
-
102
- SKY_APISERVER_PROCESS_PEAK_RSS = prom.Gauge(
103
- 'sky_apiserver_process_peak_rss',
104
- 'Peak RSS we saw in each process in last 30 seconds',
105
- ['pid', 'type'],
106
- )
107
-
108
- SKY_APISERVER_PROCESS_CPU_TOTAL = prom.Gauge(
109
- 'sky_apiserver_process_cpu_total',
110
- 'Total CPU times a worker process has been running',
111
- ['pid', 'type', 'mode'],
112
- )
113
-
114
- SKY_APISERVER_REQUEST_MEMORY_USAGE_BYTES = prom.Histogram(
115
- 'sky_apiserver_request_memory_usage_bytes',
116
- 'Peak memory usage of requests', ['name'],
117
- buckets=_MEM_BUCKETS)
118
-
119
- SKY_APISERVER_REQUEST_RSS_INCR_BYTES = prom.Histogram(
120
- 'sky_apiserver_request_rss_incr_bytes',
121
- 'RSS increment after requests', ['name'],
122
- buckets=_MEM_BUCKETS)
123
-
124
24
  metrics_app = fastapi.FastAPI()
125
25
 
126
26
 
@@ -139,6 +39,42 @@ async def metrics() -> fastapi.Response:
139
39
  headers={'Cache-Control': 'no-cache'})
140
40
 
141
41
 
42
+ @metrics_app.get('/gpu-metrics')
43
+ async def gpu_metrics() -> fastapi.Response:
44
+ """Gets the GPU metrics from multiple external k8s clusters"""
45
+ contexts = core.get_all_contexts()
46
+ all_metrics: List[str] = []
47
+ successful_contexts = 0
48
+
49
+ tasks = [
50
+ asyncio.create_task(metrics_utils.get_metrics_for_context(context))
51
+ for context in contexts
52
+ if context != 'in-cluster'
53
+ ]
54
+
55
+ results = await asyncio.gather(*tasks, return_exceptions=True)
56
+
57
+ for i, result in enumerate(results):
58
+ if isinstance(result, Exception):
59
+ logger.error(
60
+ f'Failed to get metrics for context {contexts[i]}: {result}')
61
+ elif isinstance(result, BaseException):
62
+ # Avoid changing behavior for non-Exception BaseExceptions
63
+ # like KeyboardInterrupt/SystemExit: re-raise them.
64
+ raise result
65
+ else:
66
+ metrics_text = result
67
+ all_metrics.append(metrics_text)
68
+ successful_contexts += 1
69
+
70
+ combined_metrics = '\n\n'.join(all_metrics)
71
+
72
+ # Return as plain text for Prometheus compatibility
73
+ return fastapi.Response(
74
+ content=combined_metrics,
75
+ media_type='text/plain; version=0.0.4; charset=utf-8')
76
+
77
+
142
78
  def build_metrics_server(host: str, port: int) -> uvicorn.Server:
143
79
  metrics_config = uvicorn.Config(
144
80
  'sky.server.metrics:metrics_app',
@@ -182,61 +118,17 @@ class PrometheusMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
182
118
  status_code_group = '5xx'
183
119
  raise
184
120
  finally:
185
- SKY_APISERVER_REQUESTS_TOTAL.labels(path=path,
186
- method=method,
187
- status=status_code_group).inc()
121
+ metrics_utils.SKY_APISERVER_REQUESTS_TOTAL.labels(
122
+ path=path, method=method, status=status_code_group).inc()
188
123
  if not streaming:
189
124
  duration = time.time() - start_time
190
- SKY_APISERVER_REQUEST_DURATION_SECONDS.labels(
125
+ metrics_utils.SKY_APISERVER_REQUEST_DURATION_SECONDS.labels(
191
126
  path=path, method=method,
192
127
  status=status_code_group).observe(duration)
193
128
 
194
129
  return response
195
130
 
196
131
 
197
- @contextlib.contextmanager
198
- def time_it(name: str, group: str = 'default'):
199
- """Context manager to measure and record code execution duration."""
200
- if not METRICS_ENABLED:
201
- yield
202
- else:
203
- start_time = time.time()
204
- try:
205
- yield
206
- finally:
207
- duration = time.time() - start_time
208
- SKY_APISERVER_CODE_DURATION_SECONDS.labels(
209
- name=name, group=group).observe(duration)
210
-
211
-
212
- def time_me(func):
213
- """Measure the duration of decorated function."""
214
-
215
- @functools.wraps(func)
216
- def wrapper(*args, **kwargs):
217
- if not METRICS_ENABLED:
218
- return func(*args, **kwargs)
219
- name = f'{func.__module__}/{func.__name__}'
220
- with time_it(name, group='function'):
221
- return func(*args, **kwargs)
222
-
223
- return wrapper
224
-
225
-
226
- def time_me_async(func):
227
- """Measure the duration of decorated async function."""
228
-
229
- @functools.wraps(func)
230
- async def async_wrapper(*args, **kwargs):
231
- if not METRICS_ENABLED:
232
- return await func(*args, **kwargs)
233
- name = f'{func.__module__}/{func.__name__}'
234
- with time_it(name, group='function'):
235
- return await func(*args, **kwargs)
236
-
237
- return async_wrapper
238
-
239
-
240
132
  peak_rss_bytes = 0
241
133
 
242
134
 
@@ -252,13 +144,15 @@ def process_monitor(process_type: str, stop: threading.Event):
252
144
  last_bucket_end = time.time()
253
145
  bucket_peak = 0
254
146
  peak_rss_bytes = max(bucket_peak, proc.memory_info().rss)
255
- SKY_APISERVER_PROCESS_PEAK_RSS.labels(
147
+ metrics_utils.SKY_APISERVER_PROCESS_PEAK_RSS.labels(
256
148
  pid=pid, type=process_type).set(peak_rss_bytes)
257
149
  ctimes = proc.cpu_times()
258
- SKY_APISERVER_PROCESS_CPU_TOTAL.labels(pid=pid,
259
- type=process_type,
260
- mode='user').set(ctimes.user)
261
- SKY_APISERVER_PROCESS_CPU_TOTAL.labels(pid=pid,
262
- type=process_type,
263
- mode='system').set(ctimes.system)
150
+ metrics_utils.SKY_APISERVER_PROCESS_CPU_TOTAL.labels(pid=pid,
151
+ type=process_type,
152
+ mode='user').set(
153
+ ctimes.user)
154
+ metrics_utils.SKY_APISERVER_PROCESS_CPU_TOTAL.labels(pid=pid,
155
+ type=process_type,
156
+ mode='system').set(
157
+ ctimes.system)
264
158
  time.sleep(1)
@@ -39,6 +39,7 @@ from sky import global_user_state
39
39
  from sky import models
40
40
  from sky import sky_logging
41
41
  from sky import skypilot_config
42
+ from sky.metrics import utils as metrics_utils
42
43
  from sky.server import common as server_common
43
44
  from sky.server import config as server_config
44
45
  from sky.server import constants as server_constants
@@ -422,10 +423,10 @@ def _request_execution_wrapper(request_id: str,
422
423
  config = skypilot_config.to_dict()
423
424
  logger.debug(f'request config: \n'
424
425
  f'{yaml_utils.dump_yaml_str(dict(config))}')
425
- metrics_lib.SKY_APISERVER_PROCESS_EXECUTION_START_TOTAL.labels(
426
- request=request_name, pid=pid).inc()
427
- with metrics_lib.time_it(name=request_name,
428
- group='request_execution'):
426
+ (metrics_utils.SKY_APISERVER_PROCESS_EXECUTION_START_TOTAL.
427
+ labels(request=request_name, pid=pid).inc())
428
+ with metrics_utils.time_it(name=request_name,
429
+ group='request_execution'):
429
430
  return_value = func(**request_body.to_kwargs())
430
431
  f.flush()
431
432
  except KeyboardInterrupt:
@@ -465,8 +466,11 @@ def _request_execution_wrapper(request_id: str,
465
466
  # Capture the peak RSS before GC.
466
467
  peak_rss = max(proc.memory_info().rss,
467
468
  metrics_lib.peak_rss_bytes)
468
- with metrics_lib.time_it(name='release_memory',
469
- group='internal'):
469
+ # Clear request level cache to release all memory used by
470
+ # the request.
471
+ annotations.clear_request_level_cache()
472
+ with metrics_utils.time_it(name='release_memory',
473
+ group='internal'):
470
474
  common_utils.release_memory()
471
475
  _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
472
476
  except Exception as e: # pylint: disable=broad-except
@@ -490,11 +494,11 @@ def _record_memory_metrics(request_name: str, proc: psutil.Process,
490
494
  rss_end = proc.memory_info().rss
491
495
 
492
496
  # Answer "how much RSS this request contributed?"
493
- metrics_lib.SKY_APISERVER_REQUEST_RSS_INCR_BYTES.labels(
497
+ metrics_utils.SKY_APISERVER_REQUEST_RSS_INCR_BYTES.labels(
494
498
  name=request_name).observe(max(rss_end - rss_begin, 0))
495
499
  # Estimate the memory usage by the request by capturing the
496
500
  # peak memory delta during the request execution.
497
- metrics_lib.SKY_APISERVER_REQUEST_MEMORY_USAGE_BYTES.labels(
501
+ metrics_utils.SKY_APISERVER_REQUEST_MEMORY_USAGE_BYTES.labels(
498
502
  name=request_name).observe(max(peak_rss - rss_begin, 0))
499
503
 
500
504
 
@@ -792,6 +792,12 @@ class GetConfigBody(RequestBody):
792
792
  class CostReportBody(RequestBody):
793
793
  """The request body for the cost report endpoint."""
794
794
  days: Optional[int] = 30
795
+ # we use hashes instead of names to avoid the case where
796
+ # the name is not unique
797
+ cluster_hashes: Optional[List[str]] = None
798
+ # Only return fields that are needed for the dashboard
799
+ # summary page
800
+ dashboard_summary_response: bool = False
795
801
 
796
802
 
797
803
  class RequestPayload(BasePayload):
@@ -25,10 +25,10 @@ from sky import exceptions
25
25
  from sky import global_user_state
26
26
  from sky import sky_logging
27
27
  from sky import skypilot_config
28
+ from sky.metrics import utils as metrics_lib
28
29
  from sky.server import common as server_common
29
30
  from sky.server import constants as server_constants
30
31
  from sky.server import daemons
31
- from sky.server import metrics as metrics_lib
32
32
  from sky.server.requests import payloads
33
33
  from sky.server.requests.serializers import decoders
34
34
  from sky.server.requests.serializers import encoders
@@ -185,8 +185,9 @@ def encode_cost_report(
185
185
  for cluster_report in cost_report:
186
186
  if cluster_report['status'] is not None:
187
187
  cluster_report['status'] = cluster_report['status'].value
188
- cluster_report['resources'] = pickle_and_encode(
189
- cluster_report['resources'])
188
+ if 'resources' in cluster_report:
189
+ cluster_report['resources'] = pickle_and_encode(
190
+ cluster_report['resources'])
190
191
  return cost_report
191
192
 
192
193
 
sky/server/server.py CHANGED
@@ -437,7 +437,7 @@ async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
437
437
  if lag_threshold is not None and lag > lag_threshold:
438
438
  logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
439
439
  f'{lag_threshold} seconds.')
440
- metrics.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
440
+ metrics_utils.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
441
441
  pid=pid).observe(lag)
442
442
  target = now + interval
443
443
  loop.call_at(target, tick)
@@ -470,7 +470,7 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
470
470
  # can safely ignore the error if the task is already scheduled.
471
471
  logger.debug(f'Request {event.id} already exists.')
472
472
  asyncio.create_task(cleanup_upload_ids())
473
- if metrics.METRICS_ENABLED:
473
+ if metrics_utils.METRICS_ENABLED:
474
474
  # Start monitoring the event loop lag in each server worker
475
475
  # event loop (process).
476
476
  asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
@@ -1743,7 +1743,7 @@ async def kubernetes_pod_ssh_proxy(websocket: fastapi.WebSocket,
1743
1743
  return
1744
1744
 
1745
1745
  logger.info(f'Starting port-forward to local port: {local_port}')
1746
- conn_gauge = metrics.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1746
+ conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1747
1747
  pid=os.getpid())
1748
1748
  ssh_failed = False
1749
1749
  websocket_closed = False
@@ -1807,14 +1807,14 @@ async def kubernetes_pod_ssh_proxy(websocket: fastapi.WebSocket,
1807
1807
  'ssh websocket connection was closed. Remaining '
1808
1808
  f'output: {str(stdout)}')
1809
1809
  reason = 'KubectlPortForwardExit'
1810
- metrics.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1810
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1811
1811
  pid=os.getpid(), reason='KubectlPortForwardExit').inc()
1812
1812
  else:
1813
1813
  if ssh_failed:
1814
1814
  reason = 'SSHToPodDisconnected'
1815
1815
  else:
1816
1816
  reason = 'ClientClosed'
1817
- metrics.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1817
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1818
1818
  pid=os.getpid(), reason=reason).inc()
1819
1819
 
1820
1820
 
@@ -1831,42 +1831,6 @@ async def all_contexts(request: fastapi.Request) -> None:
1831
1831
  )
1832
1832
 
1833
1833
 
1834
- @app.get('/gpu-metrics')
1835
- async def gpu_metrics() -> fastapi.Response:
1836
- """Gets the GPU metrics from multiple external k8s clusters"""
1837
- contexts = core.get_all_contexts()
1838
- all_metrics: List[str] = []
1839
- successful_contexts = 0
1840
-
1841
- tasks = [
1842
- asyncio.create_task(metrics_utils.get_metrics_for_context(context))
1843
- for context in contexts
1844
- if context != 'in-cluster'
1845
- ]
1846
-
1847
- results = await asyncio.gather(*tasks, return_exceptions=True)
1848
-
1849
- for i, result in enumerate(results):
1850
- if isinstance(result, Exception):
1851
- logger.error(
1852
- f'Failed to get metrics for context {contexts[i]}: {result}')
1853
- elif isinstance(result, BaseException):
1854
- # Avoid changing behavior for non-Exception BaseExceptions
1855
- # like KeyboardInterrupt/SystemExit: re-raise them.
1856
- raise result
1857
- else:
1858
- metrics_text = result
1859
- all_metrics.append(metrics_text)
1860
- successful_contexts += 1
1861
-
1862
- combined_metrics = '\n\n'.join(all_metrics)
1863
-
1864
- # Return as plain text for Prometheus compatibility
1865
- return fastapi.Response(
1866
- content=combined_metrics,
1867
- media_type='text/plain; version=0.0.4; charset=utf-8')
1868
-
1869
-
1870
1834
  # === Internal APIs ===
1871
1835
  @app.get('/api/completion/cluster_name')
1872
1836
  async def complete_cluster_name(incomplete: str,) -> List[str]:
@@ -189,6 +189,7 @@ extras_require: Dict[str, List[str]] = {
189
189
  'fluidstack': [], # No dependencies needed for fluidstack
190
190
  'cudo': ['cudo-compute>=0.1.10'],
191
191
  'paperspace': [], # No dependencies needed for paperspace
192
+ 'primeintellect': [], # No dependencies needed for primeintellect
192
193
  'do': ['pydo>=0.3.0', 'azure-core>=1.24.0', 'azure-common'],
193
194
  'vast': ['vastai-sdk>=0.1.12'],
194
195
  'vsphere': [
sky/skylet/constants.py CHANGED
@@ -29,6 +29,7 @@ SKY_REMOTE_RAY_PORT_FILE = '~/.sky/ray_port.json'
29
29
  SKY_REMOTE_RAY_TEMPDIR = '/tmp/ray_skypilot'
30
30
  SKY_REMOTE_RAY_VERSION = '2.9.3'
31
31
 
32
+ SKY_UNSET_PYTHONPATH = 'env -u PYTHONPATH'
32
33
  # We store the absolute path of the python executable (/opt/conda/bin/python3)
33
34
  # in this file, so that any future internal commands that need to use python
34
35
  # can use this path. This is useful for the case where the user has a custom
@@ -40,7 +41,7 @@ SKY_GET_PYTHON_PATH_CMD = (f'[ -s {SKY_PYTHON_PATH_FILE} ] && '
40
41
  f'cat {SKY_PYTHON_PATH_FILE} 2> /dev/null || '
41
42
  'which python3')
42
43
  # Python executable, e.g., /opt/conda/bin/python3
43
- SKY_PYTHON_CMD = f'$({SKY_GET_PYTHON_PATH_CMD})'
44
+ SKY_PYTHON_CMD = f'{SKY_UNSET_PYTHONPATH} $({SKY_GET_PYTHON_PATH_CMD})'
44
45
  # Prefer SKY_UV_PIP_CMD, which is faster.
45
46
  # TODO(cooperc): remove remaining usage (GCP TPU setup).
46
47
  SKY_PIP_CMD = f'{SKY_PYTHON_CMD} -m pip'
@@ -56,13 +57,16 @@ SKY_REMOTE_PYTHON_ENV: str = f'~/{SKY_REMOTE_PYTHON_ENV_NAME}'
56
57
  ACTIVATE_SKY_REMOTE_PYTHON_ENV = f'source {SKY_REMOTE_PYTHON_ENV}/bin/activate'
57
58
  # uv is used for venv and pip, much faster than python implementations.
58
59
  SKY_UV_INSTALL_DIR = '"$HOME/.local/bin"'
59
- SKY_UV_CMD = f'UV_SYSTEM_PYTHON=false {SKY_UV_INSTALL_DIR}/uv'
60
+ SKY_UV_CMD = ('UV_SYSTEM_PYTHON=false '
61
+ f'{SKY_UNSET_PYTHONPATH} {SKY_UV_INSTALL_DIR}/uv')
60
62
  # This won't reinstall uv if it's already installed, so it's safe to re-run.
61
63
  SKY_UV_INSTALL_CMD = (f'{SKY_UV_CMD} -V >/dev/null 2>&1 || '
62
64
  'curl -LsSf https://astral.sh/uv/install.sh '
63
65
  f'| UV_INSTALL_DIR={SKY_UV_INSTALL_DIR} sh')
64
66
  SKY_UV_PIP_CMD: str = (f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} pip')
65
- SKY_UV_RUN_CMD: str = (f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} run')
67
+ SKY_UV_RUN_CMD: str = (
68
+ f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} run --active '
69
+ '--no-project --no-config')
66
70
  # Deleting the SKY_REMOTE_PYTHON_ENV_NAME from the PATH and unsetting relevant
67
71
  # VIRTUAL_ENV envvars to deactivate the environment. `deactivate` command does
68
72
  # not work when conda is used.
@@ -153,7 +157,7 @@ CONDA_INSTALLATION_COMMANDS = (
153
157
  # because for some images, conda is already installed, but not initialized.
154
158
  # In this case, we need to initialize conda and set auto_activate_base to
155
159
  # true.
156
- '{ bash Miniconda3-Linux.sh -b; '
160
+ '{ bash Miniconda3-Linux.sh -b || true; '
157
161
  'eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && '
158
162
  # Caller should replace {conda_auto_activate} with either true or false.
159
163
  'conda config --set auto_activate_base {conda_auto_activate} && '
@@ -456,7 +460,8 @@ CATALOG_SCHEMA_VERSION = 'v8'
456
460
  CATALOG_DIR = '~/.sky/catalogs'
457
461
  ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
458
462
  'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
459
- 'paperspace', 'do', 'nebius', 'ssh', 'hyperbolic', 'seeweb')
463
+ 'paperspace', 'primeintellect', 'do', 'nebius', 'ssh',
464
+ 'hyperbolic', 'seeweb')
460
465
  # END constants used for service catalog.
461
466
 
462
467
  # The user ID of the SkyPilot system.
sky/skylet/job_lib.py CHANGED
@@ -559,21 +559,20 @@ def get_jobs_info(user_hash: Optional[str] = None,
559
559
  jobs_info = []
560
560
  for job in jobs:
561
561
  jobs_info.append(
562
- jobsv1_pb2.JobInfo(
563
- job_id=job['job_id'],
564
- job_name=job['job_name'],
565
- username=job['username'],
566
- submitted_at=job['submitted_at'],
567
- status=job['status'].to_protobuf(),
568
- run_timestamp=job['run_timestamp'],
569
- start_at=job['start_at']
570
- if job['start_at'] is not None else -1.0,
571
- end_at=job['end_at'] if job['end_at'] is not None else 0.0,
572
- resources=job['resources'] or '',
573
- pid=job['pid'],
574
- log_path=os.path.join(constants.SKY_LOGS_DIRECTORY,
575
- job['run_timestamp']),
576
- metadata=json.dumps(job['metadata'])))
562
+ jobsv1_pb2.JobInfo(job_id=job['job_id'],
563
+ job_name=job['job_name'],
564
+ username=job['username'],
565
+ submitted_at=job['submitted_at'],
566
+ status=job['status'].to_protobuf(),
567
+ run_timestamp=job['run_timestamp'],
568
+ start_at=job['start_at'],
569
+ end_at=job['end_at'],
570
+ resources=job['resources'],
571
+ pid=job['pid'],
572
+ log_path=os.path.join(
573
+ constants.SKY_LOGS_DIRECTORY,
574
+ job['run_timestamp']),
575
+ metadata=json.dumps(job['metadata'])))
577
576
  return jobs_info
578
577
 
579
578
 
sky/skylet/services.py CHANGED
@@ -10,7 +10,11 @@ from sky.schemas.generated import autostopv1_pb2
10
10
  from sky.schemas.generated import autostopv1_pb2_grpc
11
11
  from sky.schemas.generated import jobsv1_pb2
12
12
  from sky.schemas.generated import jobsv1_pb2_grpc
13
+ from sky.schemas.generated import servev1_pb2
14
+ from sky.schemas.generated import servev1_pb2_grpc
15
+ from sky.serve import serve_rpc_utils
13
16
  from sky.serve import serve_state
17
+ from sky.serve import serve_utils
14
18
  from sky.skylet import autostop_lib
15
19
  from sky.skylet import constants
16
20
  from sky.skylet import job_lib
@@ -52,6 +56,100 @@ class AutostopServiceImpl(autostopv1_pb2_grpc.AutostopServiceServicer):
52
56
  context.abort(grpc.StatusCode.INTERNAL, str(e))
53
57
 
54
58
 
59
+ class ServeServiceImpl(servev1_pb2_grpc.ServeServiceServicer):
60
+ """Implementation of the ServeService gRPC service."""
61
+
62
+ # NOTE (kyuds): this grpc service will run cluster-side,
63
+ # thus guaranteeing that SERVE_VERSION is above 5.
64
+ # Therefore, we removed some SERVE_VERSION checks
65
+ # present in the original codegen.
66
+
67
+ def GetServiceStatus( # type: ignore[return]
68
+ self, request: servev1_pb2.GetServiceStatusRequest,
69
+ context: grpc.ServicerContext
70
+ ) -> servev1_pb2.GetServiceStatusResponse:
71
+ """Gets serve status."""
72
+ try:
73
+ service_names, pool = (
74
+ serve_rpc_utils.GetServiceStatusRequestConverter.from_proto(request)) # pylint: disable=line-too-long
75
+ statuses = serve_utils.get_service_status_pickled(
76
+ service_names, pool)
77
+ return serve_rpc_utils.GetServiceStatusResponseConverter.to_proto(
78
+ statuses)
79
+ except Exception as e: # pylint: disable=broad-except
80
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
81
+
82
+ def AddVersion( # type: ignore[return]
83
+ self, request: servev1_pb2.AddVersionRequest,
84
+ context: grpc.ServicerContext) -> servev1_pb2.AddVersionResponse:
85
+ """Adds serve version"""
86
+ try:
87
+ service_name = request.service_name
88
+ version = serve_state.add_version(service_name)
89
+ return servev1_pb2.AddVersionResponse(version=version)
90
+ except Exception as e: # pylint: disable=broad-except
91
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
92
+
93
+ def TerminateServices( # type: ignore[return]
94
+ self, request: servev1_pb2.TerminateServicesRequest,
95
+ context: grpc.ServicerContext
96
+ ) -> servev1_pb2.TerminateServicesResponse:
97
+ """Terminates serve"""
98
+ try:
99
+ service_names, purge, pool = (
100
+ serve_rpc_utils.TerminateServicesRequestConverter.from_proto(request)) # pylint: disable=line-too-long
101
+ message = serve_utils.terminate_services(service_names, purge, pool)
102
+ return servev1_pb2.TerminateServicesResponse(message=message)
103
+ except Exception as e: # pylint: disable=broad-except
104
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
105
+
106
+ def TerminateReplica( # type: ignore[return]
107
+ self, request: servev1_pb2.TerminateReplicaRequest,
108
+ context: grpc.ServicerContext
109
+ ) -> servev1_pb2.TerminateReplicaResponse:
110
+ """Terminate replica"""
111
+ try:
112
+ service_name = request.service_name
113
+ replica_id = request.replica_id
114
+ purge = request.purge
115
+ message = serve_utils.terminate_replica(service_name, replica_id,
116
+ purge)
117
+ return servev1_pb2.TerminateReplicaResponse(message=message)
118
+ except Exception as e: # pylint: disable=broad-except
119
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
120
+
121
+ def WaitServiceRegistration( # type: ignore[return]
122
+ self, request: servev1_pb2.WaitServiceRegistrationRequest,
123
+ context: grpc.ServicerContext
124
+ ) -> servev1_pb2.WaitServiceRegistrationResponse:
125
+ """Wait for service to be registered"""
126
+ try:
127
+ service_name = request.service_name
128
+ job_id = request.job_id
129
+ pool = request.pool
130
+ encoded = serve_utils.wait_service_registration(
131
+ service_name, job_id, pool)
132
+ lb_port = serve_utils.load_service_initialization_result(encoded)
133
+ return servev1_pb2.WaitServiceRegistrationResponse(lb_port=lb_port)
134
+ except Exception as e: # pylint: disable=broad-except
135
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
136
+
137
+ def UpdateService( # type: ignore[return]
138
+ self, request: servev1_pb2.UpdateServiceRequest,
139
+ context: grpc.ServicerContext) -> servev1_pb2.UpdateServiceResponse:
140
+ """Update service"""
141
+ try:
142
+ service_name = request.service_name
143
+ version = request.version
144
+ mode = request.mode
145
+ pool = request.pool
146
+ serve_utils.update_service_encoded(service_name, version, mode,
147
+ pool)
148
+ return servev1_pb2.UpdateServiceResponse()
149
+ except Exception as e: # pylint: disable=broad-except
150
+ context.abort(grpc.StatusCode.INTERNAL, str(e))
151
+
152
+
55
153
  class JobsServiceImpl(jobsv1_pb2_grpc.JobsServiceServicer):
56
154
  """Implementation of the JobsService gRPC service."""
57
155
 
sky/skylet/skylet.py CHANGED
@@ -10,6 +10,7 @@ import sky
10
10
  from sky import sky_logging
11
11
  from sky.schemas.generated import autostopv1_pb2_grpc
12
12
  from sky.schemas.generated import jobsv1_pb2_grpc
13
+ from sky.schemas.generated import servev1_pb2_grpc
13
14
  from sky.skylet import constants
14
15
  from sky.skylet import events
15
16
  from sky.skylet import services
@@ -50,9 +51,10 @@ def start_grpc_server(port: int = constants.SKYLET_GRPC_PORT) -> grpc.Server:
50
51
 
51
52
  autostopv1_pb2_grpc.add_AutostopServiceServicer_to_server(
52
53
  services.AutostopServiceImpl(), server)
53
-
54
54
  jobsv1_pb2_grpc.add_JobsServiceServicer_to_server(
55
55
  services.JobsServiceImpl(), server)
56
+ servev1_pb2_grpc.add_ServeServiceServicer_to_server(
57
+ services.ServeServiceImpl(), server)
56
58
 
57
59
  listen_addr = f'127.0.0.1:{port}'
58
60
  server.add_insecure_port(listen_addr)