skypilot-nightly 1.0.0.dev20250903__py3-none-any.whl → 1.0.0.dev20250905__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/cloud_vm_ray_backend.py +18 -2
  3. sky/client/cli/command.py +18 -1
  4. sky/client/sdk.py +42 -17
  5. sky/clouds/nebius.py +4 -2
  6. sky/dashboard/out/404.html +1 -1
  7. sky/dashboard/out/_next/static/chunks/{1121-ec35954c8cbea535.js → 1121-408ed10b2f9fce17.js} +1 -1
  8. sky/dashboard/out/_next/static/chunks/{7205-88191679e7988c57.js → 1836-37fede578e2da5f8.js} +4 -9
  9. sky/dashboard/out/_next/static/chunks/3015-86cabed5d4669ad0.js +1 -0
  10. sky/dashboard/out/_next/static/chunks/3294.c80326aec9bfed40.js +6 -0
  11. sky/dashboard/out/_next/static/chunks/{3785.d5b86f6ebc88e6e6.js → 3785.4872a2f3aa489880.js} +1 -1
  12. sky/dashboard/out/_next/static/chunks/{4783.c485f48348349f47.js → 5339.3fda4a4010ff4e06.js} +4 -9
  13. sky/dashboard/out/_next/static/chunks/{9946.3b7b43c217ff70ec.js → 649.b9d7f7d10c1b8c53.js} +4 -9
  14. sky/dashboard/out/_next/static/chunks/6856-dca7962af4814e1b.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/{8969-4a6f1a928fb6d370.js → 8969-0be3036bf86f8256.js} +1 -1
  16. sky/dashboard/out/_next/static/chunks/9025.c12318fb6a1a9093.js +6 -0
  17. sky/dashboard/out/_next/static/chunks/{9037-89a84fd7fa31362d.js → 9037-fa1737818d0a0969.js} +2 -2
  18. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-1cbba24bd1bd35f8.js +16 -0
  19. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-a0527109c2fab467.js → [cluster]-0b4b35dc1dfe046c.js} +2 -7
  20. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-81351f95f3bec08e.js → [context]-6563820e094f68ca.js} +1 -1
  21. sky/dashboard/out/_next/static/chunks/pages/{infra-c320641c2bcbbea6.js → infra-aabba60d57826e0f.js} +1 -1
  22. sky/dashboard/out/_next/static/chunks/pages/jobs-1f70d9faa564804f.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-de06e613e20bc977.js → [name]-af76bb06dbb3954f.js} +1 -1
  24. sky/dashboard/out/_next/static/chunks/pages/{workspaces-be35b22e2046564c.js → workspaces-7598c33a746cdc91.js} +1 -1
  25. sky/dashboard/out/_next/static/chunks/webpack-4fe903277b57b523.js +1 -0
  26. sky/dashboard/out/_next/static/mS-4qZPSkRuA1u-g2wQhg/_buildManifest.js +1 -0
  27. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  28. sky/dashboard/out/clusters/[cluster].html +1 -1
  29. sky/dashboard/out/clusters.html +1 -1
  30. sky/dashboard/out/config.html +1 -1
  31. sky/dashboard/out/index.html +1 -1
  32. sky/dashboard/out/infra/[context].html +1 -1
  33. sky/dashboard/out/infra.html +1 -1
  34. sky/dashboard/out/jobs/[job].html +1 -1
  35. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  36. sky/dashboard/out/jobs.html +1 -1
  37. sky/dashboard/out/users.html +1 -1
  38. sky/dashboard/out/volumes.html +1 -1
  39. sky/dashboard/out/workspace/new.html +1 -1
  40. sky/dashboard/out/workspaces/[name].html +1 -1
  41. sky/dashboard/out/workspaces.html +1 -1
  42. sky/data/mounting_utils.py +29 -38
  43. sky/global_user_state.py +17 -5
  44. sky/jobs/state.py +1 -1
  45. sky/models.py +24 -1
  46. sky/provision/kubernetes/instance.py +10 -3
  47. sky/serve/serve_state.py +1 -1
  48. sky/server/config.py +31 -3
  49. sky/server/requests/executor.py +9 -3
  50. sky/server/requests/requests.py +24 -14
  51. sky/server/server.py +24 -21
  52. sky/server/uvicorn.py +9 -3
  53. sky/skylet/constants.py +1 -1
  54. sky/skypilot_config.py +21 -9
  55. sky/ssh_node_pools/server.py +5 -5
  56. sky/users/permission.py +6 -0
  57. sky/users/server.py +26 -17
  58. sky/utils/db/db_utils.py +61 -1
  59. sky/utils/db/migration_utils.py +0 -32
  60. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/METADATA +35 -35
  61. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/RECORD +66 -66
  62. sky/dashboard/out/_next/static/chunks/3015-8089ed1e0b7e37fd.js +0 -1
  63. sky/dashboard/out/_next/static/chunks/6856-049014c6d43d127b.js +0 -1
  64. sky/dashboard/out/_next/static/chunks/9025.a1bef12d672bb66d.js +0 -6
  65. sky/dashboard/out/_next/static/chunks/9984.7eb6cc51fb460cae.js +0 -6
  66. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-b77360a343d48902.js +0 -16
  67. sky/dashboard/out/_next/static/chunks/pages/jobs-7421e63ac35f8fce.js +0 -1
  68. sky/dashboard/out/_next/static/chunks/webpack-60556df644cd5d71.js +0 -1
  69. sky/dashboard/out/_next/static/yLz6EPhW_XXmnNs1I6dmS/_buildManifest.js +0 -1
  70. /sky/dashboard/out/_next/static/{yLz6EPhW_XXmnNs1I6dmS → mS-4qZPSkRuA1u-g2wQhg}/_ssgManifest.js +0 -0
  71. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/WHEEL +0 -0
  72. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/entry_points.txt +0 -0
  73. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/licenses/LICENSE +0 -0
  74. {skypilot_nightly-1.0.0.dev20250903.dist-info → skypilot_nightly-1.0.0.dev20250905.dist-info}/top_level.txt +0 -0
sky/global_user_state.py CHANGED
@@ -299,9 +299,7 @@ def create_table(engine: sqlalchemy.engine.Engine):
299
299
  # a session has already been created with _SQLALCHEMY_ENGINE = e1,
300
300
  # and then another thread overwrites _SQLALCHEMY_ENGINE = e2
301
301
  # which could result in e1 being garbage collected unexpectedly.
302
- def initialize_and_get_db(
303
- pg_pool_class: Optional[sqlalchemy.pool.Pool] = None
304
- ) -> sqlalchemy.engine.Engine:
302
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
305
303
  global _SQLALCHEMY_ENGINE
306
304
 
307
305
  if _SQLALCHEMY_ENGINE is not None:
@@ -310,8 +308,7 @@ def initialize_and_get_db(
310
308
  if _SQLALCHEMY_ENGINE is not None:
311
309
  return _SQLALCHEMY_ENGINE
312
310
  # get an engine to the db
313
- engine = migration_utils.get_engine('state',
314
- pg_pool_class=pg_pool_class)
311
+ engine = db_utils.get_engine('state')
315
312
 
316
313
  # run migrations if needed
317
314
  create_table(engine)
@@ -2315,3 +2312,18 @@ def set_system_config(config_key: str, config_value: str) -> None:
2315
2312
  })
2316
2313
  session.execute(upsert_stmnt)
2317
2314
  session.commit()
2315
+
2316
+
2317
+ @_init_db
2318
+ def get_max_db_connections() -> Optional[int]:
2319
+ """Get the maximum number of connections for the engine."""
2320
+ assert _SQLALCHEMY_ENGINE is not None
2321
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2322
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2323
+ return None
2324
+ with sqlalchemy.orm.Session(_SQLALCHEMY_ENGINE) as session:
2325
+ max_connections = session.execute(
2326
+ sqlalchemy.text('SHOW max_connections')).scalar()
2327
+ if max_connections is None:
2328
+ return None
2329
+ return int(max_connections)
sky/jobs/state.py CHANGED
@@ -157,7 +157,7 @@ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
157
157
  if _SQLALCHEMY_ENGINE is not None:
158
158
  return _SQLALCHEMY_ENGINE
159
159
  # get an engine to the db
160
- engine = migration_utils.get_engine('spot_jobs')
160
+ engine = db_utils.get_engine('spot_jobs')
161
161
 
162
162
  # run migrations if needed
163
163
  create_table(engine)
sky/models.py CHANGED
@@ -4,7 +4,7 @@ import collections
4
4
  import dataclasses
5
5
  import getpass
6
6
  import os
7
- from typing import Any, Dict, Optional
7
+ from typing import Any, ClassVar, Dict, Optional
8
8
 
9
9
  import pydantic
10
10
 
@@ -100,6 +100,11 @@ class KubernetesNodesInfo:
100
100
 
101
101
  class VolumeConfig(pydantic.BaseModel):
102
102
  """Configuration for creating a volume."""
103
+ # If any fields changed, increment the version. For backward compatibility,
104
+ # modify the __setstate__ method to handle the old version.
105
+ _VERSION: ClassVar[int] = 1
106
+
107
+ _version: int
103
108
  name: str
104
109
  type: str
105
110
  cloud: str
@@ -110,3 +115,21 @@ class VolumeConfig(pydantic.BaseModel):
110
115
  config: Dict[str, Any] = {}
111
116
  labels: Optional[Dict[str, str]] = None
112
117
  id_on_cloud: Optional[str] = None
118
+
119
+ def __getstate__(self) -> Dict[str, Any]:
120
+ state = super().__getstate__()
121
+ state['_version'] = self._VERSION
122
+ return state
123
+
124
+ def __setstate__(self, state: Dict[str, Any]) -> None:
125
+ """Set state from pickled state, for backward compatibility."""
126
+ super().__setstate__(state)
127
+ version = state.pop('_version', None)
128
+ if version is None:
129
+ version = -1
130
+
131
+ if version < 0:
132
+ state['id_on_cloud'] = None
133
+
134
+ state['_version'] = self._VERSION
135
+ self.__dict__.update(state)
@@ -1047,8 +1047,10 @@ def stop_instances(
1047
1047
  raise NotImplementedError()
1048
1048
 
1049
1049
 
1050
- def _delete_services(name_prefix: str, namespace: str,
1051
- context: Optional[str]) -> None:
1050
+ def _delete_services(name_prefix: str,
1051
+ namespace: str,
1052
+ context: Optional[str],
1053
+ skip_ssh_service: bool = False) -> None:
1052
1054
  """Delete services with the given name prefix.
1053
1055
 
1054
1056
  Args:
@@ -1057,7 +1059,9 @@ def _delete_services(name_prefix: str, namespace: str,
1057
1059
  context: Kubernetes context
1058
1060
  """
1059
1061
  # TODO(andy): We should use tag for the service filter.
1060
- for service_name in [name_prefix, f'{name_prefix}-ssh']:
1062
+ services = ([name_prefix, f'{name_prefix}-ssh']
1063
+ if not skip_ssh_service else [name_prefix])
1064
+ for service_name in services:
1061
1065
  # Since we are not saving this lambda, it's a false positive.
1062
1066
  # TODO(andyl): Wait for
1063
1067
  # https://github.com/pylint-dev/pylint/issues/5263.
@@ -1083,6 +1087,9 @@ def _terminate_node(namespace: str,
1083
1087
  # Delete services for the head pod
1084
1088
  # services are specified in sky/templates/kubernetes-ray.yml.j2
1085
1089
  _delete_services(pod_name, namespace, context)
1090
+ else:
1091
+ # No ssh service is created for worker pods
1092
+ _delete_services(pod_name, namespace, context, skip_ssh_service=True)
1086
1093
 
1087
1094
  # Note - delete pod after all other resources are deleted.
1088
1095
  # This is to ensure there are no leftover resources if this down is run
sky/serve/serve_state.py CHANGED
@@ -130,7 +130,7 @@ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
130
130
  if _SQLALCHEMY_ENGINE is not None:
131
131
  return _SQLALCHEMY_ENGINE
132
132
  # get an engine to the db
133
- engine = migration_utils.get_engine('serve/services')
133
+ engine = db_utils.get_engine('serve/services')
134
134
 
135
135
  # run migrations if needed
136
136
  create_table(engine)
sky/server/config.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import dataclasses
4
4
  import enum
5
+ from typing import Optional
5
6
 
6
7
  from sky import sky_logging
7
8
  from sky.server import constants as server_constants
@@ -61,6 +62,7 @@ class QueueBackend(enum.Enum):
61
62
  class WorkerConfig:
62
63
  garanteed_parallelism: int
63
64
  burstable_parallelism: int
65
+ num_db_connections_per_worker: int
64
66
 
65
67
 
66
68
  @dataclasses.dataclass
@@ -68,10 +70,13 @@ class ServerConfig:
68
70
  num_server_workers: int
69
71
  long_worker_config: WorkerConfig
70
72
  short_worker_config: WorkerConfig
73
+ num_db_connections_per_worker: int
71
74
  queue_backend: QueueBackend
72
75
 
73
76
 
74
- def compute_server_config(deploy: bool) -> ServerConfig:
77
+ def compute_server_config(deploy: bool,
78
+ max_db_connections: Optional[int] = None
79
+ ) -> ServerConfig:
75
80
  """Compute the server config based on environment.
76
81
 
77
82
  We have different assumptions for the resources in different deployment
@@ -114,7 +119,17 @@ def compute_server_config(deploy: bool) -> ServerConfig:
114
119
  queue_backend = QueueBackend.MULTIPROCESSING
115
120
  burstable_parallel_for_long = 0
116
121
  burstable_parallel_for_short = 0
122
+ # if num_db_connections_per_worker is 0, server will use NullPool
123
+ # to conserve the number of concurrent db connections.
124
+ # This could lead to performance degradation.
125
+ num_db_connections_per_worker = 0
117
126
  num_server_workers = cpu_count
127
+
128
+ # +1 for the event loop running the main process
129
+ # and gc daemons in the '__main__' body of sky/server/server.py
130
+ max_parallel_all_workers = (max_parallel_for_long + max_parallel_for_short +
131
+ num_server_workers + 1)
132
+
118
133
  if not deploy:
119
134
  # For local mode, use local queue backend since we only run 1 uvicorn
120
135
  # worker in local mode and no multiprocessing is needed.
@@ -140,6 +155,16 @@ def compute_server_config(deploy: bool) -> ServerConfig:
140
155
  'SkyPilot API server will run in low resource mode because '
141
156
  'the available memory is less than '
142
157
  f'{server_constants.MIN_AVAIL_MEM_GB}GB.')
158
+ elif max_db_connections is not None:
159
+ if max_parallel_all_workers > max_db_connections:
160
+ logger.warning(
161
+ f'Max parallel all workers ({max_parallel_all_workers}) '
162
+ f'is greater than max db connections ({max_db_connections}). '
163
+ 'Increase the number of max db connections to '
164
+ f'at least {max_parallel_all_workers} for optimal performance.')
165
+ else:
166
+ num_db_connections_per_worker = 1
167
+
143
168
  logger.info(
144
169
  f'SkyPilot API server will start {num_server_workers} server processes '
145
170
  f'with {max_parallel_for_long} background workers for long requests '
@@ -150,10 +175,13 @@ def compute_server_config(deploy: bool) -> ServerConfig:
150
175
  queue_backend=queue_backend,
151
176
  long_worker_config=WorkerConfig(
152
177
  garanteed_parallelism=max_parallel_for_long,
153
- burstable_parallelism=burstable_parallel_for_long),
178
+ burstable_parallelism=burstable_parallel_for_long,
179
+ num_db_connections_per_worker=num_db_connections_per_worker),
154
180
  short_worker_config=WorkerConfig(
155
181
  garanteed_parallelism=max_parallel_for_short,
156
- burstable_parallelism=burstable_parallel_for_short),
182
+ burstable_parallelism=burstable_parallel_for_short,
183
+ num_db_connections_per_worker=num_db_connections_per_worker),
184
+ num_db_connections_per_worker=num_db_connections_per_worker,
157
185
  )
158
186
 
159
187
 
@@ -57,6 +57,7 @@ from sky.utils import subprocess_utils
57
57
  from sky.utils import tempstore
58
58
  from sky.utils import timeline
59
59
  from sky.utils import yaml_utils
60
+ from sky.utils.db import db_utils
60
61
  from sky.workspaces import core as workspaces_core
61
62
 
62
63
  if typing.TYPE_CHECKING:
@@ -152,6 +153,8 @@ class RequestWorker:
152
153
  self.schedule_type = schedule_type
153
154
  self.garanteed_parallelism = config.garanteed_parallelism
154
155
  self.burstable_parallelism = config.burstable_parallelism
156
+ self.num_db_connections_per_worker = (
157
+ config.num_db_connections_per_worker)
155
158
  self._thread: Optional[threading.Thread] = None
156
159
  self._cancel_event = threading.Event()
157
160
 
@@ -190,8 +193,9 @@ class RequestWorker:
190
193
  # multiple requests can share the same process pid, which may cause
191
194
  # issues with SkyPilot core functions if they rely on the exit of
192
195
  # the process, such as subprocess_daemon.py.
193
- fut = executor.submit_until_success(_request_execution_wrapper,
194
- request_id, ignore_return_value)
196
+ fut = executor.submit_until_success(
197
+ _request_execution_wrapper, request_id, ignore_return_value,
198
+ self.num_db_connections_per_worker)
195
199
  # Monitor the result of the request execution.
196
200
  threading.Thread(target=self.handle_task_result,
197
201
  args=(fut, request_element),
@@ -351,7 +355,8 @@ def _sigterm_handler(signum: int, frame: Optional['types.FrameType']) -> None:
351
355
 
352
356
 
353
357
  def _request_execution_wrapper(request_id: str,
354
- ignore_return_value: bool) -> None:
358
+ ignore_return_value: bool,
359
+ num_db_connections_per_worker: int = 0) -> None:
355
360
  """Wrapper for a request execution.
356
361
 
357
362
  It wraps the execution of a request to:
@@ -362,6 +367,7 @@ def _request_execution_wrapper(request_id: str,
362
367
  4. Handle the SIGTERM signal to abort the request gracefully.
363
368
  5. Maintain the lifecycle of the temp dir used by the request.
364
369
  """
370
+ db_utils.set_max_connections(num_db_connections_per_worker)
365
371
  # Handle the SIGTERM signal to abort the request processing gracefully.
366
372
  signal.signal(signal.SIGTERM, _sigterm_handler)
367
373
 
@@ -1,5 +1,6 @@
1
1
  """Utilities for REST API."""
2
2
  import asyncio
3
+ import atexit
3
4
  import contextlib
4
5
  import dataclasses
5
6
  import enum
@@ -16,6 +17,7 @@ import traceback
16
17
  from typing import (Any, AsyncContextManager, Callable, Dict, Generator, List,
17
18
  NamedTuple, Optional, Tuple)
18
19
 
20
+ import anyio
19
21
  import colorama
20
22
  import filelock
21
23
 
@@ -31,7 +33,6 @@ from sky.server.requests import payloads
31
33
  from sky.server.requests.serializers import decoders
32
34
  from sky.server.requests.serializers import encoders
33
35
  from sky.utils import common_utils
34
- from sky.utils import subprocess_utils
35
36
  from sky.utils import ux_utils
36
37
  from sky.utils.db import db_utils
37
38
 
@@ -783,17 +784,15 @@ def set_request_cancelled(request_id: str) -> None:
783
784
 
784
785
  @init_db
785
786
  @metrics_lib.time_me
786
- def _delete_requests(requests: List[Request]):
787
+ async def _delete_requests(requests: List[Request]):
787
788
  """Clean up requests by their IDs."""
788
789
  id_list_str = ','.join(repr(req.request_id) for req in requests)
789
790
  assert _DB is not None
790
- with _DB.conn:
791
- cursor = _DB.conn.cursor()
792
- cursor.execute(
793
- f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
791
+ await _DB.execute_and_commit_async(
792
+ f'DELETE FROM {REQUEST_TABLE} WHERE request_id IN ({id_list_str})')
794
793
 
795
794
 
796
- def clean_finished_requests_with_retention(retention_seconds: int):
795
+ async def clean_finished_requests_with_retention(retention_seconds: int):
797
796
  """Clean up finished requests older than the retention period.
798
797
 
799
798
  This function removes old finished requests (SUCCEEDED, FAILED, CANCELLED)
@@ -803,17 +802,19 @@ def clean_finished_requests_with_retention(retention_seconds: int):
803
802
  retention_seconds: Requests older than this many seconds will be
804
803
  deleted.
805
804
  """
806
- reqs = get_request_tasks(
805
+ reqs = await get_request_tasks_async(
807
806
  req_filter=RequestTaskFilter(status=RequestStatus.finished_status(),
808
807
  finished_before=time.time() -
809
808
  retention_seconds))
810
809
 
811
- subprocess_utils.run_in_parallel(
812
- func=lambda req: req.log_path.unlink(missing_ok=True),
813
- args=reqs,
814
- num_threads=len(reqs))
810
+ futs = []
811
+ for req in reqs:
812
+ futs.append(
813
+ asyncio.create_task(
814
+ anyio.Path(req.log_path.absolute()).unlink(missing_ok=True)))
815
+ await asyncio.gather(*futs)
815
816
 
816
- _delete_requests(reqs)
817
+ await _delete_requests(reqs)
817
818
 
818
819
  # To avoid leakage of the log file, logs must be deleted before the
819
820
  # request task in the database.
@@ -838,7 +839,16 @@ async def requests_gc_daemon():
838
839
  logger.info('Requests GC daemon cancelled')
839
840
  break
840
841
  except Exception as e: # pylint: disable=broad-except
841
- logger.error(f'Error running requests GC daemon: {e}')
842
+ logger.error(f'Error running requests GC daemon: {e}'
843
+ f'traceback: {traceback.format_exc()}')
842
844
  # Run the daemon at most once every hour to avoid too frequent
843
845
  # cleanup.
844
846
  await asyncio.sleep(max(retention_seconds, 3600))
847
+
848
+
849
+ def _cleanup():
850
+ if _DB is not None:
851
+ asyncio.run(_DB.close())
852
+
853
+
854
+ atexit.register(_cleanup)
sky/server/server.py CHANGED
@@ -24,7 +24,6 @@ import aiofiles
24
24
  import anyio
25
25
  import fastapi
26
26
  from fastapi.middleware import cors
27
- from sqlalchemy import pool
28
27
  import starlette.middleware.base
29
28
  import uvloop
30
29
 
@@ -72,6 +71,7 @@ from sky.utils import dag_utils
72
71
  from sky.utils import perf_utils
73
72
  from sky.utils import status_lib
74
73
  from sky.utils import subprocess_utils
74
+ from sky.utils.db import db_utils
75
75
  from sky.volumes.server import server as volumes_rest
76
76
  from sky.workspaces import server as workspaces_rest
77
77
 
@@ -1322,18 +1322,17 @@ async def download(download_body: payloads.DownloadBody,
1322
1322
  detail=f'Error creating zip file: {str(e)}')
1323
1323
 
1324
1324
 
1325
+ # TODO(aylei): run it asynchronously after global_user_state support async op
1325
1326
  @app.post('/provision_logs')
1326
- async def provision_logs(cluster_body: payloads.ClusterNameBody,
1327
- follow: bool = True,
1328
- tail: int = 0) -> fastapi.responses.StreamingResponse:
1327
+ def provision_logs(cluster_body: payloads.ClusterNameBody,
1328
+ follow: bool = True,
1329
+ tail: int = 0) -> fastapi.responses.StreamingResponse:
1329
1330
  """Streams the provision.log for the latest launch request of a cluster."""
1330
1331
  # Prefer clusters table first, then cluster_history as fallback.
1331
- log_path_str = await context_utils.to_thread(
1332
- global_user_state.get_cluster_provision_log_path,
1332
+ log_path_str = global_user_state.get_cluster_provision_log_path(
1333
1333
  cluster_body.cluster_name)
1334
1334
  if not log_path_str:
1335
- log_path_str = await context_utils.to_thread(
1336
- global_user_state.get_cluster_history_provision_log_path,
1335
+ log_path_str = global_user_state.get_cluster_history_provision_log_path(
1337
1336
  cluster_body.cluster_name)
1338
1337
  if not log_path_str:
1339
1338
  raise fastapi.HTTPException(
@@ -1908,13 +1907,6 @@ if __name__ == '__main__':
1908
1907
 
1909
1908
  skyuvicorn.add_timestamp_prefix_for_server_logs()
1910
1909
 
1911
- # Initialize global user state db
1912
- global_user_state.initialize_and_get_db(pool.QueuePool)
1913
- # Initialize request db
1914
- requests_lib.reset_db_and_logs()
1915
- # Restore the server user hash
1916
- _init_or_restore_server_user_hash()
1917
-
1918
1910
  parser = argparse.ArgumentParser()
1919
1911
  parser.add_argument('--host', default='127.0.0.1')
1920
1912
  parser.add_argument('--port', default=46580, type=int)
@@ -1930,7 +1922,17 @@ if __name__ == '__main__':
1930
1922
  # that it is shown only when the API server is started.
1931
1923
  usage_lib.maybe_show_privacy_policy()
1932
1924
 
1933
- config = server_config.compute_server_config(cmd_args.deploy)
1925
+ # Initialize global user state db
1926
+ db_utils.set_max_connections(1)
1927
+ global_user_state.initialize_and_get_db()
1928
+ # Initialize request db
1929
+ requests_lib.reset_db_and_logs()
1930
+ # Restore the server user hash
1931
+ _init_or_restore_server_user_hash()
1932
+ max_db_connections = global_user_state.get_max_db_connections()
1933
+ config = server_config.compute_server_config(cmd_args.deploy,
1934
+ max_db_connections)
1935
+
1934
1936
  num_workers = config.num_server_workers
1935
1937
 
1936
1938
  queue_server: Optional[multiprocessing.Process] = None
@@ -1955,11 +1957,12 @@ if __name__ == '__main__':
1955
1957
  logger.info(f'Starting SkyPilot API server, workers={num_workers}')
1956
1958
  # We don't support reload for now, since it may cause leakage of request
1957
1959
  # workers or interrupt running requests.
1958
- config = uvicorn.Config('sky.server.server:app',
1959
- host=cmd_args.host,
1960
- port=cmd_args.port,
1961
- workers=num_workers)
1962
- skyuvicorn.run(config)
1960
+ uvicorn_config = uvicorn.Config('sky.server.server:app',
1961
+ host=cmd_args.host,
1962
+ port=cmd_args.port,
1963
+ workers=num_workers)
1964
+ skyuvicorn.run(uvicorn_config,
1965
+ max_db_connections=config.num_db_connections_per_worker)
1963
1966
  except Exception as exc: # pylint: disable=broad-except
1964
1967
  logger.error(f'Failed to start SkyPilot API server: '
1965
1968
  f'{common_utils.format_exception(exc, use_bracket=True)}')
sky/server/uvicorn.py CHANGED
@@ -26,6 +26,7 @@ from sky.utils import context_utils
26
26
  from sky.utils import env_options
27
27
  from sky.utils import perf_utils
28
28
  from sky.utils import subprocess_utils
29
+ from sky.utils.db import db_utils
29
30
 
30
31
  logger = sky_logging.init_logger(__name__)
31
32
 
@@ -88,9 +89,12 @@ class Server(uvicorn.Server):
88
89
  - Run the server process with contextually aware.
89
90
  """
90
91
 
91
- def __init__(self, config: uvicorn.Config):
92
+ def __init__(self,
93
+ config: uvicorn.Config,
94
+ max_db_connections: Optional[int] = None):
92
95
  super().__init__(config=config)
93
96
  self.exiting: bool = False
97
+ self.max_db_connections = max_db_connections
94
98
 
95
99
  def handle_exit(self, sig: int, frame: Union[FrameType, None]) -> None:
96
100
  """Handle exit signal.
@@ -196,6 +200,8 @@ class Server(uvicorn.Server):
196
200
 
197
201
  def run(self, *args, **kwargs):
198
202
  """Run the server process."""
203
+ if self.max_db_connections is not None:
204
+ db_utils.set_max_connections(self.max_db_connections)
199
205
  add_timestamp_prefix_for_server_logs()
200
206
  context_utils.hijack_sys_attrs()
201
207
  # Use default loop policy of uvicorn (use uvloop if available).
@@ -210,14 +216,14 @@ class Server(uvicorn.Server):
210
216
  asyncio.run(self.serve(*args, **kwargs))
211
217
 
212
218
 
213
- def run(config: uvicorn.Config):
219
+ def run(config: uvicorn.Config, max_db_connections: Optional[int] = None):
214
220
  """Run unvicorn server."""
215
221
  if config.reload:
216
222
  # Reload and multi-workers are mutually exclusive
217
223
  # in uvicorn. Since we do not use reload now, simply
218
224
  # guard by an exception.
219
225
  raise ValueError('Reload is not supported yet.')
220
- server = Server(config=config)
226
+ server = Server(config=config, max_db_connections=max_db_connections)
221
227
  try:
222
228
  if config.workers is not None and config.workers > 1:
223
229
  sock = config.bind_socket()
sky/skylet/constants.py CHANGED
@@ -362,7 +362,7 @@ SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user'
362
362
 
363
363
  RCLONE_CONFIG_DIR = '~/.config/rclone'
364
364
  RCLONE_CONFIG_PATH = f'{RCLONE_CONFIG_DIR}/rclone.conf'
365
- RCLONE_LOG_DIR = '~/.sky/rclone_log'
365
+ RCLONE_MOUNT_CACHED_LOG_DIR = '~/.sky/rclone_log'
366
366
  RCLONE_CACHE_DIR = '~/.cache/rclone'
367
367
  RCLONE_CACHE_REFRESH_INTERVAL = 10
368
368
 
sky/skypilot_config.py CHANGED
@@ -227,7 +227,7 @@ def _get_config_from_path(path: Optional[str]) -> config_utils.Config:
227
227
  return parse_and_validate_config_file(path)
228
228
 
229
229
 
230
- def _resolve_user_config_path() -> Optional[str]:
230
+ def resolve_user_config_path() -> Optional[str]:
231
231
  # find the user config file path, None if not resolved.
232
232
  user_config_path = _get_config_file_path(ENV_VAR_GLOBAL_CONFIG)
233
233
  if user_config_path:
@@ -252,7 +252,7 @@ def _resolve_user_config_path() -> Optional[str]:
252
252
 
253
253
  def get_user_config() -> config_utils.Config:
254
254
  """Returns the user config."""
255
- return _get_config_from_path(_resolve_user_config_path())
255
+ return _get_config_from_path(resolve_user_config_path())
256
256
 
257
257
 
258
258
  def _resolve_project_config_path() -> Optional[str]:
@@ -574,8 +574,13 @@ def _reload_config_as_server() -> None:
574
574
  'If db config is specified, no other config is allowed')
575
575
  logger.debug('retrieving config from database')
576
576
  with _DB_USE_LOCK:
577
- sqlalchemy_engine = sqlalchemy.create_engine(db_url,
578
- poolclass=NullPool)
577
+ dispose_engine = False
578
+ if db_utils.get_max_connections() == 0:
579
+ dispose_engine = True
580
+ sqlalchemy_engine = sqlalchemy.create_engine(db_url,
581
+ poolclass=NullPool)
582
+ else:
583
+ sqlalchemy_engine = db_utils.get_engine('config')
579
584
  db_utils.add_all_tables_to_db_sqlalchemy(Base.metadata,
580
585
  sqlalchemy_engine)
581
586
 
@@ -597,7 +602,8 @@ def _reload_config_as_server() -> None:
597
602
  server_config = overlay_skypilot_config(server_config,
598
603
  db_config)
599
604
  # Close the engine to avoid connection leaks
600
- sqlalchemy_engine.dispose()
605
+ if dispose_engine:
606
+ sqlalchemy_engine.dispose()
601
607
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
602
608
  logger.debug(f'server config: \n'
603
609
  f'{yaml_utils.dump_yaml_str(dict(server_config))}')
@@ -611,7 +617,7 @@ def _reload_config_as_client() -> None:
611
617
  _set_loaded_config_path(None)
612
618
 
613
619
  overrides: List[config_utils.Config] = []
614
- user_config_path = _resolve_user_config_path()
620
+ user_config_path = resolve_user_config_path()
615
621
  user_config = _get_config_from_path(user_config_path)
616
622
  if user_config:
617
623
  overrides.append(user_config)
@@ -867,8 +873,13 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
867
873
  raise ValueError('Cannot change db url while server is running')
868
874
  if existing_db_url:
869
875
  with _DB_USE_LOCK:
870
- sqlalchemy_engine = sqlalchemy.create_engine(existing_db_url,
871
- poolclass=NullPool)
876
+ dispose_engine = False
877
+ if db_utils.get_max_connections() == 0:
878
+ dispose_engine = True
879
+ sqlalchemy_engine = sqlalchemy.create_engine(
880
+ existing_db_url, poolclass=NullPool)
881
+ else:
882
+ sqlalchemy_engine = db_utils.get_engine('config')
872
883
  db_utils.add_all_tables_to_db_sqlalchemy(
873
884
  Base.metadata, sqlalchemy_engine)
874
885
 
@@ -897,7 +908,8 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
897
908
  _set_config_yaml_to_db(API_SERVER_CONFIG_KEY, config)
898
909
  db_updated = True
899
910
  # Close the engine to avoid connection leaks
900
- sqlalchemy_engine.dispose()
911
+ if dispose_engine:
912
+ sqlalchemy_engine.dispose()
901
913
 
902
914
  if not db_updated:
903
915
  # save to the local file (PVC in Kubernetes, local file otherwise)
@@ -15,7 +15,7 @@ router = fastapi.APIRouter()
15
15
 
16
16
 
17
17
  @router.get('')
18
- async def get_ssh_node_pools() -> Dict[str, Any]:
18
+ def get_ssh_node_pools() -> Dict[str, Any]:
19
19
  """Get all SSH Node Pool configurations."""
20
20
  try:
21
21
  return ssh_node_pools_core.get_all_pools()
@@ -27,7 +27,7 @@ async def get_ssh_node_pools() -> Dict[str, Any]:
27
27
 
28
28
 
29
29
  @router.post('')
30
- async def update_ssh_node_pools(pools_config: Dict[str, Any]) -> Dict[str, str]:
30
+ def update_ssh_node_pools(pools_config: Dict[str, Any]) -> Dict[str, str]:
31
31
  """Update SSH Node Pool configurations."""
32
32
  try:
33
33
  ssh_node_pools_core.update_pools(pools_config)
@@ -39,7 +39,7 @@ async def update_ssh_node_pools(pools_config: Dict[str, Any]) -> Dict[str, str]:
39
39
 
40
40
 
41
41
  @router.delete('/{pool_name}')
42
- async def delete_ssh_node_pool(pool_name: str) -> Dict[str, str]:
42
+ def delete_ssh_node_pool(pool_name: str) -> Dict[str, str]:
43
43
  """Delete a SSH Node Pool configuration."""
44
44
  try:
45
45
  if ssh_node_pools_core.delete_pool(pool_name):
@@ -83,7 +83,7 @@ async def upload_ssh_key(request: fastapi.Request) -> Dict[str, str]:
83
83
 
84
84
 
85
85
  @router.get('/keys')
86
- async def list_ssh_keys() -> List[str]:
86
+ def list_ssh_keys() -> List[str]:
87
87
  """List available SSH keys."""
88
88
  try:
89
89
  return ssh_node_pools_core.list_ssh_keys()
@@ -200,7 +200,7 @@ async def down_ssh_node_pool_general(
200
200
 
201
201
 
202
202
  @router.get('/{pool_name}/status')
203
- async def get_ssh_node_pool_status(pool_name: str) -> Dict[str, str]:
203
+ def get_ssh_node_pool_status(pool_name: str) -> Dict[str, str]:
204
204
  """Get the status of a specific SSH Node Pool."""
205
205
  try:
206
206
  # Call ssh_status to check the context
sky/users/permission.py CHANGED
@@ -226,6 +226,12 @@ class PermissionService:
226
226
  self._load_policy_no_lock()
227
227
  return self.enforcer.get_roles_for_user(user_id)
228
228
 
229
+ def get_users_for_role(self, role: str) -> List[str]:
230
+ """Get all users for a role."""
231
+ self._lazy_initialize()
232
+ self._load_policy_no_lock()
233
+ return self.enforcer.get_users_for_role(role)
234
+
229
235
  def check_endpoint_permission(self, user_id: str, path: str,
230
236
  method: str) -> bool:
231
237
  """Check permission."""