skypilot-nightly 1.0.0.dev20250827__py3-none-any.whl → 1.0.0.dev20250829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (86) hide show
  1. sky/__init__.py +2 -2
  2. sky/admin_policy.py +11 -10
  3. sky/authentication.py +1 -1
  4. sky/backends/backend.py +3 -5
  5. sky/backends/backend_utils.py +140 -52
  6. sky/backends/cloud_vm_ray_backend.py +30 -25
  7. sky/backends/local_docker_backend.py +3 -8
  8. sky/backends/wheel_utils.py +35 -8
  9. sky/client/cli/command.py +41 -9
  10. sky/client/sdk.py +23 -8
  11. sky/client/sdk_async.py +6 -2
  12. sky/clouds/aws.py +118 -1
  13. sky/core.py +1 -4
  14. sky/dashboard/out/404.html +1 -1
  15. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  16. sky/dashboard/out/clusters/[cluster].html +1 -1
  17. sky/dashboard/out/clusters.html +1 -1
  18. sky/dashboard/out/config.html +1 -1
  19. sky/dashboard/out/index.html +1 -1
  20. sky/dashboard/out/infra/[context].html +1 -1
  21. sky/dashboard/out/infra.html +1 -1
  22. sky/dashboard/out/jobs/[job].html +1 -1
  23. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  24. sky/dashboard/out/jobs.html +1 -1
  25. sky/dashboard/out/users.html +1 -1
  26. sky/dashboard/out/volumes.html +1 -1
  27. sky/dashboard/out/workspace/new.html +1 -1
  28. sky/dashboard/out/workspaces/[name].html +1 -1
  29. sky/dashboard/out/workspaces.html +1 -1
  30. sky/global_user_state.py +82 -22
  31. sky/jobs/client/sdk.py +5 -2
  32. sky/jobs/recovery_strategy.py +9 -4
  33. sky/jobs/server/server.py +2 -1
  34. sky/logs/agent.py +2 -2
  35. sky/logs/aws.py +6 -3
  36. sky/provision/aws/config.py +78 -3
  37. sky/provision/aws/instance.py +45 -6
  38. sky/provision/do/utils.py +2 -1
  39. sky/provision/kubernetes/instance.py +55 -11
  40. sky/provision/kubernetes/utils.py +11 -2
  41. sky/provision/nebius/utils.py +36 -2
  42. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  43. sky/serve/client/impl.py +5 -4
  44. sky/serve/replica_managers.py +4 -3
  45. sky/serve/serve_utils.py +2 -2
  46. sky/serve/server/impl.py +3 -2
  47. sky/serve/server/server.py +2 -1
  48. sky/server/auth/oauth2_proxy.py +10 -4
  49. sky/server/common.py +4 -4
  50. sky/server/daemons.py +16 -5
  51. sky/server/requests/executor.py +5 -3
  52. sky/server/requests/payloads.py +3 -1
  53. sky/server/requests/preconditions.py +3 -2
  54. sky/server/requests/requests.py +121 -19
  55. sky/server/server.py +85 -60
  56. sky/server/stream_utils.py +7 -5
  57. sky/setup_files/dependencies.py +6 -1
  58. sky/sky_logging.py +28 -0
  59. sky/skylet/constants.py +6 -0
  60. sky/skylet/events.py +2 -3
  61. sky/skypilot_config.py +10 -10
  62. sky/task.py +1 -1
  63. sky/templates/aws-ray.yml.j2 +1 -0
  64. sky/templates/nebius-ray.yml.j2 +4 -8
  65. sky/usage/usage_lib.py +3 -2
  66. sky/utils/annotations.py +8 -2
  67. sky/utils/cluster_utils.py +3 -3
  68. sky/utils/common_utils.py +0 -72
  69. sky/utils/controller_utils.py +4 -3
  70. sky/utils/dag_utils.py +4 -4
  71. sky/utils/db/db_utils.py +11 -0
  72. sky/utils/db/migration_utils.py +1 -1
  73. sky/utils/kubernetes/config_map_utils.py +3 -3
  74. sky/utils/kubernetes_enums.py +1 -0
  75. sky/utils/lock_events.py +94 -0
  76. sky/utils/schemas.py +3 -0
  77. sky/utils/timeline.py +24 -93
  78. sky/utils/yaml_utils.py +77 -10
  79. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/METADATA +8 -2
  80. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/RECORD +86 -84
  81. /sky/dashboard/out/_next/static/{-eL7Ky3bxVivzeLHNB9U6 → hYJYFIxp_ZFONR4wTIJqZ}/_buildManifest.js +0 -0
  82. /sky/dashboard/out/_next/static/{-eL7Ky3bxVivzeLHNB9U6 → hYJYFIxp_ZFONR4wTIJqZ}/_ssgManifest.js +0 -0
  83. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/WHEEL +0 -0
  84. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/entry_points.txt +0 -0
  85. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/licenses/LICENSE +0 -0
  86. {skypilot_nightly-1.0.0.dev20250827.dist-info → skypilot_nightly-1.0.0.dev20250829.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """Kubernetes instance provisioning."""
2
2
  import copy
3
+ import datetime
3
4
  import json
4
5
  import re
5
6
  import time
@@ -1254,9 +1255,11 @@ def get_cluster_info(
1254
1255
  provider_config=provider_config)
1255
1256
 
1256
1257
 
1257
- def _get_pod_termination_reason(pod: Any) -> str:
1258
+ def _get_pod_termination_reason(pod: Any, cluster_name: str) -> str:
1259
+ """Get pod termination reason and write to cluster events."""
1258
1260
  reasons = []
1259
- if pod.status.container_statuses:
1261
+ latest_timestamp = pod.status.start_time or datetime.datetime.min
1262
+ if pod.status and pod.status.container_statuses:
1260
1263
  for container_status in pod.status.container_statuses:
1261
1264
  terminated = container_status.state.terminated
1262
1265
  if terminated:
@@ -1264,20 +1267,38 @@ def _get_pod_termination_reason(pod: Any) -> str:
1264
1267
  reason = terminated.reason
1265
1268
  if exit_code == 0:
1266
1269
  # skip exit 0 (non-failed) just for sanity
1270
+ logger.debug(f'{pod.metadata.name}/{container_status.name} '
1271
+ 'had exit code 0. Skipping.')
1267
1272
  continue
1268
1273
  if reason is None:
1269
1274
  # just in-case reason is None, have default for debugging
1270
1275
  reason = f'exit({exit_code})'
1271
1276
  reasons.append(reason)
1277
+ if terminated.finished_at > latest_timestamp:
1278
+ latest_timestamp = terminated.finished_at
1279
+
1272
1280
  # TODO (kyuds): later, if needed, query `last_state` too.
1273
1281
 
1282
+ if not reasons:
1283
+ return ''
1284
+
1274
1285
  # Normally we will have a single container per pod for skypilot
1275
1286
  # but doing this just in-case there are multiple containers.
1276
- return ' | '.join(reasons)
1287
+ pod_reason = ' | '.join(reasons)
1288
+
1289
+ global_user_state.add_cluster_event(
1290
+ cluster_name,
1291
+ None,
1292
+ f'[kubernetes pod {pod.metadata.name} terminated] {pod_reason}',
1293
+ global_user_state.ClusterEventType.DEBUG,
1294
+ transitioned_at=int(latest_timestamp.timestamp()),
1295
+ )
1296
+ return pod_reason
1277
1297
 
1278
1298
 
1279
1299
  def _get_pod_missing_reason(context: Optional[str], namespace: str,
1280
1300
  cluster_name: str, pod_name: str) -> Optional[str]:
1301
+ """Get events for missing pod and write to cluster events."""
1281
1302
  logger.debug(f'Analyzing events for pod {pod_name}')
1282
1303
  pod_field_selector = (
1283
1304
  f'involvedObject.kind=Pod,involvedObject.name={pod_name}')
@@ -1293,6 +1314,8 @@ def _get_pod_missing_reason(context: Optional[str], namespace: str,
1293
1314
  last_scheduled_node = None
1294
1315
  insert_new_pod_event = True
1295
1316
  new_event_inserted = False
1317
+ inserted_pod_events = 0
1318
+
1296
1319
  for event in pod_events:
1297
1320
  if event.reason == 'Scheduled':
1298
1321
  pattern = r'Successfully assigned (\S+) to (\S+)'
@@ -1313,10 +1336,18 @@ def _get_pod_missing_reason(context: Optional[str], namespace: str,
1313
1336
  transitioned_at=int(
1314
1337
  event.metadata.creation_timestamp.timestamp()),
1315
1338
  expose_duplicate_error=True)
1339
+ logger.debug(f'[pod {pod_name}] encountered new pod event: '
1340
+ f'{event.metadata.creation_timestamp} '
1341
+ f'{event.reason} {event.message}')
1316
1342
  except db_utils.UniqueConstraintViolationError:
1317
1343
  insert_new_pod_event = False
1318
1344
  else:
1319
1345
  new_event_inserted = True
1346
+ inserted_pod_events += 1
1347
+
1348
+ logger.debug(f'[pod {pod_name}] processed {len(pod_events)} pod events and '
1349
+ f'inserted {inserted_pod_events} new pod events '
1350
+ 'previously unseen')
1320
1351
 
1321
1352
  if last_scheduled_node is not None:
1322
1353
  node_field_selector = ('involvedObject.kind=Node,'
@@ -1331,6 +1362,7 @@ def _get_pod_missing_reason(context: Optional[str], namespace: str,
1331
1362
  # latest event appears first
1332
1363
  reverse=True)
1333
1364
  insert_new_node_event = True
1365
+ inserted_node_events = 0
1334
1366
  for event in node_events:
1335
1367
  if insert_new_node_event:
1336
1368
  # Try inserting the latest events first. If the event is a
@@ -1345,10 +1377,23 @@ def _get_pod_missing_reason(context: Optional[str], namespace: str,
1345
1377
  transitioned_at=int(
1346
1378
  event.metadata.creation_timestamp.timestamp()),
1347
1379
  expose_duplicate_error=True)
1380
+ logger.debug(
1381
+ f'[pod {pod_name}] encountered new node event: '
1382
+ f'{event.metadata.creation_timestamp} '
1383
+ f'{event.reason} {event.message}')
1348
1384
  except db_utils.UniqueConstraintViolationError:
1349
1385
  insert_new_node_event = False
1350
1386
  else:
1351
1387
  new_event_inserted = True
1388
+ inserted_node_events += 1
1389
+
1390
+ logger.debug(f'[pod {pod_name}: node {last_scheduled_node}] '
1391
+ f'processed {len(node_events)} node events and '
1392
+ f'inserted {inserted_node_events} new node events '
1393
+ 'previously unseen')
1394
+ else:
1395
+ logger.debug(f'[pod {pod_name}] could not determine the node '
1396
+ 'the pod was scheduled to')
1352
1397
 
1353
1398
  if not new_event_inserted:
1354
1399
  # If new event is not inserted, there is no useful information to
@@ -1390,13 +1435,15 @@ def query_instances(
1390
1435
  provider_config: Optional[Dict[str, Any]] = None,
1391
1436
  non_terminated_only: bool = True
1392
1437
  ) -> Dict[str, Tuple[Optional['status_lib.ClusterStatus'], Optional[str]]]:
1438
+ # Mapping from pod phase to skypilot status. These are the only valid pod
1439
+ # phases.
1440
+ # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase
1393
1441
  status_map = {
1394
1442
  'Pending': status_lib.ClusterStatus.INIT,
1395
1443
  'Running': status_lib.ClusterStatus.UP,
1396
1444
  'Failed': status_lib.ClusterStatus.INIT,
1397
1445
  'Unknown': None,
1398
1446
  'Succeeded': None,
1399
- 'Terminating': None,
1400
1447
  }
1401
1448
 
1402
1449
  assert provider_config is not None
@@ -1440,18 +1487,15 @@ def query_instances(
1440
1487
  for pod in pods:
1441
1488
  phase = pod.status.phase
1442
1489
  pod_status = status_map[phase]
1490
+ reason = None
1491
+ if phase in ('Failed', 'Unknown'):
1492
+ reason = _get_pod_termination_reason(pod, cluster_name)
1493
+ logger.debug(f'Pod Status ({phase}) Reason(s): {reason}')
1443
1494
  if non_terminated_only and pod_status is None:
1444
1495
  logger.debug(f'Pod {pod.metadata.name} is terminated, but '
1445
1496
  'query_instances is called with '
1446
1497
  f'non_terminated_only=True. Phase: {phase}')
1447
- if phase == 'Failed':
1448
- reason_for_debug = _get_pod_termination_reason(pod)
1449
- logger.debug(f'Termination reason: {reason_for_debug}')
1450
1498
  continue
1451
- reason = None
1452
- if phase == 'Failed':
1453
- reason = _get_pod_termination_reason(pod)
1454
- logger.debug(f'Pod Status Reason(s): {reason}')
1455
1499
  pod_name = pod.metadata.name
1456
1500
  reason = f'{pod_name}: {reason}' if reason is not None else None
1457
1501
  cluster_status[pod_name] = (pod_status, reason)
@@ -1082,6 +1082,14 @@ class KarpenterAutoscaler(Autoscaler):
1082
1082
  can_query_backend: bool = False
1083
1083
 
1084
1084
 
1085
+ class CoreweaveAutoscaler(Autoscaler):
1086
+ """CoreWeave autoscaler
1087
+ """
1088
+
1089
+ label_formatter: Any = CoreWeaveLabelFormatter
1090
+ can_query_backend: bool = False
1091
+
1092
+
1085
1093
  class GenericAutoscaler(Autoscaler):
1086
1094
  """Generic autoscaler
1087
1095
  """
@@ -1094,6 +1102,7 @@ class GenericAutoscaler(Autoscaler):
1094
1102
  AUTOSCALER_TYPE_TO_AUTOSCALER = {
1095
1103
  kubernetes_enums.KubernetesAutoscalerType.GKE: GKEAutoscaler,
1096
1104
  kubernetes_enums.KubernetesAutoscalerType.KARPENTER: KarpenterAutoscaler,
1105
+ kubernetes_enums.KubernetesAutoscalerType.COREWEAVE: CoreweaveAutoscaler,
1097
1106
  kubernetes_enums.KubernetesAutoscalerType.GENERIC: GenericAutoscaler,
1098
1107
  }
1099
1108
 
@@ -2782,7 +2791,7 @@ def combine_pod_config_fields(
2782
2791
  kubernetes_config)
2783
2792
 
2784
2793
  # Write the updated YAML back to the file
2785
- common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2794
+ yaml_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2786
2795
 
2787
2796
 
2788
2797
  def combine_metadata_fields(cluster_yaml_path: str,
@@ -2834,7 +2843,7 @@ def combine_metadata_fields(cluster_yaml_path: str,
2834
2843
  config_utils.merge_k8s_configs(destination, custom_metadata)
2835
2844
 
2836
2845
  # Write the updated YAML back to the file
2837
- common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2846
+ yaml_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2838
2847
 
2839
2848
 
2840
2849
  def merge_custom_metadata(
@@ -14,6 +14,8 @@ logger = sky_logging.init_logger(__name__)
14
14
 
15
15
  POLL_INTERVAL = 5
16
16
 
17
+ _MAX_OPERATIONS_TO_FETCH = 1000
18
+
17
19
 
18
20
  def retry(func):
19
21
  """Decorator to retry a function."""
@@ -321,11 +323,43 @@ def launch(cluster_name_on_cloud: str,
321
323
  parent_id=project_id,
322
324
  name=instance_name,
323
325
  )))
326
+ instance_id = instance.metadata.id
324
327
  if instance.status.state.name == 'STARTING':
325
- instance_id = instance.metadata.id
326
328
  break
329
+
330
+ # All Instances initially have state=STOPPED and reconciling=True,
331
+ # so we need to wait until reconciling is False.
332
+ if instance.status.state.name == 'STOPPED' and \
333
+ not instance.status.reconciling:
334
+ next_token = ''
335
+ total_operations = 0
336
+ while True:
337
+ operations_response = nebius.sync_call(
338
+ service.list_operations_by_parent(
339
+ nebius.compute().ListOperationsByParentRequest(
340
+ parent_id=project_id,
341
+ page_size=100,
342
+ page_token=next_token,
343
+ )))
344
+ total_operations += len(operations_response.operations)
345
+ for operation in operations_response.operations:
346
+ # Find the most recent operation for the instance.
347
+ if operation.resource_id == instance_id:
348
+ error_msg = operation.description
349
+ if operation.status:
350
+ error_msg += f' {operation.status.message}'
351
+ raise RuntimeError(error_msg)
352
+ # If we've fetched too many operations, or there are no more
353
+ # operations to fetch, just raise a generic error.
354
+ if total_operations > _MAX_OPERATIONS_TO_FETCH or \
355
+ not operations_response.next_page_token:
356
+ raise RuntimeError(
357
+ f'Instance {instance_name} failed to start.')
358
+ next_token = operations_response.next_page_token
327
359
  time.sleep(POLL_INTERVAL)
328
- logger.debug(f'Waiting for instance {instance_name} start running.')
360
+ logger.debug(f'Waiting for instance {instance_name} to start running. '
361
+ f'State: {instance.status.state.name}, '
362
+ f'Reconciling: {instance.status.reconciling}')
329
363
  retry_count += 1
330
364
 
331
365
  if retry_count == nebius.MAX_RETRIES_TO_INSTANCE_READY:
@@ -0,0 +1,34 @@
1
+ """Add request_id to cluster_events.
2
+
3
+ Revision ID: 007
4
+ Revises: 006
5
+ Create Date: 2025-08-28
6
+
7
+ """
8
+ # pylint: disable=invalid-name
9
+ from typing import Sequence, Union
10
+
11
+ from alembic import op
12
+ import sqlalchemy as sa
13
+
14
+ from sky.utils.db import db_utils
15
+
16
+ # revision identifiers, used by Alembic.
17
+ revision: str = '007'
18
+ down_revision: Union[str, Sequence[str], None] = '006'
19
+ branch_labels: Union[str, Sequence[str], None] = None
20
+ depends_on: Union[str, Sequence[str], None] = None
21
+
22
+
23
+ def upgrade():
24
+ """Add request_id column to cluster_events."""
25
+ with op.get_context().autocommit_block():
26
+ db_utils.add_column_to_table_alembic('cluster_events',
27
+ 'request_id',
28
+ sa.Text(),
29
+ server_default=None)
30
+
31
+
32
+ def downgrade():
33
+ """No-op for backward compatibility."""
34
+ pass
sky/serve/client/impl.py CHANGED
@@ -224,10 +224,11 @@ def tail_logs(service_name: str,
224
224
  stream=True)
225
225
  request_id: server_common.RequestId[None] = server_common.get_request_id(
226
226
  response)
227
- return sdk.stream_response(request_id=request_id,
228
- response=response,
229
- output_stream=output_stream,
230
- resumable=True)
227
+ sdk.stream_response(request_id=request_id,
228
+ response=response,
229
+ output_stream=output_stream,
230
+ resumable=True,
231
+ get_result=follow)
231
232
 
232
233
 
233
234
  def sync_down_logs(service_name: str,
@@ -37,6 +37,7 @@ from sky.utils import env_options
37
37
  from sky.utils import resources_utils
38
38
  from sky.utils import status_lib
39
39
  from sky.utils import ux_utils
40
+ from sky.utils import yaml_utils
40
41
 
41
42
  if typing.TYPE_CHECKING:
42
43
  from sky.serve import service_spec
@@ -79,7 +80,7 @@ def launch_cluster(replica_id: int,
79
80
  f'{cluster_name} with resources override: '
80
81
  f'{resources_override}')
81
82
  try:
82
- config = common_utils.read_yaml(
83
+ config = yaml_utils.read_yaml(
83
84
  os.path.expanduser(service_task_yaml_path))
84
85
  task = task_lib.Task.from_yaml_config(config)
85
86
  if resources_override is not None:
@@ -1397,7 +1398,7 @@ class SkyPilotReplicaManager(ReplicaManager):
1397
1398
  # the latest version. This can significantly improve the speed
1398
1399
  # for updating an existing service with only config changes to the
1399
1400
  # service specs, e.g. scale down the service.
1400
- new_config = common_utils.read_yaml(
1401
+ new_config = yaml_utils.read_yaml(
1401
1402
  os.path.expanduser(service_task_yaml_path))
1402
1403
  # Always create new replicas and scale down old ones when file_mounts
1403
1404
  # are not empty.
@@ -1414,7 +1415,7 @@ class SkyPilotReplicaManager(ReplicaManager):
1414
1415
  old_service_task_yaml_path = (
1415
1416
  serve_utils.generate_task_yaml_file_name(
1416
1417
  self._service_name, info.version))
1417
- old_config = common_utils.read_yaml(
1418
+ old_config = yaml_utils.read_yaml(
1418
1419
  os.path.expanduser(old_service_task_yaml_path))
1419
1420
  for key in ['service', 'pool', '_user_specified_yaml']:
1420
1421
  old_config.pop(key, None)
sky/serve/serve_utils.py CHANGED
@@ -699,7 +699,7 @@ def _get_service_status(
699
699
  if record['pool']:
700
700
  latest_yaml_path = generate_task_yaml_file_name(service_name,
701
701
  record['version'])
702
- raw_yaml_config = common_utils.read_yaml(latest_yaml_path)
702
+ raw_yaml_config = yaml_utils.read_yaml(latest_yaml_path)
703
703
  original_config = raw_yaml_config.get('_user_specified_yaml')
704
704
  if original_config is None:
705
705
  # Fall back to old display format.
@@ -711,7 +711,7 @@ def _get_service_status(
711
711
  original_config['pool'] = svc # Add pool to root config
712
712
  else:
713
713
  original_config = yaml_utils.safe_load(original_config)
714
- record['pool_yaml'] = common_utils.dump_yaml_str(original_config)
714
+ record['pool_yaml'] = yaml_utils.dump_yaml_str(original_config)
715
715
 
716
716
  record['target_num_replicas'] = 0
717
717
  try:
sky/serve/server/impl.py CHANGED
@@ -34,6 +34,7 @@ from sky.utils import dag_utils
34
34
  from sky.utils import rich_utils
35
35
  from sky.utils import subprocess_utils
36
36
  from sky.utils import ux_utils
37
+ from sky.utils import yaml_utils
37
38
 
38
39
  logger = sky_logging.init_logger(__name__)
39
40
 
@@ -179,7 +180,7 @@ def up(
179
180
  controller = controller_utils.get_controller_for_pool(pool)
180
181
  controller_name = controller.value.cluster_name
181
182
  task_config = task.to_yaml_config()
182
- common_utils.dump_yaml(service_file.name, task_config)
183
+ yaml_utils.dump_yaml(service_file.name, task_config)
183
184
  remote_tmp_task_yaml_path = (
184
185
  serve_utils.generate_remote_tmp_task_yaml_file_name(service_name))
185
186
  remote_config_yaml_path = (
@@ -531,7 +532,7 @@ def update(
531
532
  prefix=f'{service_name}-v{current_version}',
532
533
  mode='w') as service_file:
533
534
  task_config = task.to_yaml_config()
534
- common_utils.dump_yaml(service_file.name, task_config)
535
+ yaml_utils.dump_yaml(service_file.name, task_config)
535
536
  remote_task_yaml_path = serve_utils.generate_task_yaml_file_name(
536
537
  service_name, current_version, expand_user=False)
537
538
 
@@ -107,7 +107,8 @@ async def tail_logs(
107
107
  request_cluster_name=common.SKY_SERVE_CONTROLLER_NAME,
108
108
  )
109
109
 
110
- request_task = api_requests.get_request(request.state.request_id)
110
+ request_task = await api_requests.get_request_async(request.state.request_id
111
+ )
111
112
 
112
113
  return stream_utils.stream_response(
113
114
  request_id=request_task.request_id,
@@ -4,6 +4,7 @@ import asyncio
4
4
  import hashlib
5
5
  import http
6
6
  import os
7
+ import traceback
7
8
  from typing import Optional
8
9
  import urllib
9
10
 
@@ -109,8 +110,8 @@ class OAuth2ProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
109
110
  try:
110
111
  return await self._authenticate(request, call_next, session)
111
112
  except (aiohttp.ClientError, asyncio.TimeoutError) as e:
112
- logger.error(f'Error communicating with OAuth2 proxy: {e}')
113
- # Fail open or closed based on your security requirements
113
+ logger.error(f'Error communicating with OAuth2 proxy: {e}'
114
+ f'{traceback.format_exc()}')
114
115
  return fastapi.responses.JSONResponse(
115
116
  status_code=http.HTTPStatus.BAD_GATEWAY,
116
117
  content={'detail': 'oauth2-proxy service unavailable'})
@@ -120,10 +121,15 @@ class OAuth2ProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
120
121
  forwarded_headers = dict(request.headers)
121
122
  auth_url = f'{self.proxy_base}/oauth2/auth'
122
123
  forwarded_headers['X-Forwarded-Uri'] = str(request.url).rstrip('/')
123
- logger.debug(f'authenticate request: {request.url.path}')
124
+ # Remove content-length and content-type headers and drop request body
125
+ # to reduce the auth overhead.
126
+ forwarded_headers.pop('content-length', None)
127
+ forwarded_headers.pop('content-type', None)
128
+ logger.debug(f'authenticate request: {auth_url}, '
129
+ f'headers: {forwarded_headers}')
124
130
 
125
131
  async with session.request(
126
- method=request.method,
132
+ method='GET',
127
133
  url=auth_url,
128
134
  headers=forwarded_headers,
129
135
  cookies=request.cookies,
sky/server/common.py CHANGED
@@ -41,6 +41,7 @@ from sky.utils import annotations
41
41
  from sky.utils import common_utils
42
42
  from sky.utils import rich_utils
43
43
  from sky.utils import ux_utils
44
+ from sky.utils import yaml_utils
44
45
 
45
46
  if typing.TYPE_CHECKING:
46
47
  import aiohttp
@@ -816,7 +817,7 @@ def process_mounts_in_task_on_api_server(task: str, env_vars: Dict[str, str],
816
817
  return str(client_file_mounts_dir /
817
818
  file_mounts_mapping[original_path].lstrip('/'))
818
819
 
819
- task_configs = common_utils.read_yaml_all(str(client_task_path))
820
+ task_configs = yaml_utils.read_yaml_all(str(client_task_path))
820
821
  for task_config in task_configs:
821
822
  if task_config is None:
822
823
  continue
@@ -869,7 +870,7 @@ def process_mounts_in_task_on_api_server(task: str, env_vars: Dict[str, str],
869
870
  # We can switch to using string, but this is to make it easier to debug, by
870
871
  # persisting the translated task yaml file.
871
872
  translated_client_task_path = client_dir / f'{task_id}_translated.yaml'
872
- common_utils.dump_yaml(str(translated_client_task_path), task_configs)
873
+ yaml_utils.dump_yaml(str(translated_client_task_path), task_configs)
873
874
 
874
875
  dag = dag_utils.load_chain_dag_from_yaml(str(translated_client_task_path))
875
876
  return dag
@@ -910,8 +911,7 @@ def reload_for_new_request(client_entrypoint: Optional[str],
910
911
 
911
912
  # Clear cache should be called before reload_logger and usage reset,
912
913
  # otherwise, the latest env var will not be used.
913
- for func in annotations.FUNCTIONS_NEED_RELOAD_CACHE:
914
- func.cache_clear()
914
+ annotations.clear_request_level_cache()
915
915
 
916
916
  # We need to reset usage message, so that the message is up-to-date with the
917
917
  # latest information in the context, e.g. client entrypoint and run id.
sky/server/daemons.py CHANGED
@@ -7,8 +7,10 @@ from typing import Callable
7
7
  from sky import sky_logging
8
8
  from sky import skypilot_config
9
9
  from sky.server import constants as server_constants
10
+ from sky.utils import annotations
10
11
  from sky.utils import common
11
12
  from sky.utils import env_options
13
+ from sky.utils import timeline
12
14
  from sky.utils import ux_utils
13
15
 
14
16
  logger = sky_logging.init_logger(__name__)
@@ -67,6 +69,10 @@ class InternalRequestDaemon:
67
69
  sky_logging.reload_logger()
68
70
  level = self.refresh_log_level()
69
71
  self.event_fn()
72
+ # Clear request level cache after each run to avoid
73
+ # using too much memory.
74
+ annotations.clear_request_level_cache()
75
+ timeline.save_timeline()
70
76
  except Exception: # pylint: disable=broad-except
71
77
  # It is OK to fail to run the event, as the event is not
72
78
  # critical, but we should log the error.
@@ -191,23 +197,28 @@ INTERNAL_REQUEST_DAEMONS = [
191
197
  # set to updated status automatically, without showing users the hint of
192
198
  # cluster being stopped or down when `sky status -r` is called.
193
199
  InternalRequestDaemon(id='skypilot-status-refresh-daemon',
194
- name='status',
200
+ name='status-refresh',
195
201
  event_fn=refresh_cluster_status_event,
196
202
  default_log_level='DEBUG'),
197
203
  # Volume status refresh daemon to update the volume status periodically.
198
204
  InternalRequestDaemon(id='skypilot-volume-status-refresh-daemon',
199
- name='volume',
205
+ name='volume-refresh',
200
206
  event_fn=refresh_volume_status_event),
201
207
  InternalRequestDaemon(id='managed-job-status-refresh-daemon',
202
- name='managed-job-status',
208
+ name='managed-job-status-refresh',
203
209
  event_fn=managed_job_status_refresh_event,
204
210
  should_skip=should_skip_managed_job_status_refresh),
205
211
  InternalRequestDaemon(id='sky-serve-status-refresh-daemon',
206
- name='sky-serve-status',
212
+ name='sky-serve-status-refresh',
207
213
  event_fn=sky_serve_status_refresh_event,
208
214
  should_skip=should_skip_sky_serve_status_refresh),
209
215
  InternalRequestDaemon(id='pool-status-refresh-daemon',
210
- name='pool-status',
216
+ name='pool-status-refresh',
211
217
  event_fn=pool_status_refresh_event,
212
218
  should_skip=should_skip_pool_status_refresh),
213
219
  ]
220
+
221
+
222
+ def is_daemon_request_id(request_id: str) -> bool:
223
+ """Returns whether a specific request_id is an internal daemon."""
224
+ return any([d.id == request_id for d in INTERNAL_REQUEST_DAEMONS])
@@ -55,6 +55,7 @@ from sky.utils import context_utils
55
55
  from sky.utils import subprocess_utils
56
56
  from sky.utils import tempstore
57
57
  from sky.utils import timeline
58
+ from sky.utils import yaml_utils
58
59
  from sky.workspaces import core as workspaces_core
59
60
 
60
61
  if typing.TYPE_CHECKING:
@@ -382,12 +383,13 @@ def _request_execution_wrapper(request_id: str,
382
383
  # config, as there can be some logs during override that needs to be
383
384
  # captured in the log file.
384
385
  try:
385
- with override_request_env_and_config(request_body, request_id), \
386
+ with sky_logging.add_debug_log_handler(request_id), \
387
+ override_request_env_and_config(request_body, request_id), \
386
388
  tempstore.tempdir():
387
389
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
388
390
  config = skypilot_config.to_dict()
389
391
  logger.debug(f'request config: \n'
390
- f'{common_utils.dump_yaml_str(dict(config))}')
392
+ f'{yaml_utils.dump_yaml_str(dict(config))}')
391
393
  return_value = func(**request_body.to_kwargs())
392
394
  f.flush()
393
395
  except KeyboardInterrupt:
@@ -451,7 +453,7 @@ async def execute_request_coroutine(request: api_requests.Request):
451
453
  **request_body.to_kwargs())
452
454
 
453
455
  async def poll_task(request_id: str) -> bool:
454
- request = api_requests.get_request(request_id)
456
+ request = await api_requests.get_request_async(request_id)
455
457
  if request is None:
456
458
  raise RuntimeError('Request not found')
457
459
 
@@ -71,7 +71,9 @@ EXTERNAL_LOCAL_ENV_VARS = [
71
71
  def request_body_env_vars() -> dict:
72
72
  env_vars = {}
73
73
  for env_var in os.environ:
74
- if env_var.startswith(constants.SKYPILOT_ENV_VAR_PREFIX):
74
+ if (env_var.startswith(constants.SKYPILOT_ENV_VAR_PREFIX) and
75
+ not env_var.startswith(
76
+ constants.SKYPILOT_SERVER_ENV_VAR_PREFIX)):
75
77
  env_vars[env_var] = os.environ[env_var]
76
78
  if common.is_api_server_local() and env_var in EXTERNAL_LOCAL_ENV_VARS:
77
79
  env_vars[env_var] = os.environ[env_var]
@@ -98,7 +98,7 @@ class Precondition(abc.ABC):
98
98
  return False
99
99
 
100
100
  # Check if the request has been cancelled
101
- request = api_requests.get_request(self.request_id)
101
+ request = await api_requests.get_request_async(self.request_id)
102
102
  if request is None:
103
103
  logger.error(f'Request {self.request_id} not found')
104
104
  return False
@@ -112,7 +112,8 @@ class Precondition(abc.ABC):
112
112
  return True
113
113
  if status_msg is not None and status_msg != last_status_msg:
114
114
  # Update the status message if it has changed.
115
- with api_requests.update_request(self.request_id) as req:
115
+ async with api_requests.update_request_async(
116
+ self.request_id) as req:
116
117
  assert req is not None, self.request_id
117
118
  req.status_msg = status_msg
118
119
  last_status_msg = status_msg