skypilot-nightly 1.0.0.dev20251022__py3-none-any.whl → 1.0.0.dev20251024__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (70) hide show
  1. sky/__init__.py +2 -2
  2. sky/client/cli/command.py +118 -30
  3. sky/client/cli/table_utils.py +14 -8
  4. sky/dashboard/out/404.html +1 -1
  5. sky/dashboard/out/_next/static/{IgACOQPupLbX9z-RYVEDx → KxZarRnMeQID-gZkBfhzv}/_buildManifest.js +1 -1
  6. sky/dashboard/out/_next/static/chunks/1141-145e542070a6b615.js +11 -0
  7. sky/dashboard/out/_next/static/chunks/1871-5f68bb683940d23f.js +6 -0
  8. sky/dashboard/out/_next/static/chunks/2755.210d4b7c92c6efb2.js +26 -0
  9. sky/dashboard/out/_next/static/chunks/3015-2dcace420c8939f4.js +1 -0
  10. sky/dashboard/out/_next/static/chunks/3294.bdfea6492fdf845a.js +1 -0
  11. sky/dashboard/out/_next/static/chunks/{3785.483a3dda2d52f26e.js → 3785.e4aaef3b6c460dfb.js} +1 -1
  12. sky/dashboard/out/_next/static/chunks/{4725.10f7a9a5d3ea8208.js → 4725.a830b5c9e7867c92.js} +1 -1
  13. sky/dashboard/out/_next/static/chunks/8632-396a81107ba407c3.js +1 -0
  14. sky/dashboard/out/_next/static/chunks/9360.07d78b8552bc9d17.js +31 -0
  15. sky/dashboard/out/_next/static/chunks/pages/{_app-ce361c6959bc2001.js → _app-7d2f4467a685de9b.js} +1 -1
  16. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-602eeead010ec1d6.js → [job]-1d90fbf8f0ffd194.js} +1 -1
  17. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-5edf81da6677352c.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/pages/{clusters-57221ec2e4e01076.js → clusters-d887733ca514d22e.js} +1 -1
  19. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-44ce535a0a0ad4ec.js → [context]-79b68556c00d284e.js} +1 -1
  20. sky/dashboard/out/_next/static/chunks/pages/{infra-872e6a00165534f4.js → infra-e1c3789e69102cdb.js} +1 -1
  21. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-10f12af2b199ff68.js +16 -0
  22. sky/dashboard/out/_next/static/chunks/pages/{users-3a543725492fb896.js → users-0befb1254ac5c965.js} +1 -1
  23. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-9ad108cd67d16d96.js → [name]-ecc38434a21e8a1e.js} +1 -1
  24. sky/dashboard/out/_next/static/chunks/{webpack-919e3c01ab6b2633.js → webpack-31acb2000f9b372c.js} +1 -1
  25. sky/dashboard/out/_next/static/css/4c052b4444e52a58.css +3 -0
  26. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  27. sky/dashboard/out/clusters/[cluster].html +1 -1
  28. sky/dashboard/out/clusters.html +1 -1
  29. sky/dashboard/out/config.html +1 -1
  30. sky/dashboard/out/index.html +1 -1
  31. sky/dashboard/out/infra/[context].html +1 -1
  32. sky/dashboard/out/infra.html +1 -1
  33. sky/dashboard/out/jobs/[job].html +1 -1
  34. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  35. sky/dashboard/out/jobs.html +1 -1
  36. sky/dashboard/out/users.html +1 -1
  37. sky/dashboard/out/volumes.html +1 -1
  38. sky/dashboard/out/workspace/new.html +1 -1
  39. sky/dashboard/out/workspaces/[name].html +1 -1
  40. sky/dashboard/out/workspaces.html +1 -1
  41. sky/jobs/client/sdk.py +28 -9
  42. sky/jobs/client/sdk_async.py +9 -3
  43. sky/jobs/server/core.py +3 -1
  44. sky/jobs/utils.py +33 -22
  45. sky/server/auth/oauth2_proxy.py +2 -5
  46. sky/server/requests/requests.py +39 -6
  47. sky/server/requests/serializers/decoders.py +23 -10
  48. sky/server/requests/serializers/encoders.py +4 -3
  49. sky/server/rest.py +35 -1
  50. sky/skylet/log_lib.py +8 -1
  51. sky/skylet/subprocess_daemon.py +103 -29
  52. sky/utils/db/db_utils.py +21 -0
  53. sky/utils/subprocess_utils.py +13 -1
  54. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/METADATA +33 -33
  55. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/RECORD +60 -60
  56. sky/dashboard/out/_next/static/chunks/1141-ec6f902ffb865853.js +0 -11
  57. sky/dashboard/out/_next/static/chunks/1871-df9f87fcb7f24292.js +0 -6
  58. sky/dashboard/out/_next/static/chunks/2755.9b1e69c921b5a870.js +0 -26
  59. sky/dashboard/out/_next/static/chunks/3015-d014dc5b9412fade.js +0 -1
  60. sky/dashboard/out/_next/static/chunks/3294.998db87cd52a1238.js +0 -1
  61. sky/dashboard/out/_next/static/chunks/6135-4b4d5e824b7f9d3c.js +0 -1
  62. sky/dashboard/out/_next/static/chunks/9360.14326e329484b57e.js +0 -31
  63. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-18b334dedbd9f6f2.js +0 -1
  64. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-8677af16befde039.js +0 -16
  65. sky/dashboard/out/_next/static/css/4614e06482d7309e.css +0 -3
  66. /sky/dashboard/out/_next/static/{IgACOQPupLbX9z-RYVEDx → KxZarRnMeQID-gZkBfhzv}/_ssgManifest.js +0 -0
  67. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/WHEEL +0 -0
  68. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/entry_points.txt +0 -0
  69. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/licenses/LICENSE +0 -0
  70. {skypilot_nightly-1.0.0.dev20251022.dist-info → skypilot_nightly-1.0.0.dev20251024.dist-info}/top_level.txt +0 -0
sky/jobs/client/sdk.py CHANGED
@@ -130,8 +130,11 @@ def queue(
130
130
  refresh: bool,
131
131
  skip_finished: bool = False,
132
132
  all_users: bool = False,
133
- job_ids: Optional[List[int]] = None
134
- ) -> server_common.RequestId[List[responses.ManagedJobRecord]]:
133
+ job_ids: Optional[List[int]] = None,
134
+ limit: Optional[int] = None,
135
+ fields: Optional[List[str]] = None,
136
+ ) -> server_common.RequestId[Union[List[responses.ManagedJobRecord], Tuple[
137
+ List[responses.ManagedJobRecord], int, Dict[str, int], int]]]:
135
138
  """Gets statuses of managed jobs.
136
139
 
137
140
  Please refer to sky.cli.job_queue for documentation.
@@ -141,6 +144,8 @@ def queue(
141
144
  skip_finished: Whether to skip finished jobs.
142
145
  all_users: Whether to show all users' jobs.
143
146
  job_ids: IDs of the managed jobs to show.
147
+ limit: Number of jobs to show.
148
+ fields: Fields to get for the managed jobs.
144
149
 
145
150
  Returns:
146
151
  The request ID of the queue request.
@@ -173,15 +178,29 @@ def queue(
173
178
  does not exist.
174
179
  RuntimeError: if failed to get the managed jobs with ssh.
175
180
  """
176
- body = payloads.JobsQueueBody(
177
- refresh=refresh,
178
- skip_finished=skip_finished,
179
- all_users=all_users,
180
- job_ids=job_ids,
181
- )
181
+ remote_api_version = versions.get_remote_api_version()
182
+ if remote_api_version and remote_api_version >= 18:
183
+ body = payloads.JobsQueueV2Body(
184
+ refresh=refresh,
185
+ skip_finished=skip_finished,
186
+ all_users=all_users,
187
+ job_ids=job_ids,
188
+ limit=limit,
189
+ fields=fields,
190
+ )
191
+ path = '/jobs/queue/v2'
192
+ else:
193
+ body = payloads.JobsQueueBody(
194
+ refresh=refresh,
195
+ skip_finished=skip_finished,
196
+ all_users=all_users,
197
+ job_ids=job_ids,
198
+ )
199
+ path = '/jobs/queue'
200
+
182
201
  response = server_common.make_authenticated_request(
183
202
  'POST',
184
- '/jobs/queue',
203
+ path,
185
204
  json=json.loads(body.model_dump_json()),
186
205
  timeout=(5, None))
187
206
  return server_common.get_request_id(response=response)
@@ -1,12 +1,13 @@
1
1
  """Async SDK functions for managed jobs."""
2
2
  import typing
3
- from typing import Any, Dict, List, Optional, Tuple, Union
3
+ from typing import Dict, List, Optional, Tuple, Union
4
4
 
5
5
  from sky import backends
6
6
  from sky import sky_logging
7
7
  from sky.adaptors import common as adaptors_common
8
8
  from sky.client import sdk_async
9
9
  from sky.jobs.client import sdk
10
+ from sky.schemas.api import responses
10
11
  from sky.skylet import constants
11
12
  from sky.usage import usage_lib
12
13
  from sky.utils import common_utils
@@ -50,12 +51,17 @@ async def queue(
50
51
  refresh: bool,
51
52
  skip_finished: bool = False,
52
53
  all_users: bool = False,
54
+ job_ids: Optional[List[int]] = None,
55
+ limit: Optional[int] = None,
56
+ fields: Optional[List[str]] = None,
53
57
  stream_logs: Optional[
54
58
  sdk_async.StreamConfig] = sdk_async.DEFAULT_STREAM_CONFIG
55
- ) -> List[Dict[str, Any]]:
59
+ ) -> Union[List[responses.ManagedJobRecord], Tuple[
60
+ List[responses.ManagedJobRecord], int, Dict[str, int], int]]:
56
61
  """Async version of queue() that gets statuses of managed jobs."""
57
62
  request_id = await context_utils.to_thread(sdk.queue, refresh,
58
- skip_finished, all_users)
63
+ skip_finished, all_users,
64
+ job_ids, limit, fields)
59
65
  if stream_logs is not None:
60
66
  return await sdk_async._stream_and_get(request_id, stream_logs) # pylint: disable=protected-access
61
67
  else:
sky/jobs/server/core.py CHANGED
@@ -337,6 +337,7 @@ def launch(
337
337
  def _submit_one(
338
338
  consolidation_mode_job_id: Optional[int] = None,
339
339
  job_rank: Optional[int] = None,
340
+ num_jobs: Optional[int] = None,
340
341
  ) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
341
342
  rank_suffix = '' if job_rank is None else f'-{job_rank}'
342
343
  remote_original_user_yaml_path = (
@@ -359,6 +360,7 @@ def launch(
359
360
  for task_ in dag.tasks:
360
361
  if job_rank is not None:
361
362
  task_.update_envs({'SKYPILOT_JOB_RANK': str(job_rank)})
363
+ task_.update_envs({'SKYPILOT_NUM_JOBS': str(num_jobs)})
362
364
 
363
365
  dag_utils.dump_chain_dag_to_yaml(dag, f.name)
364
366
 
@@ -475,7 +477,7 @@ def launch(
475
477
  for job_rank in range(num_jobs):
476
478
  job_id = (consolidation_mode_job_ids[job_rank]
477
479
  if consolidation_mode_job_ids is not None else None)
478
- jid, handle = _submit_one(job_id, job_rank)
480
+ jid, handle = _submit_one(job_id, job_rank, num_jobs=num_jobs)
479
481
  assert jid is not None, (job_id, handle)
480
482
  ids.append(jid)
481
483
  all_handle = handle
sky/jobs/utils.py CHANGED
@@ -1710,29 +1710,37 @@ def _get_job_status_from_tasks(
1710
1710
 
1711
1711
 
1712
1712
  @typing.overload
1713
- def format_job_table(tasks: List[Dict[str, Any]],
1714
- show_all: bool,
1715
- show_user: bool,
1716
- return_rows: Literal[False] = False,
1717
- max_jobs: Optional[int] = None) -> str:
1713
+ def format_job_table(
1714
+ tasks: List[Dict[str, Any]],
1715
+ show_all: bool,
1716
+ show_user: bool,
1717
+ return_rows: Literal[False] = False,
1718
+ max_jobs: Optional[int] = None,
1719
+ job_status_counts: Optional[Dict[str, int]] = None,
1720
+ ) -> str:
1718
1721
  ...
1719
1722
 
1720
1723
 
1721
1724
  @typing.overload
1722
- def format_job_table(tasks: List[Dict[str, Any]],
1723
- show_all: bool,
1724
- show_user: bool,
1725
- return_rows: Literal[True],
1726
- max_jobs: Optional[int] = None) -> List[List[str]]:
1725
+ def format_job_table(
1726
+ tasks: List[Dict[str, Any]],
1727
+ show_all: bool,
1728
+ show_user: bool,
1729
+ return_rows: Literal[True],
1730
+ max_jobs: Optional[int] = None,
1731
+ job_status_counts: Optional[Dict[str, int]] = None,
1732
+ ) -> List[List[str]]:
1727
1733
  ...
1728
1734
 
1729
1735
 
1730
1736
  def format_job_table(
1731
- tasks: List[Dict[str, Any]],
1732
- show_all: bool,
1733
- show_user: bool,
1734
- return_rows: bool = False,
1735
- max_jobs: Optional[int] = None) -> Union[str, List[List[str]]]:
1737
+ tasks: List[Dict[str, Any]],
1738
+ show_all: bool,
1739
+ show_user: bool,
1740
+ return_rows: bool = False,
1741
+ max_jobs: Optional[int] = None,
1742
+ job_status_counts: Optional[Dict[str, int]] = None,
1743
+ ) -> Union[str, List[List[str]]]:
1736
1744
  """Returns managed jobs as a formatted string.
1737
1745
 
1738
1746
  Args:
@@ -1741,6 +1749,7 @@ def format_job_table(
1741
1749
  max_jobs: The maximum number of jobs to show in the table.
1742
1750
  return_rows: If True, return the rows as a list of strings instead of
1743
1751
  all rows concatenated into a single string.
1752
+ job_status_counts: The counts of each job status.
1744
1753
 
1745
1754
  Returns: A formatted string of managed jobs, if not `return_rows`; otherwise
1746
1755
  a list of "rows" (each of which is a list of str).
@@ -1762,12 +1771,8 @@ def format_job_table(
1762
1771
  # by the task_id.
1763
1772
  jobs[get_hash(task)].append(task)
1764
1773
 
1765
- status_counts: Dict[str, int] = collections.defaultdict(int)
1766
1774
  workspaces = set()
1767
1775
  for job_tasks in jobs.values():
1768
- managed_job_status = _get_job_status_from_tasks(job_tasks)[0]
1769
- if not managed_job_status.is_terminal():
1770
- status_counts[managed_job_status.value] += 1
1771
1776
  workspaces.add(job_tasks[0].get('workspace',
1772
1777
  constants.SKYPILOT_DEFAULT_WORKSPACE))
1773
1778
 
@@ -1810,9 +1815,15 @@ def format_job_table(
1810
1815
  job_table = log_utils.create_table(columns)
1811
1816
 
1812
1817
  status_counts: Dict[str, int] = collections.defaultdict(int)
1813
- for task in tasks:
1814
- if not task['status'].is_terminal():
1815
- status_counts[task['status'].value] += 1
1818
+ if job_status_counts:
1819
+ for status_value, count in job_status_counts.items():
1820
+ status = managed_job_state.ManagedJobStatus(status_value)
1821
+ if not status.is_terminal():
1822
+ status_counts[status_value] = count
1823
+ else:
1824
+ for task in tasks:
1825
+ if not task['status'].is_terminal():
1826
+ status_counts[task['status'].value] += 1
1816
1827
 
1817
1828
  all_tasks = tasks
1818
1829
  if max_jobs is not None:
@@ -126,13 +126,10 @@ class OAuth2ProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
126
126
 
127
127
  async def _authenticate(self, request: fastapi.Request, call_next,
128
128
  session: aiohttp.ClientSession):
129
- forwarded_headers = dict(request.headers)
129
+ forwarded_headers = {}
130
130
  auth_url = f'{self.proxy_base}/oauth2/auth'
131
131
  forwarded_headers['X-Forwarded-Uri'] = str(request.url).rstrip('/')
132
- # Remove content-length and content-type headers and drop request body
133
- # to reduce the auth overhead.
134
- forwarded_headers.pop('content-length', None)
135
- forwarded_headers.pop('content-type', None)
132
+ forwarded_headers['Host'] = request.url.hostname
136
133
  logger.debug(f'authenticate request: {auth_url}, '
137
134
  f'headers: {forwarded_headers}')
138
135
 
@@ -578,6 +578,26 @@ def reset_db_and_logs():
578
578
  f'{server_common.API_SERVER_CLIENT_DIR.expanduser()}')
579
579
  shutil.rmtree(server_common.API_SERVER_CLIENT_DIR.expanduser(),
580
580
  ignore_errors=True)
581
+ with _init_db_lock:
582
+ _init_db_within_lock()
583
+ assert _DB is not None
584
+ with _DB.conn:
585
+ cursor = _DB.conn.cursor()
586
+ cursor.execute('SELECT sqlite_version()')
587
+ row = cursor.fetchone()
588
+ if row is None:
589
+ raise RuntimeError('Failed to get SQLite version')
590
+ version_str = row[0]
591
+ version_parts = version_str.split('.')
592
+ assert len(version_parts) >= 2, \
593
+ f'Invalid version string: {version_str}'
594
+ major, minor = int(version_parts[0]), int(version_parts[1])
595
+ # SQLite 3.35.0+ supports RETURNING statements.
596
+ # 3.35.0 was released in March 2021.
597
+ if not ((major > 3) or (major == 3 and minor >= 35)):
598
+ raise RuntimeError(
599
+ f'SQLite version {version_str} is not supported. '
600
+ 'Please upgrade to SQLite 3.35.0 or later.')
581
601
 
582
602
 
583
603
  def request_lock_path(request_id: str) -> str:
@@ -733,12 +753,25 @@ async def get_request_status_async(
733
753
  @metrics_lib.time_me_async
734
754
  @asyncio_utils.shield
735
755
  async def create_if_not_exists_async(request: Request) -> bool:
736
- """Async version of create_if_not_exists."""
737
- async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
738
- if await _get_request_no_lock_async(request.request_id) is not None:
739
- return False
740
- await _add_or_update_request_no_lock_async(request)
741
- return True
756
+ """Create a request if it does not exist, otherwise do nothing.
757
+
758
+ Returns:
759
+ True if a new request is created, False if the request already exists.
760
+ """
761
+ assert _DB is not None
762
+ request_columns = ', '.join(REQUEST_COLUMNS)
763
+ values_str = ', '.join(['?'] * len(REQUEST_COLUMNS))
764
+ sql_statement = (
765
+ f'INSERT INTO {REQUEST_TABLE} '
766
+ f'({request_columns}) VALUES '
767
+ f'({values_str}) ON CONFLICT(request_id) DO NOTHING RETURNING ROWID')
768
+ request_row = request.to_row()
769
+ # Execute the SQL statement without getting the request lock.
770
+ # The request lock is used to prevent racing with cancellation codepath,
771
+ # but a request cannot be cancelled before it is created.
772
+ row = await _DB.execute_get_returning_value_async(sql_statement,
773
+ request_row)
774
+ return True if row else False
742
775
 
743
776
 
744
777
  @dataclasses.dataclass
@@ -2,7 +2,7 @@
2
2
  import base64
3
3
  import pickle
4
4
  import typing
5
- from typing import Any, Dict, List, Optional, Tuple
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  from sky import jobs as managed_jobs
8
8
  from sky import models
@@ -116,22 +116,35 @@ def decode_jobs_queue(return_value: List[dict],) -> List[Dict[str, Any]]:
116
116
 
117
117
 
118
118
  @register_decoders('jobs.queue_v2')
119
- def decode_jobs_queue_v2(return_value) -> List[responses.ManagedJobRecord]:
119
+ def decode_jobs_queue_v2(
120
+ return_value
121
+ ) -> Union[Tuple[List[responses.ManagedJobRecord], int, Dict[str, int], int],
122
+ List[responses.ManagedJobRecord]]:
120
123
  """Decode jobs queue response.
121
124
 
122
- Supports legacy list, or a dict {jobs, total}.
123
- - Returns list[job]
125
+ Supports legacy list, or a dict {jobs, total, total_no_filter,
126
+ status_counts}.
127
+
128
+ - Returns either list[job] or tuple(list[job], total, status_counts,
129
+ total_no_filter)
124
130
  """
125
- # Case 1: dict shape {jobs, total}
126
- if isinstance(return_value, dict) and 'jobs' in return_value:
131
+ # Case 1: dict shape {jobs, total, total_no_filter, status_counts}
132
+ if isinstance(return_value, dict):
127
133
  jobs = return_value.get('jobs', [])
134
+ total = return_value.get('total', len(jobs))
135
+ total_no_filter = return_value.get('total_no_filter', total)
136
+ status_counts = return_value.get('status_counts', {})
137
+ for job in jobs:
138
+ job['status'] = managed_jobs.ManagedJobStatus(job['status'])
139
+ jobs = [responses.ManagedJobRecord(**job) for job in jobs]
140
+ return jobs, total, status_counts, total_no_filter
128
141
  else:
129
142
  # Case 2: legacy list
130
143
  jobs = return_value
131
- for job in jobs:
132
- job['status'] = managed_jobs.ManagedJobStatus(job['status'])
133
- jobs = [responses.ManagedJobRecord(**job) for job in jobs]
134
- return jobs
144
+ for job in jobs:
145
+ job['status'] = managed_jobs.ManagedJobStatus(job['status'])
146
+ jobs = [responses.ManagedJobRecord(**job) for job in jobs]
147
+ return jobs
135
148
 
136
149
 
137
150
  def _decode_serve_status(
@@ -148,12 +148,13 @@ def encode_jobs_queue_v2(
148
148
  else:
149
149
  jobs = jobs_or_tuple
150
150
  total = None
151
- for job in jobs:
151
+ jobs_dict = [job.model_dump(by_alias=True) for job in jobs]
152
+ for job in jobs_dict:
152
153
  job['status'] = job['status'].value
153
154
  if total is None:
154
- return [job.model_dump(by_alias=True) for job in jobs]
155
+ return jobs_dict
155
156
  return {
156
- 'jobs': [job.model_dump(by_alias=True) for job in jobs],
157
+ 'jobs': jobs_dict,
157
158
  'total': total,
158
159
  'total_no_filter': total_no_filter,
159
160
  'status_counts': status_counts
sky/server/rest.py CHANGED
@@ -256,6 +256,40 @@ def handle_server_unavailable(response: 'requests.Response') -> None:
256
256
  raise exceptions.ServerTemporarilyUnavailableError(error_msg)
257
257
 
258
258
 
259
+ async def handle_server_unavailable_async(
260
+ response: 'aiohttp.ClientResponse') -> None:
261
+ """Async version: Handle 503 (Service Unavailable) error
262
+
263
+ The client get 503 error in the following cases:
264
+ 1. The reverse proxy cannot find any ready backend endpoints to serve the
265
+ request, e.g. when there is and rolling-update.
266
+ 2. The skypilot API server has temporary resource issue, e.g. when the
267
+ cucurrency of the handling process is exhausted.
268
+
269
+ We expect the caller (CLI or SDK) retry on these cases and show clear wait
270
+ message to the user to let user decide whether keep waiting or abort the
271
+ request.
272
+ """
273
+ if response.status != 503:
274
+ return
275
+
276
+ error_msg = ''
277
+ try:
278
+ response_data = await response.json()
279
+ if 'detail' in response_data:
280
+ error_msg = response_data['detail']
281
+ except Exception: # pylint: disable=broad-except
282
+ try:
283
+ text = await response.text()
284
+ if text:
285
+ error_msg = text
286
+ except Exception: # pylint: disable=broad-except
287
+ pass
288
+
289
+ with ux_utils.print_exception_no_traceback():
290
+ raise exceptions.ServerTemporarilyUnavailableError(error_msg)
291
+
292
+
259
293
  @_retry_on_server_unavailable()
260
294
  def request(method, url, **kwargs) -> 'requests.Response':
261
295
  """Send a request to the API server, retry on server temporarily
@@ -332,7 +366,7 @@ async def request_without_retry_async(session: 'aiohttp.ClientSession',
332
366
  response = await session.request(method, url, **kwargs)
333
367
 
334
368
  # Handle server unavailability (503 status) - same as sync version
335
- handle_server_unavailable(response)
369
+ await handle_server_unavailable_async(response)
336
370
 
337
371
  # Set remote API version and version from headers - same as sync version
338
372
  remote_api_version = response.headers.get(constants.API_VERSION_HEADER)
sky/skylet/log_lib.py CHANGED
@@ -220,7 +220,14 @@ def run_with_log(
220
220
  stdin=stdin,
221
221
  **kwargs) as proc:
222
222
  try:
223
- subprocess_utils.kill_process_daemon(proc.pid)
223
+ if ctx is not None:
224
+ # When runs in coroutine, use kill_pg if available to avoid
225
+ # the overhead of refreshing the process tree in the daemon.
226
+ subprocess_utils.kill_process_daemon(proc.pid, use_kill_pg=True)
227
+ else:
228
+ # For backward compatibility, do not specify use_kill_pg by
229
+ # default.
230
+ subprocess_utils.kill_process_daemon(proc.pid)
224
231
  stdout = ''
225
232
  stderr = ''
226
233
  stdout_stream_handler = None
@@ -4,11 +4,16 @@ processes of proc_pid.
4
4
  """
5
5
  import argparse
6
6
  import os
7
+ import signal
7
8
  import sys
8
9
  import time
10
+ from typing import List, Optional
9
11
 
10
12
  import psutil
11
13
 
14
+ # Environment variable to enable kill_pg in subprocess daemon.
15
+ USE_KILL_PG_ENV_VAR = 'SKYPILOT_SUBPROCESS_DAEMON_KILL_PG'
16
+
12
17
 
13
18
  def daemonize():
14
19
  """Detaches the process from its parent process with double-forking.
@@ -38,8 +43,74 @@ def daemonize():
38
43
  # This process is now fully detached from the original parent and terminal
39
44
 
40
45
 
41
- if __name__ == '__main__':
42
- daemonize()
46
+ def get_pgid_if_leader(pid) -> Optional[int]:
47
+ """Get the process group ID of the target process if it is the leader."""
48
+ try:
49
+ pgid = os.getpgid(pid)
50
+ # Only use process group if the target process is the leader. This is
51
+ # to avoid killing the entire process group while the target process is
52
+ # just a subprocess in the group.
53
+ if pgid == pid:
54
+ print(f'Process group {pgid} is the leader.')
55
+ return pgid
56
+ return None
57
+ except Exception: # pylint: disable=broad-except
58
+ # Process group is only available in UNIX.
59
+ return None
60
+
61
+
62
+ def kill_process_group(pgid: int) -> bool:
63
+ """Kill the target process group."""
64
+ try:
65
+ print(f'Terminating process group {pgid}...')
66
+ os.killpg(pgid, signal.SIGTERM)
67
+ except Exception: # pylint: disable=broad-except
68
+ return False
69
+
70
+ # Wait 30s for the process group to exit gracefully.
71
+ time.sleep(30)
72
+
73
+ try:
74
+ print(f'Force killing process group {pgid}...')
75
+ os.killpg(pgid, signal.SIGKILL)
76
+ except Exception: # pylint: disable=broad-except
77
+ pass
78
+
79
+ return True
80
+
81
+
82
+ def kill_process_tree(process: psutil.Process,
83
+ children: List[psutil.Process]) -> bool:
84
+ """Kill the process tree of the target process."""
85
+ if process is not None:
86
+ # Kill the target process first to avoid having more children, or fail
87
+ # the process due to the children being defunct.
88
+ children = [process] + children
89
+
90
+ if not children:
91
+ sys.exit()
92
+
93
+ for child in children:
94
+ try:
95
+ child.terminate()
96
+ except psutil.NoSuchProcess:
97
+ continue
98
+
99
+ # Wait 30s for the processes to exit gracefully.
100
+ time.sleep(30)
101
+
102
+ # SIGKILL if they're still running.
103
+ for child in children:
104
+ try:
105
+ child.kill()
106
+ except psutil.NoSuchProcess:
107
+ continue
108
+
109
+ return True
110
+
111
+
112
+ def main():
113
+ # daemonize()
43
114
  parser = argparse.ArgumentParser()
44
115
  parser.add_argument('--parent-pid', type=int, required=True)
45
116
  parser.add_argument('--proc-pid', type=int, required=True)
@@ -72,37 +143,40 @@ if __name__ == '__main__':
72
143
  except (psutil.NoSuchProcess, ValueError):
73
144
  pass
74
145
 
146
+ pgid: Optional[int] = None
147
+ if os.environ.get(USE_KILL_PG_ENV_VAR) == '1':
148
+ # Use kill_pg on UNIX system if allowed to reduce the resource usage.
149
+ # Note that both implementations might leave subprocessed uncancelled:
150
+ # - kill_process_tree(default): a subprocess is able to detach itself
151
+ # from the process tree use the same technique as daemonize(). Also,
152
+ # since we refresh the process tree per second, if the subprocess is
153
+ # launched between the [last_poll, parent_die] interval, the
154
+ # subprocess will not be captured will not be killed.
155
+ # - kill_process_group: kill_pg will kill all the processed in the group
156
+ # but if a subprocess calls setpgid(0, 0) to detach itself from the
157
+ # process group (usually to daemonize itself), the subprocess will
158
+ # not be killed.
159
+ pgid = get_pgid_if_leader(process.pid)
160
+
75
161
  if process is not None and parent_process is not None:
76
162
  # Wait for either parent or target process to exit
77
163
  while process.is_running() and parent_process.is_running():
78
- try:
79
- tmp_children = process.children(recursive=True)
80
- if tmp_children:
81
- children = tmp_children
82
- except psutil.NoSuchProcess:
83
- pass
164
+ if pgid is None:
165
+ # Refresh process tree for cleanup if process group is not
166
+ # available.
167
+ try:
168
+ tmp_children = process.children(recursive=True)
169
+ if tmp_children:
170
+ children = tmp_children
171
+ except psutil.NoSuchProcess:
172
+ pass
84
173
  time.sleep(1)
85
174
 
86
- if process is not None:
87
- # Kill the target process first to avoid having more children, or fail
88
- # the process due to the children being defunct.
89
- children = [process] + children
175
+ if pgid is not None:
176
+ kill_process_group(pgid)
177
+ else:
178
+ kill_process_tree(process, children)
90
179
 
91
- if not children:
92
- sys.exit()
93
180
 
94
- for child in children:
95
- try:
96
- child.terminate()
97
- except psutil.NoSuchProcess:
98
- continue
99
-
100
- # Wait 30s for the processes to exit gracefully.
101
- time.sleep(30)
102
-
103
- # SIGKILL if they're still running.
104
- for child in children:
105
- try:
106
- child.kill()
107
- except psutil.NoSuchProcess:
108
- continue
181
+ if __name__ == '__main__':
182
+ main()
sky/utils/db/db_utils.py CHANGED
@@ -358,6 +358,27 @@ class SQLiteConn(threading.local):
358
358
  conn = await self._get_async_conn()
359
359
  return await conn.execute_fetchall(sql, parameters)
360
360
 
361
+ async def execute_get_returning_value_async(
362
+ self,
363
+ sql: str,
364
+ parameters: Optional[Iterable[Any]] = None
365
+ ) -> Optional[sqlite3.Row]:
366
+ conn = await self._get_async_conn()
367
+
368
+ if parameters is None:
369
+ parameters = []
370
+
371
+ def exec_and_get_returning_value(sql: str,
372
+ parameters: Optional[Iterable[Any]]):
373
+ # pylint: disable=protected-access
374
+ row = conn._conn.execute(sql, parameters).fetchone()
375
+ conn._conn.commit()
376
+ return row
377
+
378
+ # pylint: disable=protected-access
379
+ return await conn._execute(exec_and_get_returning_value, sql,
380
+ parameters)
381
+
361
382
  async def close(self):
362
383
  if self._async_conn is not None:
363
384
  await self._async_conn.close()
@@ -19,6 +19,7 @@ from sky import exceptions
19
19
  from sky import sky_logging
20
20
  from sky.adaptors import common as adaptors_common
21
21
  from sky.skylet import log_lib
22
+ from sky.skylet import subprocess_daemon
22
23
  from sky.utils import common_utils
23
24
  from sky.utils import timeline
24
25
  from sky.utils import ux_utils
@@ -306,11 +307,17 @@ def run_with_retries(
306
307
  return returncode, stdout, stderr
307
308
 
308
309
 
309
- def kill_process_daemon(process_pid: int) -> None:
310
+ def kill_process_daemon(process_pid: int, use_kill_pg: bool = False) -> None:
310
311
  """Start a daemon as a safety net to kill the process.
311
312
 
312
313
  Args:
313
314
  process_pid: The PID of the process to kill.
315
+ use_kill_pg: Whether to use kill process group to kill the process. If
316
+ True, the process will use os.killpg() to kill the target process
317
+ group on UNIX system, which is more efficient than using the daemon
318
+ to refresh the process tree in the daemon. Note that both
319
+ implementations have corner cases where subprocesses might not be
320
+ killed. Refer to subprocess_daemon.py for more details.
314
321
  """
315
322
  # Get initial children list
316
323
  try:
@@ -337,6 +344,10 @@ def kill_process_daemon(process_pid: int) -> None:
337
344
  ','.join(map(str, initial_children)),
338
345
  ]
339
346
 
347
+ env = os.environ.copy()
348
+ if use_kill_pg:
349
+ env[subprocess_daemon.USE_KILL_PG_ENV_VAR] = '1'
350
+
340
351
  # We do not need to set `start_new_session=True` here, as the
341
352
  # daemon script will detach itself from the parent process with
342
353
  # fork to avoid being killed by parent process. See the reason we
@@ -348,6 +359,7 @@ def kill_process_daemon(process_pid: int) -> None:
348
359
  stderr=subprocess.DEVNULL,
349
360
  # Disable input
350
361
  stdin=subprocess.DEVNULL,
362
+ env=env,
351
363
  )
352
364
 
353
365