skypilot-nightly 1.0.0.dev20251013__py3-none-any.whl → 1.0.0.dev20251015__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (57) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +9 -2
  3. sky/backends/backend_utils.py +62 -40
  4. sky/backends/cloud_vm_ray_backend.py +8 -6
  5. sky/catalog/kubernetes_catalog.py +19 -25
  6. sky/client/cli/command.py +53 -19
  7. sky/client/sdk.py +13 -1
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  10. sky/dashboard/out/clusters/[cluster].html +1 -1
  11. sky/dashboard/out/clusters.html +1 -1
  12. sky/dashboard/out/config.html +1 -1
  13. sky/dashboard/out/index.html +1 -1
  14. sky/dashboard/out/infra/[context].html +1 -1
  15. sky/dashboard/out/infra.html +1 -1
  16. sky/dashboard/out/jobs/[job].html +1 -1
  17. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  18. sky/dashboard/out/jobs.html +1 -1
  19. sky/dashboard/out/users.html +1 -1
  20. sky/dashboard/out/volumes.html +1 -1
  21. sky/dashboard/out/workspace/new.html +1 -1
  22. sky/dashboard/out/workspaces/[name].html +1 -1
  23. sky/dashboard/out/workspaces.html +1 -1
  24. sky/jobs/controller.py +122 -145
  25. sky/jobs/recovery_strategy.py +59 -82
  26. sky/jobs/scheduler.py +5 -5
  27. sky/jobs/state.py +65 -21
  28. sky/jobs/utils.py +58 -22
  29. sky/metrics/utils.py +27 -6
  30. sky/provision/common.py +2 -0
  31. sky/provision/instance_setup.py +10 -2
  32. sky/provision/kubernetes/instance.py +34 -10
  33. sky/provision/kubernetes/utils.py +53 -39
  34. sky/server/common.py +4 -2
  35. sky/server/requests/executor.py +3 -1
  36. sky/server/requests/preconditions.py +2 -4
  37. sky/server/requests/requests.py +13 -23
  38. sky/server/server.py +5 -0
  39. sky/sky_logging.py +0 -2
  40. sky/skylet/constants.py +22 -5
  41. sky/skylet/log_lib.py +0 -1
  42. sky/skylet/log_lib.pyi +1 -1
  43. sky/utils/asyncio_utils.py +18 -0
  44. sky/utils/common.py +2 -0
  45. sky/utils/context.py +57 -51
  46. sky/utils/context_utils.py +2 -2
  47. sky/utils/controller_utils.py +35 -8
  48. sky/utils/locks.py +20 -5
  49. sky/utils/subprocess_utils.py +4 -3
  50. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/METADATA +38 -37
  51. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/RECORD +57 -56
  52. /sky/dashboard/out/_next/static/{MtlDUf-nH1hhcy7xwbCj3 → -bih7JVStsXyeasac-dvQ}/_buildManifest.js +0 -0
  53. /sky/dashboard/out/_next/static/{MtlDUf-nH1hhcy7xwbCj3 → -bih7JVStsXyeasac-dvQ}/_ssgManifest.js +0 -0
  54. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/WHEEL +0 -0
  55. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/entry_points.txt +0 -0
  56. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/licenses/LICENSE +0 -0
  57. {skypilot_nightly-1.0.0.dev20251013.dist-info → skypilot_nightly-1.0.0.dev20251015.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -8,7 +8,6 @@ import asyncio
8
8
  import collections
9
9
  import datetime
10
10
  import enum
11
- import logging
12
11
  import os
13
12
  import pathlib
14
13
  import re
@@ -84,6 +83,7 @@ _LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS = 5
84
83
 
85
84
  _JOB_STATUS_FETCH_MAX_RETRIES = 3
86
85
  _JOB_K8S_TRANSIENT_NW_MSG = 'Unable to connect to the server: dial tcp'
86
+ _JOB_STATUS_FETCH_TIMEOUT_SECONDS = 30
87
87
 
88
88
  _JOB_WAITING_STATUS_MESSAGE = ux_utils.spinner_message(
89
89
  'Waiting for task to start[/]'
@@ -101,6 +101,13 @@ _JOB_CANCELLED_MESSAGE = (
101
101
  # update the state.
102
102
  _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
103
103
 
104
+ # After enabling consolidation mode, we need to restart the API server to get
105
+ # the jobs refresh deamon and correct number of executors. We use this file to
106
+ # indicate that the API server has been restarted after enabling consolidation
107
+ # mode.
108
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE = (
109
+ '~/.sky/.jobs_controller_consolidation_reloaded_signal')
110
+
104
111
 
105
112
  class ManagedJobQueueResultType(enum.Enum):
106
113
  """The type of the managed job queue result."""
@@ -117,9 +124,8 @@ class UserSignal(enum.Enum):
117
124
 
118
125
  # ====== internal functions ======
119
126
  def terminate_cluster(
120
- cluster_name: str,
121
- max_retry: int = 6,
122
- _logger: logging.Logger = logger, # pylint: disable=invalid-name
127
+ cluster_name: str,
128
+ max_retry: int = 6,
123
129
  ) -> None:
124
130
  """Terminate the cluster."""
125
131
  from sky import core # pylint: disable=import-outside-toplevel
@@ -143,18 +149,18 @@ def terminate_cluster(
143
149
  return
144
150
  except exceptions.ClusterDoesNotExist:
145
151
  # The cluster is already down.
146
- _logger.debug(f'The cluster {cluster_name} is already down.')
152
+ logger.debug(f'The cluster {cluster_name} is already down.')
147
153
  return
148
154
  except Exception as e: # pylint: disable=broad-except
149
155
  retry_cnt += 1
150
156
  if retry_cnt >= max_retry:
151
157
  raise RuntimeError(
152
158
  f'Failed to terminate the cluster {cluster_name}.') from e
153
- _logger.error(
159
+ logger.error(
154
160
  f'Failed to terminate the cluster {cluster_name}. Retrying.'
155
161
  f'Details: {common_utils.format_exception(e)}')
156
162
  with ux_utils.enable_traceback():
157
- _logger.error(f' Traceback: {traceback.format_exc()}')
163
+ logger.error(f' Traceback: {traceback.format_exc()}')
158
164
  time.sleep(backoff.current_backoff())
159
165
 
160
166
 
@@ -202,13 +208,39 @@ def _validate_consolidation_mode_config(
202
208
  # API Server. Under the hood, we submit the job monitoring logic as processes
203
209
  # directly in the API Server.
204
210
  # Use LRU Cache so that the check is only done once.
205
- @annotations.lru_cache(scope='request', maxsize=1)
206
- def is_consolidation_mode() -> bool:
211
+ @annotations.lru_cache(scope='request', maxsize=2)
212
+ def is_consolidation_mode(on_api_restart: bool = False) -> bool:
207
213
  if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
208
214
  return True
209
215
 
210
- consolidation_mode = skypilot_config.get_nested(
216
+ config_consolidation_mode = skypilot_config.get_nested(
211
217
  ('jobs', 'controller', 'consolidation_mode'), default_value=False)
218
+
219
+ signal_file = pathlib.Path(
220
+ _JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE).expanduser()
221
+
222
+ restart_signal_file_exists = signal_file.exists()
223
+ consolidation_mode = (config_consolidation_mode and
224
+ restart_signal_file_exists)
225
+
226
+ if on_api_restart:
227
+ if config_consolidation_mode:
228
+ signal_file.touch()
229
+ else:
230
+ if not restart_signal_file_exists:
231
+ if config_consolidation_mode:
232
+ logger.warning(f'{colorama.Fore.YELLOW}Consolidation mode for '
233
+ 'managed jobs is enabled in the server config, '
234
+ 'but the API server has not been restarted yet. '
235
+ 'Please restart the API server to enable it.'
236
+ f'{colorama.Style.RESET_ALL}')
237
+ return False
238
+ elif not config_consolidation_mode:
239
+ # Cleanup the signal file if the consolidation mode is disabled in
240
+ # the config. This allow the user to disable the consolidation mode
241
+ # without restarting the API server.
242
+ signal_file.unlink()
243
+
212
244
  # We should only do this check on API server, as the controller will not
213
245
  # have related config and will always seemingly disabled for consolidation
214
246
  # mode. Check #6611 for more details.
@@ -269,8 +301,7 @@ def ha_recovery_for_consolidation_mode():
269
301
 
270
302
  async def get_job_status(
271
303
  backend: 'backends.CloudVmRayBackend', cluster_name: str,
272
- job_id: Optional[int],
273
- job_logger: logging.Logger) -> Optional['job_lib.JobStatus']:
304
+ job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
274
305
  """Check the status of the job running on a managed job cluster.
275
306
 
276
307
  It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
@@ -282,26 +313,28 @@ async def get_job_status(
282
313
  if handle is None:
283
314
  # This can happen if the cluster was preempted and background status
284
315
  # refresh already noticed and cleaned it up.
285
- job_logger.info(f'Cluster {cluster_name} not found.')
316
+ logger.info(f'Cluster {cluster_name} not found.')
286
317
  return None
287
318
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
288
319
  job_ids = None if job_id is None else [job_id]
289
320
  for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
290
321
  try:
291
- job_logger.info('=== Checking the job status... ===')
292
- statuses = await context_utils.to_thread(backend.get_job_status,
293
- handle,
294
- job_ids=job_ids,
295
- stream_logs=False)
322
+ logger.info('=== Checking the job status... ===')
323
+ statuses = await asyncio.wait_for(
324
+ context_utils.to_thread(backend.get_job_status,
325
+ handle,
326
+ job_ids=job_ids,
327
+ stream_logs=False),
328
+ timeout=_JOB_STATUS_FETCH_TIMEOUT_SECONDS)
296
329
  status = list(statuses.values())[0]
297
330
  if status is None:
298
- job_logger.info('No job found.')
331
+ logger.info('No job found.')
299
332
  else:
300
- job_logger.info(f'Job status: {status}')
301
- job_logger.info('=' * 34)
333
+ logger.info(f'Job status: {status}')
334
+ logger.info('=' * 34)
302
335
  return status
303
336
  except (exceptions.CommandError, grpc.RpcError, grpc.FutureTimeoutError,
304
- ValueError, TypeError) as e:
337
+ ValueError, TypeError, asyncio.TimeoutError) as e:
305
338
  # Note: Each of these exceptions has some additional conditions to
306
339
  # limit how we handle it and whether or not we catch it.
307
340
  # Retry on k8s transient network errors. This is useful when using
@@ -322,6 +355,9 @@ async def get_job_status(
322
355
  is_transient_error = True
323
356
  elif isinstance(e, grpc.FutureTimeoutError):
324
357
  detailed_reason = 'Timeout'
358
+ elif isinstance(e, asyncio.TimeoutError):
359
+ detailed_reason = ('Job status check timed out after '
360
+ f'{_JOB_STATUS_FETCH_TIMEOUT_SECONDS}s')
325
361
  # TODO(cooperc): Gracefully handle these exceptions in the backend.
326
362
  elif isinstance(e, ValueError):
327
363
  # If the cluster yaml is deleted in the middle of getting the
sky/metrics/utils.py CHANGED
@@ -48,8 +48,15 @@ SKY_APISERVER_CODE_DURATION_SECONDS = prom.Histogram(
48
48
  'sky_apiserver_code_duration_seconds',
49
49
  'Time spent processing code',
50
50
  ['name', 'group'],
51
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
52
- 60.0, 120.0, float('inf')),
51
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
52
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
53
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
54
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
55
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
56
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
57
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
58
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
59
+ 960.0, 980.0, 1000.0, float('inf')),
53
60
  )
54
61
 
55
62
  # Total number of API server requests, grouped by path, method, and status.
@@ -65,16 +72,30 @@ SKY_APISERVER_REQUEST_DURATION_SECONDS = prom.Histogram(
65
72
  'sky_apiserver_request_duration_seconds',
66
73
  'Time spent processing API server requests',
67
74
  ['path', 'method', 'status'],
68
- buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
69
- 60.0, 120.0, float('inf')),
75
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
76
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
77
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
78
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
79
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
80
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
81
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
82
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
83
+ 960.0, 980.0, 1000.0, float('inf')),
70
84
  )
71
85
 
72
86
  SKY_APISERVER_EVENT_LOOP_LAG_SECONDS = prom.Histogram(
73
87
  'sky_apiserver_event_loop_lag_seconds',
74
88
  'Scheduling delay of the server event loop',
75
89
  ['pid'],
76
- buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 20.0,
77
- 60.0, float('inf')),
90
+ buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
91
+ 0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
92
+ 5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
93
+ 50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
94
+ 240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
95
+ 420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
96
+ 600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
97
+ 780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
98
+ 960.0, 980.0, 1000.0, float('inf')),
78
99
  )
79
100
 
80
101
  SKY_APISERVER_WEBSOCKET_CONNECTIONS = prom.Gauge(
sky/provision/common.py CHANGED
@@ -97,6 +97,8 @@ class InstanceInfo:
97
97
  external_ip: Optional[str]
98
98
  tags: Dict[str, str]
99
99
  ssh_port: int = 22
100
+ # The internal service address of the instance on Kubernetes.
101
+ internal_svc: Optional[str] = None
100
102
 
101
103
  def get_feasible_ip(self) -> str:
102
104
  """Get the most feasible IPs of the instance. This function returns
@@ -434,8 +434,16 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
434
434
  # use the external IP of the head node.
435
435
  use_external_ip = cluster_info.custom_ray_options.pop(
436
436
  'use_external_ip', False)
437
- head_ip = (head_instance.internal_ip
438
- if not use_external_ip else head_instance.external_ip)
437
+
438
+ if use_external_ip:
439
+ head_ip = head_instance.external_ip
440
+ else:
441
+ # For Kubernetes, use the internal service address of the head node.
442
+ # Keep this consistent with the logic in kubernetes-ray.yml.j2
443
+ if head_instance.internal_svc:
444
+ head_ip = head_instance.internal_svc
445
+ else:
446
+ head_ip = head_instance.internal_ip
439
447
 
440
448
  ray_cmd = ray_worker_start_command(custom_resource,
441
449
  cluster_info.custom_ray_options,
@@ -959,12 +959,19 @@ def _create_pods(region: str, cluster_name: str, cluster_name_on_cloud: str,
959
959
 
960
960
  def _create_resource_thread(i: int):
961
961
  pod_spec_copy = copy.deepcopy(pod_spec)
962
- if head_pod_name is None and i == 0:
963
- # First pod should be head if no head exists
964
- pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS)
965
- head_selector = _head_service_selector(cluster_name_on_cloud)
966
- pod_spec_copy['metadata']['labels'].update(head_selector)
967
- pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head'
962
+ # 0 is for head pod, while 1+ is for worker pods.
963
+ if i == 0:
964
+ if head_pod_name is None:
965
+ # First pod should be head if no head exists
966
+ pod_spec_copy['metadata']['labels'].update(
967
+ constants.HEAD_NODE_TAGS)
968
+ head_selector = _head_service_selector(cluster_name_on_cloud)
969
+ pod_spec_copy['metadata']['labels'].update(head_selector)
970
+ pod_spec_copy['metadata'][
971
+ 'name'] = f'{cluster_name_on_cloud}-head'
972
+ else:
973
+ # If head pod already exists, we skip creating it.
974
+ return
968
975
  else:
969
976
  # Worker pods
970
977
  pod_spec_copy['metadata']['labels'].update(
@@ -1105,9 +1112,16 @@ def _create_pods(region: str, cluster_name: str, cluster_name_on_cloud: str,
1105
1112
  'and then up the cluster again.')
1106
1113
  raise exceptions.InconsistentHighAvailabilityError(message)
1107
1114
 
1108
- # Create pods in parallel
1109
- created_resources = subprocess_utils.run_in_parallel(
1110
- _create_resource_thread, list(range(to_start_count)), _NUM_THREADS)
1115
+ created_resources = []
1116
+ if to_start_count > 0:
1117
+ # Create pods in parallel.
1118
+ # Use `config.count` instead of `to_start_count` to keep the index of
1119
+ # the Pods consistent especially for the case where some Pods are down
1120
+ # due to node failure or manual termination, etc. and then launch
1121
+ # again to create the Pods back.
1122
+ # The existing Pods will be skipped in _create_resource_thread.
1123
+ created_resources = subprocess_utils.run_in_parallel(
1124
+ _create_resource_thread, list(range(config.count)), _NUM_THREADS)
1111
1125
 
1112
1126
  if to_create_deployment:
1113
1127
  deployments = copy.deepcopy(created_resources)
@@ -1350,6 +1364,9 @@ def get_cluster_info(
1350
1364
  external_ip=None,
1351
1365
  ssh_port=port,
1352
1366
  tags=pod.metadata.labels,
1367
+ # TODO(hailong): `cluster.local` may need to be configurable
1368
+ # Service name is same as the pod name for now.
1369
+ internal_svc=f'{pod_name}.{namespace}.svc.cluster.local',
1353
1370
  )
1354
1371
  ]
1355
1372
  if _is_head(pod):
@@ -1388,6 +1405,13 @@ def get_cluster_info(
1388
1405
  logger.debug(
1389
1406
  f'Using ssh user {ssh_user} for cluster {cluster_name_on_cloud}')
1390
1407
 
1408
+ # cpu_request may be a string like `100m`, need to parse and convert
1409
+ num_cpus = kubernetes_utils.parse_cpu_or_gpu_resource_to_float(cpu_request)
1410
+ # 'num-cpus' for ray must be an integer, but we should not set it to 0 if
1411
+ # cpus is <1.
1412
+ # Keep consistent with the logic in clouds/kubernetes.py
1413
+ str_cpus = str(max(int(num_cpus), 1))
1414
+
1391
1415
  return common.ClusterInfo(
1392
1416
  instances=pods,
1393
1417
  head_instance_id=head_pod_name,
@@ -1397,7 +1421,7 @@ def get_cluster_info(
1397
1421
  # problems for other pods.
1398
1422
  custom_ray_options={
1399
1423
  'object-store-memory': 500000000,
1400
- 'num-cpus': cpu_request,
1424
+ 'num-cpus': str_cpus,
1401
1425
  },
1402
1426
  provider_name='kubernetes',
1403
1427
  provider_config=provider_config)
@@ -1299,30 +1299,52 @@ class V1Pod:
1299
1299
 
1300
1300
 
1301
1301
  @_retry_on_error(resource_type='pod')
1302
- def get_all_pods_in_kubernetes_cluster(*,
1303
- context: Optional[str] = None
1304
- ) -> List[V1Pod]:
1305
- """Gets pods in all namespaces in kubernetes cluster indicated by context.
1306
-
1307
- Used for computing cluster resource usage.
1302
+ def get_allocated_gpu_qty_by_node(
1303
+ *,
1304
+ context: Optional[str] = None,
1305
+ ) -> Dict[str, int]:
1306
+ """Gets allocated GPU quantity by each node by fetching pods in
1307
+ all namespaces in kubernetes cluster indicated by context.
1308
1308
  """
1309
1309
  if context is None:
1310
1310
  context = get_current_kube_config_context_name()
1311
+ non_included_pod_statuses = POD_STATUSES.copy()
1312
+ status_filters = ['Running', 'Pending']
1313
+ if status_filters is not None:
1314
+ non_included_pod_statuses -= set(status_filters)
1315
+ field_selector = ','.join(
1316
+ [f'status.phase!={status}' for status in non_included_pod_statuses])
1311
1317
 
1312
1318
  # Return raw urllib3.HTTPResponse object so that we can parse the json
1313
1319
  # more efficiently.
1314
1320
  response = kubernetes.core_api(context).list_pod_for_all_namespaces(
1315
- _request_timeout=kubernetes.API_TIMEOUT, _preload_content=False)
1321
+ _request_timeout=kubernetes.API_TIMEOUT,
1322
+ _preload_content=False,
1323
+ field_selector=field_selector)
1316
1324
  try:
1317
- pods = [
1318
- V1Pod.from_dict(item_dict) for item_dict in ijson.items(
1319
- response, 'items.item', buf_size=IJSON_BUFFER_SIZE)
1320
- ]
1325
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
1326
+ for item_dict in ijson.items(response,
1327
+ 'items.item',
1328
+ buf_size=IJSON_BUFFER_SIZE):
1329
+ pod = V1Pod.from_dict(item_dict)
1330
+ if should_exclude_pod_from_gpu_allocation(pod):
1331
+ logger.debug(
1332
+ f'Excluding pod {pod.metadata.name} from GPU count '
1333
+ f'calculations on node {pod.spec.node_name}')
1334
+ continue
1335
+ # Iterate over all the containers in the pod and sum the
1336
+ # GPU requests
1337
+ pod_allocated_qty = 0
1338
+ for container in pod.spec.containers:
1339
+ if container.resources.requests:
1340
+ pod_allocated_qty += get_node_accelerator_count(
1341
+ context, container.resources.requests)
1342
+ if pod_allocated_qty > 0 and pod.spec.node_name:
1343
+ allocated_qty_by_node[pod.spec.node_name] += pod_allocated_qty
1344
+ return allocated_qty_by_node
1321
1345
  finally:
1322
1346
  response.release_conn()
1323
1347
 
1324
- return pods
1325
-
1326
1348
 
1327
1349
  def check_instance_fits(context: Optional[str],
1328
1350
  instance: str) -> Tuple[bool, Optional[str]]:
@@ -2219,6 +2241,15 @@ def get_kube_config_context_namespace(
2219
2241
  return DEFAULT_NAMESPACE
2220
2242
 
2221
2243
 
2244
+ def parse_cpu_or_gpu_resource_to_float(resource_str: str) -> float:
2245
+ if not resource_str:
2246
+ return 0.0
2247
+ if resource_str[-1] == 'm':
2248
+ return float(resource_str[:-1]) / 1000
2249
+ else:
2250
+ return float(resource_str)
2251
+
2252
+
2222
2253
  def parse_cpu_or_gpu_resource(resource_qty_str: str) -> Union[int, float]:
2223
2254
  resource_str = str(resource_qty_str)
2224
2255
  if resource_str[-1] == 'm':
@@ -3006,41 +3037,24 @@ def get_kubernetes_node_info(
3006
3037
  label_keys = lf.get_label_keys()
3007
3038
 
3008
3039
  # Check if all nodes have no accelerators to avoid fetching pods
3009
- any_node_has_accelerators = False
3040
+ has_accelerator_nodes = False
3010
3041
  for node in nodes:
3011
3042
  accelerator_count = get_node_accelerator_count(context,
3012
3043
  node.status.allocatable)
3013
3044
  if accelerator_count > 0:
3014
- any_node_has_accelerators = True
3045
+ has_accelerator_nodes = True
3015
3046
  break
3016
3047
 
3017
- # Get the pods to get the real-time resource usage
3018
- pods = None
3048
+ # Get the allocated GPU quantity by each node
3019
3049
  allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
3020
- if any_node_has_accelerators:
3050
+ error_on_get_allocated_gpu_qty_by_node = False
3051
+ if has_accelerator_nodes:
3021
3052
  try:
3022
- pods = get_all_pods_in_kubernetes_cluster(context=context)
3023
- # Pre-compute allocated accelerator count per node
3024
- for pod in pods:
3025
- if pod.status.phase in ['Running', 'Pending']:
3026
- # Skip pods that should not count against GPU count
3027
- if should_exclude_pod_from_gpu_allocation(pod):
3028
- logger.debug(f'Excluding low priority pod '
3029
- f'{pod.metadata.name} from GPU allocation '
3030
- f'calculations')
3031
- continue
3032
- # Iterate over all the containers in the pod and sum the
3033
- # GPU requests
3034
- pod_allocated_qty = 0
3035
- for container in pod.spec.containers:
3036
- if container.resources.requests:
3037
- pod_allocated_qty += get_node_accelerator_count(
3038
- context, container.resources.requests)
3039
- if pod_allocated_qty > 0:
3040
- allocated_qty_by_node[
3041
- pod.spec.node_name] += pod_allocated_qty
3053
+ allocated_qty_by_node = get_allocated_gpu_qty_by_node(
3054
+ context=context)
3042
3055
  except kubernetes.api_exception() as e:
3043
3056
  if e.status == 403:
3057
+ error_on_get_allocated_gpu_qty_by_node = True
3044
3058
  pass
3045
3059
  else:
3046
3060
  raise
@@ -3085,7 +3099,7 @@ def get_kubernetes_node_info(
3085
3099
  ip_address=node_ip)
3086
3100
  continue
3087
3101
 
3088
- if pods is None:
3102
+ if not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
3089
3103
  accelerators_available = -1
3090
3104
  else:
3091
3105
  allocated_qty = allocated_qty_by_node[node.metadata.name]
sky/server/common.py CHANGED
@@ -554,8 +554,8 @@ def _start_api_server(deploy: bool = False,
554
554
  # pylint: disable=import-outside-toplevel
555
555
  import sky.jobs.utils as job_utils
556
556
  max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
557
- if job_utils.is_consolidation_mode() else
558
- server_constants.MIN_AVAIL_MEM_GB)
557
+ if job_utils.is_consolidation_mode(on_api_restart=True)
558
+ else server_constants.MIN_AVAIL_MEM_GB)
559
559
  if avail_mem_size_gb <= max_memory:
560
560
  logger.warning(
561
561
  f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only '
@@ -571,6 +571,8 @@ def _start_api_server(deploy: bool = False,
571
571
  args += [f'--host={host}']
572
572
  if metrics_port is not None:
573
573
  args += [f'--metrics-port={metrics_port}']
574
+ # Use this argument to disable the internal signal file check.
575
+ args += ['--start-with-python']
574
576
 
575
577
  if foreground:
576
578
  # Replaces the current process with the API server
@@ -424,6 +424,7 @@ def _request_execution_wrapper(request_id: str,
424
424
  os.close(original_stderr)
425
425
  original_stderr = None
426
426
 
427
+ request_name = None
427
428
  try:
428
429
  # As soon as the request is updated with the executor PID, we can
429
430
  # receive SIGTERM from cancellation. So, we update the request inside
@@ -515,7 +516,8 @@ def _request_execution_wrapper(request_id: str,
515
516
  annotations.clear_request_level_cache()
516
517
  with metrics_utils.time_it(name='release_memory', group='internal'):
517
518
  common_utils.release_memory()
518
- _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
519
+ if request_name is not None:
520
+ _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
519
521
  except Exception as e: # pylint: disable=broad-except
520
522
  logger.error(f'Failed to record memory metrics: '
521
523
  f'{common_utils.format_exception(e)}')
@@ -112,10 +112,8 @@ class Precondition(abc.ABC):
112
112
  return True
113
113
  if status_msg is not None and status_msg != last_status_msg:
114
114
  # Update the status message if it has changed.
115
- async with api_requests.update_request_async(
116
- self.request_id) as req:
117
- assert req is not None, self.request_id
118
- req.status_msg = status_msg
115
+ await api_requests.update_status_msg_async(
116
+ self.request_id, status_msg)
119
117
  last_status_msg = status_msg
120
118
  except (Exception, SystemExit, KeyboardInterrupt) as e: # pylint: disable=broad-except
121
119
  api_requests.set_request_failed(self.request_id, e)
@@ -14,8 +14,8 @@ import sqlite3
14
14
  import threading
15
15
  import time
16
16
  import traceback
17
- from typing import (Any, AsyncContextManager, Callable, Dict, Generator, List,
18
- NamedTuple, Optional, Tuple)
17
+ from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
18
+ Tuple)
19
19
 
20
20
  import anyio
21
21
  import colorama
@@ -32,6 +32,7 @@ from sky.server import daemons
32
32
  from sky.server.requests import payloads
33
33
  from sky.server.requests.serializers import decoders
34
34
  from sky.server.requests.serializers import encoders
35
+ from sky.utils import asyncio_utils
35
36
  from sky.utils import common_utils
36
37
  from sky.utils import ux_utils
37
38
  from sky.utils.db import db_utils
@@ -578,27 +579,14 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
578
579
 
579
580
  @init_db
580
581
  @metrics_lib.time_me
581
- def update_request_async(
582
- request_id: str) -> AsyncContextManager[Optional[Request]]:
583
- """Async version of update_request.
584
-
585
- Returns an async context manager that yields the request record and
586
- persists any in-place updates upon exit.
587
- """
588
-
589
- @contextlib.asynccontextmanager
590
- async def _cm():
591
- # Acquire the lock to avoid race conditions between multiple request
592
- # operations, e.g. execute and cancel.
593
- async with filelock.AsyncFileLock(request_lock_path(request_id)):
594
- request = await _get_request_no_lock_async(request_id)
595
- try:
596
- yield request
597
- finally:
598
- if request is not None:
599
- await _add_or_update_request_no_lock_async(request)
600
-
601
- return _cm()
582
+ @asyncio_utils.shield
583
+ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
584
+ """Update the status message of a request"""
585
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
586
+ request = await _get_request_no_lock_async(request_id)
587
+ if request is not None:
588
+ request.status_msg = status_msg
589
+ await _add_or_update_request_no_lock_async(request)
602
590
 
603
591
 
604
592
  _get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
@@ -651,6 +639,7 @@ def get_request(request_id: str) -> Optional[Request]:
651
639
 
652
640
  @init_db_async
653
641
  @metrics_lib.time_me_async
642
+ @asyncio_utils.shield
654
643
  async def get_request_async(request_id: str) -> Optional[Request]:
655
644
  """Async version of get_request."""
656
645
  async with filelock.AsyncFileLock(request_lock_path(request_id)):
@@ -704,6 +693,7 @@ def create_if_not_exists(request: Request) -> bool:
704
693
 
705
694
  @init_db_async
706
695
  @metrics_lib.time_me_async
696
+ @asyncio_utils.shield
707
697
  async def create_if_not_exists_async(request: Request) -> bool:
708
698
  """Async version of create_if_not_exists."""
709
699
  async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
sky/server/server.py CHANGED
@@ -1968,6 +1968,7 @@ if __name__ == '__main__':
1968
1968
  # Serve metrics on a separate port to isolate it from the application APIs:
1969
1969
  # metrics port will not be exposed to the public network typically.
1970
1970
  parser.add_argument('--metrics-port', default=9090, type=int)
1971
+ parser.add_argument('--start-with-python', action='store_true')
1971
1972
  cmd_args = parser.parse_args()
1972
1973
  if cmd_args.port == cmd_args.metrics_port:
1973
1974
  logger.error('port and metrics-port cannot be the same, exiting.')
@@ -1982,6 +1983,10 @@ if __name__ == '__main__':
1982
1983
  logger.error(f'Port {cmd_args.port} is not available, exiting.')
1983
1984
  raise RuntimeError(f'Port {cmd_args.port} is not available')
1984
1985
 
1986
+ if not cmd_args.start_with_python:
1987
+ # Maybe touch the signal file on API server startup.
1988
+ managed_job_utils.is_consolidation_mode(on_api_restart=True)
1989
+
1985
1990
  # Show the privacy policy if it is not already shown. We place it here so
1986
1991
  # that it is shown only when the API server is started.
1987
1992
  usage_lib.maybe_show_privacy_policy()
sky/sky_logging.py CHANGED
@@ -109,7 +109,6 @@ def _setup_logger():
109
109
  global _default_handler
110
110
  if _default_handler is None:
111
111
  _default_handler = EnvAwareHandler(sys.stdout)
112
- _default_handler.flush = sys.stdout.flush # type: ignore
113
112
  if env_options.Options.SHOW_DEBUG_INFO.get():
114
113
  _default_handler.setLevel(logging.DEBUG)
115
114
  else:
@@ -129,7 +128,6 @@ def _setup_logger():
129
128
  for logger_name in _SENSITIVE_LOGGER:
130
129
  logger = logging.getLogger(logger_name)
131
130
  handler_to_logger = EnvAwareHandler(sys.stdout, sensitive=True)
132
- handler_to_logger.flush = sys.stdout.flush # type: ignore
133
131
  logger.addHandler(handler_to_logger)
134
132
  logger.setLevel(logging.INFO)
135
133
  if _show_logging_prefix():