skypilot-nightly 1.0.0.dev20251013__py3-none-any.whl → 1.0.0.dev20251014__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (52) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +9 -2
  3. sky/backends/backend_utils.py +33 -25
  4. sky/backends/cloud_vm_ray_backend.py +3 -5
  5. sky/catalog/kubernetes_catalog.py +19 -25
  6. sky/client/cli/command.py +53 -19
  7. sky/client/sdk.py +13 -1
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/_next/static/chunks/{webpack-ac3a34c8f9fef041.js → webpack-66f23594d38c7f16.js} +1 -1
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/config.html +1 -1
  14. sky/dashboard/out/index.html +1 -1
  15. sky/dashboard/out/infra/[context].html +1 -1
  16. sky/dashboard/out/infra.html +1 -1
  17. sky/dashboard/out/jobs/[job].html +1 -1
  18. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  19. sky/dashboard/out/jobs.html +1 -1
  20. sky/dashboard/out/users.html +1 -1
  21. sky/dashboard/out/volumes.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/jobs/controller.py +122 -145
  26. sky/jobs/recovery_strategy.py +59 -82
  27. sky/jobs/scheduler.py +5 -5
  28. sky/jobs/state.py +65 -21
  29. sky/jobs/utils.py +58 -22
  30. sky/metrics/utils.py +27 -6
  31. sky/provision/kubernetes/utils.py +44 -39
  32. sky/server/common.py +4 -2
  33. sky/server/requests/executor.py +3 -1
  34. sky/server/server.py +5 -0
  35. sky/sky_logging.py +0 -2
  36. sky/skylet/constants.py +22 -5
  37. sky/skylet/log_lib.py +0 -1
  38. sky/skylet/log_lib.pyi +1 -1
  39. sky/utils/common.py +2 -0
  40. sky/utils/context.py +57 -51
  41. sky/utils/context_utils.py +2 -2
  42. sky/utils/controller_utils.py +35 -8
  43. sky/utils/locks.py +20 -5
  44. sky/utils/subprocess_utils.py +4 -3
  45. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/METADATA +36 -36
  46. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/RECORD +52 -52
  47. /sky/dashboard/out/_next/static/{MtlDUf-nH1hhcy7xwbCj3 → 9Fek73R28lDp1A5J4N7g7}/_buildManifest.js +0 -0
  48. /sky/dashboard/out/_next/static/{MtlDUf-nH1hhcy7xwbCj3 → 9Fek73R28lDp1A5J4N7g7}/_ssgManifest.js +0 -0
  49. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/WHEEL +0 -0
  50. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/entry_points.txt +0 -0
  51. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/licenses/LICENSE +0 -0
  52. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -8,7 +8,6 @@ import asyncio
8
8
  import collections
9
9
  import datetime
10
10
  import enum
11
- import logging
12
11
  import os
13
12
  import pathlib
14
13
  import re
@@ -84,6 +83,7 @@ _LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS = 5
84
83
 
85
84
  _JOB_STATUS_FETCH_MAX_RETRIES = 3
86
85
  _JOB_K8S_TRANSIENT_NW_MSG = 'Unable to connect to the server: dial tcp'
86
+ _JOB_STATUS_FETCH_TIMEOUT_SECONDS = 30
87
87
 
88
88
  _JOB_WAITING_STATUS_MESSAGE = ux_utils.spinner_message(
89
89
  'Waiting for task to start[/]'
@@ -101,6 +101,13 @@ _JOB_CANCELLED_MESSAGE = (
101
101
  # update the state.
102
102
  _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
103
103
 
104
+ # After enabling consolidation mode, we need to restart the API server to get
105
+ # the jobs refresh deamon and correct number of executors. We use this file to
106
+ # indicate that the API server has been restarted after enabling consolidation
107
+ # mode.
108
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE = (
109
+ '~/.sky/.jobs_controller_consolidation_reloaded_signal')
110
+
104
111
 
105
112
  class ManagedJobQueueResultType(enum.Enum):
106
113
  """The type of the managed job queue result."""
@@ -117,9 +124,8 @@ class UserSignal(enum.Enum):
117
124
 
118
125
  # ====== internal functions ======
119
126
  def terminate_cluster(
120
- cluster_name: str,
121
- max_retry: int = 6,
122
- _logger: logging.Logger = logger, # pylint: disable=invalid-name
127
+ cluster_name: str,
128
+ max_retry: int = 6,
123
129
  ) -> None:
124
130
  """Terminate the cluster."""
125
131
  from sky import core # pylint: disable=import-outside-toplevel
@@ -143,18 +149,18 @@ def terminate_cluster(
143
149
  return
144
150
  except exceptions.ClusterDoesNotExist:
145
151
  # The cluster is already down.
146
- _logger.debug(f'The cluster {cluster_name} is already down.')
152
+ logger.debug(f'The cluster {cluster_name} is already down.')
147
153
  return
148
154
  except Exception as e: # pylint: disable=broad-except
149
155
  retry_cnt += 1
150
156
  if retry_cnt >= max_retry:
151
157
  raise RuntimeError(
152
158
  f'Failed to terminate the cluster {cluster_name}.') from e
153
- _logger.error(
159
+ logger.error(
154
160
  f'Failed to terminate the cluster {cluster_name}. Retrying.'
155
161
  f'Details: {common_utils.format_exception(e)}')
156
162
  with ux_utils.enable_traceback():
157
- _logger.error(f' Traceback: {traceback.format_exc()}')
163
+ logger.error(f' Traceback: {traceback.format_exc()}')
158
164
  time.sleep(backoff.current_backoff())
159
165
 
160
166
 
@@ -202,13 +208,39 @@ def _validate_consolidation_mode_config(
202
208
  # API Server. Under the hood, we submit the job monitoring logic as processes
203
209
  # directly in the API Server.
204
210
  # Use LRU Cache so that the check is only done once.
205
- @annotations.lru_cache(scope='request', maxsize=1)
206
- def is_consolidation_mode() -> bool:
211
+ @annotations.lru_cache(scope='request', maxsize=2)
212
+ def is_consolidation_mode(on_api_restart: bool = False) -> bool:
207
213
  if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
208
214
  return True
209
215
 
210
- consolidation_mode = skypilot_config.get_nested(
216
+ config_consolidation_mode = skypilot_config.get_nested(
211
217
  ('jobs', 'controller', 'consolidation_mode'), default_value=False)
218
+
219
+ signal_file = pathlib.Path(
220
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE).expanduser()
221
+
222
+ restart_signal_file_exists = signal_file.exists()
223
+ consolidation_mode = (config_consolidation_mode and
224
+ restart_signal_file_exists)
225
+
226
+ if on_api_restart:
227
+ if config_consolidation_mode:
228
+ signal_file.touch()
229
+ else:
230
+ if not restart_signal_file_exists:
231
+ if config_consolidation_mode:
232
+ logger.warning(f'{colorama.Fore.YELLOW}Consolidation mode for '
233
+ 'managed jobs is enabled in the server config, '
234
+ 'but the API server has not been restarted yet. '
235
+ 'Please restart the API server to enable it.'
236
+ f'{colorama.Style.RESET_ALL}')
237
+ return False
238
+ elif not config_consolidation_mode:
239
+ # Cleanup the signal file if the consolidation mode is disabled in
240
+ # the config. This allow the user to disable the consolidation mode
241
+ # without restarting the API server.
242
+ signal_file.unlink()
243
+
212
244
  # We should only do this check on API server, as the controller will not
213
245
  # have related config and will always seemingly disabled for consolidation
214
246
  # mode. Check #6611 for more details.
@@ -269,8 +301,7 @@ def ha_recovery_for_consolidation_mode():
269
301
 
270
302
  async def get_job_status(
271
303
  backend: 'backends.CloudVmRayBackend', cluster_name: str,
272
- job_id: Optional[int],
273
- job_logger: logging.Logger) -> Optional['job_lib.JobStatus']:
304
+ job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
274
305
  """Check the status of the job running on a managed job cluster.
275
306
 
276
307
  It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
@@ -282,26 +313,28 @@ async def get_job_status(
282
313
  if handle is None:
283
314
  # This can happen if the cluster was preempted and background status
284
315
  # refresh already noticed and cleaned it up.
285
- job_logger.info(f'Cluster {cluster_name} not found.')
316
+ logger.info(f'Cluster {cluster_name} not found.')
286
317
  return None
287
318
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
288
319
  job_ids = None if job_id is None else [job_id]
289
320
  for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
290
321
  try:
291
- job_logger.info('=== Checking the job status... ===')
292
- statuses = await context_utils.to_thread(backend.get_job_status,
293
- handle,
294
- job_ids=job_ids,
295
- stream_logs=False)
322
+ logger.info('=== Checking the job status... ===')
323
+ statuses = await asyncio.wait_for(
324
+ context_utils.to_thread(backend.get_job_status,
325
+ handle,
326
+ job_ids=job_ids,
327
+ stream_logs=False),
328
+ timeout=_JOB_STATUS_FETCH_TIMEOUT_SECONDS)
296
329
  status = list(statuses.values())[0]
297
330
  if status is None:
298
- job_logger.info('No job found.')
331
+ logger.info('No job found.')
299
332
  else:
300
- job_logger.info(f'Job status: {status}')
301
- job_logger.info('=' * 34)
333
+ logger.info(f'Job status: {status}')
334
+ logger.info('=' * 34)
302
335
  return status
303
336
  except (exceptions.CommandError, grpc.RpcError, grpc.FutureTimeoutError,
304
- ValueError, TypeError) as e:
337
+ ValueError, TypeError, asyncio.TimeoutError) as e:
305
338
  # Note: Each of these exceptions has some additional conditions to
306
339
  # limit how we handle it and whether or not we catch it.
307
340
  # Retry on k8s transient network errors. This is useful when using
@@ -322,6 +355,9 @@ async def get_job_status(
322
355
  is_transient_error = True
323
356
  elif isinstance(e, grpc.FutureTimeoutError):
324
357
  detailed_reason = 'Timeout'
358
+ elif isinstance(e, asyncio.TimeoutError):
359
+ detailed_reason = ('Job status check timed out after '
360
+ f'{_JOB_STATUS_FETCH_TIMEOUT_SECONDS}s')
325
361
  # TODO(cooperc): Gracefully handle these exceptions in the backend.
326
362
  elif isinstance(e, ValueError):
327
363
  # If the cluster yaml is deleted in the middle of getting the
sky/metrics/utils.py CHANGED
@@ -48,8 +48,15 @@ SKY_APISERVER_CODE_DURATION_SECONDS = prom.Histogram(
48
48
  'sky_apiserver_code_duration_seconds',
49
49
  'Time spent processing code',
50
50
  ['name', 'group'],
51
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
52
- 60.0, 120.0, float('inf')),
51
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
52
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
53
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
54
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
55
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
56
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
57
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
58
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
59
+ 960.0, 980.0, 1000.0, float('inf')),
53
60
  )
54
61
 
55
62
  # Total number of API server requests, grouped by path, method, and status.
@@ -65,16 +72,30 @@ SKY_APISERVER_REQUEST_DURATION_SECONDS = prom.Histogram(
65
72
  'sky_apiserver_request_duration_seconds',
66
73
  'Time spent processing API server requests',
67
74
  ['path', 'method', 'status'],
68
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
69
- 60.0, 120.0, float('inf')),
75
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
76
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
77
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
78
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
79
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
80
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
81
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
82
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
83
+ 960.0, 980.0, 1000.0, float('inf')),
70
84
  )
71
85
 
72
86
  SKY_APISERVER_EVENT_LOOP_LAG_SECONDS = prom.Histogram(
73
87
  'sky_apiserver_event_loop_lag_seconds',
74
88
  'Scheduling delay of the server event loop',
75
89
  ['pid'],
76
- buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 20.0,
77
- 60.0, float('inf')),
90
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
91
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
92
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
93
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
94
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
95
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
96
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
97
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
98
+ 960.0, 980.0, 1000.0, float('inf')),
78
99
  )
79
100
 
80
101
  SKY_APISERVER_WEBSOCKET_CONNECTIONS = prom.Gauge(
@@ -1299,30 +1299,52 @@ class V1Pod:
1299
1299
 
1300
1300
 
1301
1301
  @_retry_on_error(resource_type='pod')
1302
- def get_all_pods_in_kubernetes_cluster(*,
1303
- context: Optional[str] = None
1304
- ) -> List[V1Pod]:
1305
- """Gets pods in all namespaces in kubernetes cluster indicated by context.
1306
-
1307
- Used for computing cluster resource usage.
1302
+ def get_allocated_gpu_qty_by_node(
1303
+ *,
1304
+ context: Optional[str] = None,
1305
+ ) -> Dict[str, int]:
1306
+ """Gets allocated GPU quantity by each node by fetching pods in
1307
+ all namespaces in kubernetes cluster indicated by context.
1308
1308
  """
1309
1309
  if context is None:
1310
1310
  context = get_current_kube_config_context_name()
1311
+ non_included_pod_statuses = POD_STATUSES.copy()
1312
+ status_filters = ['Running', 'Pending']
1313
+ if status_filters is not None:
1314
+ non_included_pod_statuses -= set(status_filters)
1315
+ field_selector = ','.join(
1316
+ [f'status.phase!={status}' for status in non_included_pod_statuses])
1311
1317
 
1312
1318
  # Return raw urllib3.HTTPResponse object so that we can parse the json
1313
1319
  # more efficiently.
1314
1320
  response = kubernetes.core_api(context).list_pod_for_all_namespaces(
1315
- _request_timeout=kubernetes.API_TIMEOUT, _preload_content=False)
1321
+ _request_timeout=kubernetes.API_TIMEOUT,
1322
+ _preload_content=False,
1323
+ field_selector=field_selector)
1316
1324
  try:
1317
- pods = [
1318
- V1Pod.from_dict(item_dict) for item_dict in ijson.items(
1319
- response, 'items.item', buf_size=IJSON_BUFFER_SIZE)
1320
- ]
1325
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
1326
+ for item_dict in ijson.items(response,
1327
+ 'items.item',
1328
+ buf_size=IJSON_BUFFER_SIZE):
1329
+ pod = V1Pod.from_dict(item_dict)
1330
+ if should_exclude_pod_from_gpu_allocation(pod):
1331
+ logger.debug(
1332
+ f'Excluding pod {pod.metadata.name} from GPU count '
1333
+ f'calculations on node {pod.spec.node_name}')
1334
+ continue
1335
+ # Iterate over all the containers in the pod and sum the
1336
+ # GPU requests
1337
+ pod_allocated_qty = 0
1338
+ for container in pod.spec.containers:
1339
+ if container.resources.requests:
1340
+ pod_allocated_qty += get_node_accelerator_count(
1341
+ context, container.resources.requests)
1342
+ if pod_allocated_qty > 0 and pod.spec.node_name:
1343
+ allocated_qty_by_node[pod.spec.node_name] += pod_allocated_qty
1344
+ return allocated_qty_by_node
1321
1345
  finally:
1322
1346
  response.release_conn()
1323
1347
 
1324
- return pods
1325
-
1326
1348
 
1327
1349
  def check_instance_fits(context: Optional[str],
1328
1350
  instance: str) -> Tuple[bool, Optional[str]]:
@@ -3006,41 +3028,24 @@ def get_kubernetes_node_info(
3006
3028
  label_keys = lf.get_label_keys()
3007
3029
 
3008
3030
  # Check if all nodes have no accelerators to avoid fetching pods
3009
- any_node_has_accelerators = False
3031
+ has_accelerator_nodes = False
3010
3032
  for node in nodes:
3011
3033
  accelerator_count = get_node_accelerator_count(context,
3012
3034
  node.status.allocatable)
3013
3035
  if accelerator_count > 0:
3014
- any_node_has_accelerators = True
3036
+ has_accelerator_nodes = True
3015
3037
  break
3016
3038
 
3017
- # Get the pods to get the real-time resource usage
3018
- pods = None
3039
+ # Get the allocated GPU quantity by each node
3019
3040
  allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
3020
- if any_node_has_accelerators:
3041
+ error_on_get_allocated_gpu_qty_by_node = False
3042
+ if has_accelerator_nodes:
3021
3043
  try:
3022
- pods = get_all_pods_in_kubernetes_cluster(context=context)
3023
- # Pre-compute allocated accelerator count per node
3024
- for pod in pods:
3025
- if pod.status.phase in ['Running', 'Pending']:
3026
- # Skip pods that should not count against GPU count
3027
- if should_exclude_pod_from_gpu_allocation(pod):
3028
- logger.debug(f'Excluding low priority pod '
3029
- f'{pod.metadata.name} from GPU allocation '
3030
- f'calculations')
3031
- continue
3032
- # Iterate over all the containers in the pod and sum the
3033
- # GPU requests
3034
- pod_allocated_qty = 0
3035
- for container in pod.spec.containers:
3036
- if container.resources.requests:
3037
- pod_allocated_qty += get_node_accelerator_count(
3038
- context, container.resources.requests)
3039
- if pod_allocated_qty > 0:
3040
- allocated_qty_by_node[
3041
- pod.spec.node_name] += pod_allocated_qty
3044
+ allocated_qty_by_node = get_allocated_gpu_qty_by_node(
3045
+ context=context)
3042
3046
  except kubernetes.api_exception() as e:
3043
3047
  if e.status == 403:
3048
+ error_on_get_allocated_gpu_qty_by_node = True
3044
3049
  pass
3045
3050
  else:
3046
3051
  raise
@@ -3085,7 +3090,7 @@ def get_kubernetes_node_info(
3085
3090
  ip_address=node_ip)
3086
3091
  continue
3087
3092
 
3088
- if pods is None:
3093
+ if not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
3089
3094
  accelerators_available = -1
3090
3095
  else:
3091
3096
  allocated_qty = allocated_qty_by_node[node.metadata.name]
sky/server/common.py CHANGED
@@ -554,8 +554,8 @@ def _start_api_server(deploy: bool = False,
554
554
  # pylint: disable=import-outside-toplevel
555
555
  import sky.jobs.utils as job_utils
556
556
  max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
557
- if job_utils.is_consolidation_mode() else
558
- server_constants.MIN_AVAIL_MEM_GB)
557
+ if job_utils.is_consolidation_mode(on_api_restart=True)
558
+ else server_constants.MIN_AVAIL_MEM_GB)
559
559
  if avail_mem_size_gb <= max_memory:
560
560
  logger.warning(
561
561
  f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only '
@@ -571,6 +571,8 @@ def _start_api_server(deploy: bool = False,
571
571
  args += [f'--host={host}']
572
572
  if metrics_port is not None:
573
573
  args += [f'--metrics-port={metrics_port}']
574
+ # Use this argument to disable the internal signal file check.
575
+ args += ['--start-with-python']
574
576
 
575
577
  if foreground:
576
578
  # Replaces the current process with the API server
@@ -424,6 +424,7 @@ def _request_execution_wrapper(request_id: str,
424
424
  os.close(original_stderr)
425
425
  original_stderr = None
426
426
 
427
+ request_name = None
427
428
  try:
428
429
  # As soon as the request is updated with the executor PID, we can
429
430
  # receive SIGTERM from cancellation. So, we update the request inside
@@ -515,7 +516,8 @@ def _request_execution_wrapper(request_id: str,
515
516
  annotations.clear_request_level_cache()
516
517
  with metrics_utils.time_it(name='release_memory', group='internal'):
517
518
  common_utils.release_memory()
518
- _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
519
+ if request_name is not None:
520
+ _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
519
521
  except Exception as e: # pylint: disable=broad-except
520
522
  logger.error(f'Failed to record memory metrics: '
521
523
  f'{common_utils.format_exception(e)}')
sky/server/server.py CHANGED
@@ -1968,6 +1968,7 @@ if __name__ == '__main__':
1968
1968
  # Serve metrics on a separate port to isolate it from the application APIs:
1969
1969
  # metrics port will not be exposed to the public network typically.
1970
1970
  parser.add_argument('--metrics-port', default=9090, type=int)
1971
+ parser.add_argument('--start-with-python', action='store_true')
1971
1972
  cmd_args = parser.parse_args()
1972
1973
  if cmd_args.port == cmd_args.metrics_port:
1973
1974
  logger.error('port and metrics-port cannot be the same, exiting.')
@@ -1982,6 +1983,10 @@ if __name__ == '__main__':
1982
1983
  logger.error(f'Port {cmd_args.port} is not available, exiting.')
1983
1984
  raise RuntimeError(f'Port {cmd_args.port} is not available')
1984
1985
 
1986
+ if not cmd_args.start_with_python:
1987
+ # Maybe touch the signal file on API server startup.
1988
+ managed_job_utils.is_consolidation_mode(on_api_restart=True)
1989
+
1985
1990
  # Show the privacy policy if it is not already shown. We place it here so
1986
1991
  # that it is shown only when the API server is started.
1987
1992
  usage_lib.maybe_show_privacy_policy()
sky/sky_logging.py CHANGED
@@ -109,7 +109,6 @@ def _setup_logger():
109
109
  global _default_handler
110
110
  if _default_handler is None:
111
111
  _default_handler = EnvAwareHandler(sys.stdout)
112
- _default_handler.flush = sys.stdout.flush # type: ignore
113
112
  if env_options.Options.SHOW_DEBUG_INFO.get():
114
113
  _default_handler.setLevel(logging.DEBUG)
115
114
  else:
@@ -129,7 +128,6 @@ def _setup_logger():
129
128
  for logger_name in _SENSITIVE_LOGGER:
130
129
  logger = logging.getLogger(logger_name)
131
130
  handler_to_logger = EnvAwareHandler(sys.stdout, sensitive=True)
132
- handler_to_logger.flush = sys.stdout.flush # type: ignore
133
131
  logger.addHandler(handler_to_logger)
134
132
  logger.setLevel(logging.INFO)
135
133
  if _show_logging_prefix():
sky/skylet/constants.py CHANGED
@@ -226,7 +226,9 @@ RAY_INSTALLATION_COMMANDS = (
226
226
  f'{SKY_UV_PIP_CMD} list | grep "ray " | '
227
227
  f'grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null '
228
228
  f'|| {RAY_STATUS} || '
229
- f'{SKY_UV_PIP_CMD} install -U ray[default]=={SKY_REMOTE_RAY_VERSION}; ' # pylint: disable=line-too-long
229
+ # The pydantic-core==2.41.3 for arm seems corrupted
230
+ # so we need to avoid that specific version.
231
+ f'{SKY_UV_PIP_CMD} install -U "ray[default]=={SKY_REMOTE_RAY_VERSION}" "pydantic-core==2.41.1"; ' # pylint: disable=line-too-long
230
232
  # In some envs, e.g. pip does not have permission to write under /opt/conda
231
233
  # ray package will be installed under ~/.local/bin. If the user's PATH does
232
234
  # not include ~/.local/bin (the pip install will have the output: `WARNING:
@@ -402,10 +404,25 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
402
404
  ]
403
405
  # When overriding the SkyPilot configs on the API server with the client one,
404
406
  # we skip the following keys because they are meant to be client-side configs.
405
- SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('api_server',),
406
- ('allowed_clouds',),
407
- ('workspaces',), ('db',),
408
- ('daemons',)]
407
+ # Also, we skip the consolidation mode config as those should be only set on
408
+ # the API server side.
409
+ SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [
410
+ ('api_server',),
411
+ ('allowed_clouds',),
412
+ ('workspaces',),
413
+ ('db',),
414
+ ('daemons',),
415
+ # TODO(kevin,tian): Override the whole controller config once our test
416
+ # infrastructure supports setting dynamic server side configs.
417
+ # Tests that are affected:
418
+ # - test_managed_jobs_ha_kill_starting
419
+ # - test_managed_jobs_ha_kill_running
420
+ # - all tests that use LOW_CONTROLLER_RESOURCE_ENV or
421
+ # LOW_CONTROLLER_RESOURCE_OVERRIDE_CONFIG (won't cause test failure,
422
+ # but the configs won't be applied)
423
+ ('jobs', 'controller', 'consolidation_mode'),
424
+ ('serve', 'controller', 'consolidation_mode'),
425
+ ]
409
426
 
410
427
  # Constants for Azure blob storage
411
428
  WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
sky/skylet/log_lib.py CHANGED
@@ -271,7 +271,6 @@ def run_with_log(
271
271
  stdout, stderr = context_utils.pipe_and_wait_process(
272
272
  ctx,
273
273
  proc,
274
- cancel_callback=subprocess_utils.kill_children_processes,
275
274
  stdout_stream_handler=stdout_stream_handler,
276
275
  stderr_stream_handler=stderr_stream_handler)
277
276
  elif process_stream:
sky/skylet/log_lib.pyi CHANGED
@@ -42,7 +42,7 @@ class _ProcessingArgs:
42
42
  ...
43
43
 
44
44
 
45
- def _get_context() -> Optional[context.Context]:
45
+ def _get_context() -> Optional[context.SkyPilotContext]:
46
46
  ...
47
47
 
48
48
 
sky/utils/common.py CHANGED
@@ -42,6 +42,8 @@ def refresh_server_id() -> None:
42
42
  JOB_CONTROLLER_NAME = f'{JOB_CONTROLLER_PREFIX}{SERVER_ID}'
43
43
 
44
44
 
45
+ # TODO(kevin): Remove this side effect and have callers call
46
+ # refresh_server_id() explicitly as needed.
45
47
  refresh_server_id()
46
48
 
47
49
 
sky/utils/context.py CHANGED
@@ -5,13 +5,12 @@ from collections.abc import Mapping
5
5
  import contextvars
6
6
  import copy
7
7
  import functools
8
- import inspect
9
8
  import os
10
9
  import pathlib
11
10
  import subprocess
12
11
  import sys
13
- from typing import (Callable, Dict, Iterator, MutableMapping, Optional, TextIO,
14
- TYPE_CHECKING, TypeVar)
12
+ from typing import (Any, Callable, Coroutine, Dict, Iterator, MutableMapping,
13
+ Optional, TextIO, TYPE_CHECKING, TypeVar)
15
14
 
16
15
  from typing_extensions import ParamSpec
17
16
 
@@ -19,7 +18,7 @@ if TYPE_CHECKING:
19
18
  from sky.skypilot_config import ConfigContext
20
19
 
21
20
 
22
- class Context(object):
21
+ class SkyPilotContext(object):
23
22
  """SkyPilot typed context vars for threads and coroutines.
24
23
 
25
24
  This is a wrapper around `contextvars.ContextVar` that provides a typed
@@ -114,7 +113,14 @@ class Context(object):
114
113
  self._log_file_handle.close()
115
114
  self._log_file_handle = None
116
115
 
117
- def copy(self) -> 'Context':
116
+ def __enter__(self):
117
+ return self
118
+
119
+ def __exit__(self, exc_type, exc_val, exc_tb):
120
+ del exc_type, exc_val, exc_tb
121
+ self.cleanup()
122
+
123
+ def copy(self) -> 'SkyPilotContext':
118
124
  """Create a copy of the context.
119
125
 
120
126
  Changes to the current context after this call will not affect the copy.
@@ -123,18 +129,18 @@ class Context(object):
123
129
  The new context will get an independent copy of the config context.
124
130
  Cancellation of the current context will not be propagated to the copy.
125
131
  """
126
- new_context = Context()
132
+ new_context = SkyPilotContext()
127
133
  new_context.redirect_log(self._log_file)
128
134
  new_context.env_overrides = self.env_overrides.copy()
129
135
  new_context.config_context = copy.deepcopy(self.config_context)
130
136
  return new_context
131
137
 
132
138
 
133
- _CONTEXT = contextvars.ContextVar[Optional[Context]]('sky_context',
134
- default=None)
139
+ _CONTEXT = contextvars.ContextVar[Optional[SkyPilotContext]]('sky_context',
140
+ default=None)
135
141
 
136
142
 
137
- def get() -> Optional[Context]:
143
+ def get() -> Optional[SkyPilotContext]:
138
144
  """Get the current SkyPilot context.
139
145
 
140
146
  If the context is not initialized, get() will return None. This helps
@@ -200,7 +206,7 @@ class ContextualEnviron(MutableMapping[str, str]):
200
206
 
201
207
  def __iter__(self) -> Iterator[str]:
202
208
 
203
- def iter_from_context(ctx: Context) -> Iterator[str]:
209
+ def iter_from_context(ctx: SkyPilotContext) -> Iterator[str]:
204
210
  deleted_keys = set()
205
211
  for key, value in ctx.env_overrides.items():
206
212
  if value is None:
@@ -311,56 +317,56 @@ def contextual(func: Callable[P, T]) -> Callable[P, T]:
311
317
  context that inherits the values from the existing context.
312
318
  """
313
319
 
320
+ def run_in_context(*args: P.args, **kwargs: P.kwargs) -> T:
321
+ # Within the new contextvars Context, set up the SkyPilotContext.
322
+ original_ctx = get()
323
+ with initialize(original_ctx):
324
+ return func(*args, **kwargs)
325
+
314
326
  @functools.wraps(func)
315
327
  def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
328
+ # Create a copy of the current contextvars Context so that setting the
329
+ # SkyPilotContext does not affect the caller's context in async
330
+ # environments.
331
+ context = contextvars.copy_context()
332
+ return context.run(run_in_context, *args, **kwargs)
333
+
334
+ return wrapper
335
+
336
+
337
+ def contextual_async(
338
+ func: Callable[P, Coroutine[Any, Any, T]]
339
+ ) -> Callable[P, Coroutine[Any, Any, T]]:
340
+ """Decorator to initialize a context before executing the function.
341
+
342
+ If a context is already initialized, this decorator will create a new
343
+ context that inherits the values from the existing context.
344
+ """
345
+
346
+ async def run_in_context(*args: P.args, **kwargs: P.kwargs) -> T:
347
+ # Within the new contextvars Context, set up the SkyPilotContext.
316
348
  original_ctx = get()
317
- initialize(original_ctx)
318
- ctx = get()
319
- cleanup_after_await = False
320
-
321
- def cleanup():
322
- try:
323
- if ctx is not None:
324
- ctx.cleanup()
325
- finally:
326
- # Note: _CONTEXT.reset() is not reliable - may fail with
327
- # ValueError: <Token ... at ...> was created in a different
328
- # Context
329
- # We must make sure this happens because otherwise we may try to
330
- # write to the wrong log.
331
- _CONTEXT.set(original_ctx)
332
-
333
- # There are two cases:
334
- # 1. The function is synchronous (that is, return type is not awaitable)
335
- # In this case, we use a finally block to cleanup the context.
336
- # 2. The function is asynchronous (that is, return type is awaitable)
337
- # In this case, we need to construct an async def wrapper and await
338
- # the value, then call the cleanup function in the finally block.
339
-
340
- async def await_with_cleanup(awaitable):
341
- try:
342
- return await awaitable
343
- finally:
344
- cleanup()
345
-
346
- try:
347
- ret = func(*args, **kwargs)
348
- if inspect.isawaitable(ret):
349
- cleanup_after_await = True
350
- return await_with_cleanup(ret)
351
- else:
352
- return ret
353
- finally:
354
- if not cleanup_after_await:
355
- cleanup()
349
+ with initialize(original_ctx):
350
+ return await func(*args, **kwargs)
351
+
352
+ @functools.wraps(func)
353
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
354
+ # Create a copy of the current contextvars Context so that setting the
355
+ # SkyPilotContext does not affect the caller's context in async
356
+ # environments.
357
+ context = contextvars.copy_context()
358
+ return await context.run(run_in_context, *args, **kwargs)
356
359
 
357
360
  return wrapper
358
361
 
359
362
 
360
- def initialize(base_context: Optional[Context] = None) -> None:
363
+ def initialize(
364
+ base_context: Optional[SkyPilotContext] = None) -> SkyPilotContext:
361
365
  """Initialize the current SkyPilot context."""
362
- new_context = base_context.copy() if base_context is not None else Context()
366
+ new_context = base_context.copy(
367
+ ) if base_context is not None else SkyPilotContext()
363
368
  _CONTEXT.set(new_context)
369
+ return new_context
364
370
 
365
371
 
366
372
  class _ContextualStream: