skypilot-nightly 1.0.0.dev20250723__py3-none-any.whl → 1.0.0.dev20250725__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (49) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +27 -1
  3. sky/client/cli/command.py +61 -21
  4. sky/client/sdk.pyi +296 -0
  5. sky/clouds/utils/oci_utils.py +16 -40
  6. sky/dashboard/out/404.html +1 -1
  7. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  8. sky/dashboard/out/clusters/[cluster].html +1 -1
  9. sky/dashboard/out/clusters.html +1 -1
  10. sky/dashboard/out/config.html +1 -1
  11. sky/dashboard/out/index.html +1 -1
  12. sky/dashboard/out/infra/[context].html +1 -1
  13. sky/dashboard/out/infra.html +1 -1
  14. sky/dashboard/out/jobs/[job].html +1 -1
  15. sky/dashboard/out/jobs.html +1 -1
  16. sky/dashboard/out/users.html +1 -1
  17. sky/dashboard/out/volumes.html +1 -1
  18. sky/dashboard/out/workspace/new.html +1 -1
  19. sky/dashboard/out/workspaces/[name].html +1 -1
  20. sky/dashboard/out/workspaces.html +1 -1
  21. sky/exceptions.py +8 -0
  22. sky/global_user_state.py +12 -23
  23. sky/jobs/state.py +12 -24
  24. sky/logs/__init__.py +4 -0
  25. sky/logs/agent.py +14 -0
  26. sky/logs/aws.py +276 -0
  27. sky/server/common.py +14 -1
  28. sky/server/requests/payloads.py +20 -4
  29. sky/server/rest.py +6 -0
  30. sky/server/server.py +5 -1
  31. sky/templates/aws-ray.yml.j2 +7 -1
  32. sky/templates/azure-ray.yml.j2 +1 -1
  33. sky/templates/do-ray.yml.j2 +1 -1
  34. sky/templates/lambda-ray.yml.j2 +1 -1
  35. sky/templates/nebius-ray.yml.j2 +1 -1
  36. sky/templates/paperspace-ray.yml.j2 +1 -1
  37. sky/templates/runpod-ray.yml.j2 +1 -1
  38. sky/utils/config_utils.py +6 -4
  39. sky/utils/db/migration_utils.py +60 -19
  40. sky/utils/rich_utils.py +2 -3
  41. sky/utils/schemas.py +67 -22
  42. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/METADATA +1 -1
  43. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/RECORD +49 -47
  44. /sky/dashboard/out/_next/static/{mym3Ciwp-zqU7ZpOLGnrW → SiA7c33x_DqO42M373Okd}/_buildManifest.js +0 -0
  45. /sky/dashboard/out/_next/static/{mym3Ciwp-zqU7ZpOLGnrW → SiA7c33x_DqO42M373Okd}/_ssgManifest.js +0 -0
  46. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/WHEEL +0 -0
  47. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/entry_points.txt +0 -0
  48. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/licenses/LICENSE +0 -0
  49. {skypilot_nightly-1.0.0.dev20250723.dist-info → skypilot_nightly-1.0.0.dev20250725.dist-info}/top_level.txt +0 -0
sky/logs/aws.py ADDED
@@ -0,0 +1,276 @@
1
+ """AWS CloudWatch logging agent."""
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import pydantic
6
+
7
+ from sky.logs.agent import FluentbitAgent
8
+ from sky.skylet import constants
9
+ from sky.utils import common_utils
10
+ from sky.utils import resources_utils
11
+
12
+
13
+ class _CloudwatchLoggingConfig(pydantic.BaseModel):
14
+ """Configuration for AWS CloudWatch logging agent."""
15
+ region: Optional[str] = None
16
+ credentials_file: Optional[str] = None
17
+ log_group_name: str = 'skypilot-logs'
18
+ log_stream_prefix: str = 'skypilot-'
19
+ auto_create_group: bool = True
20
+ additional_tags: Optional[Dict[str, str]] = None
21
+
22
+
23
+ class _CloudWatchOutputConfig(pydantic.BaseModel):
24
+ """Auxiliary model for building CloudWatch output config in YAML.
25
+
26
+ Ref: https://docs.fluentbit.io/manual/pipeline/outputs/cloudwatch
27
+ """
28
+ name: str = 'cloudwatch_logs'
29
+ match: str = '*'
30
+ region: Optional[str] = None
31
+ log_group_name: Optional[str] = None
32
+ log_stream_prefix: Optional[str] = None
33
+ auto_create_group: bool = True
34
+ additional_tags: Optional[Dict[str, str]] = None
35
+
36
+ def to_dict(self) -> Dict[str, Any]:
37
+ config = self.model_dump(exclude_none=True)
38
+ if 'auto_create_group' in config:
39
+ config['auto_create_group'] = 'true' if config[
40
+ 'auto_create_group'] else 'false'
41
+ return config
42
+
43
+
44
+ class CloudwatchLoggingAgent(FluentbitAgent):
45
+ """AWS CloudWatch logging agent.
46
+
47
+ This agent forwards logs from SkyPilot clusters to AWS CloudWatch using
48
+ Fluent Bit. It supports authentication via IAM roles (preferred), AWS
49
+ credentials file, or environment variables.
50
+
51
+ Example configuration:
52
+ ```yaml
53
+ logs:
54
+ store: aws
55
+ aws:
56
+ region: us-west-2
57
+ log_group_name: skypilot-logs
58
+ log_stream_prefix: my-cluster-
59
+ auto_create_group: true
60
+ ```
61
+ """
62
+
63
+ def __init__(self, config: Dict[str, Any]):
64
+ """Initialize the CloudWatch logging agent.
65
+
66
+ Args:
67
+ config: The configuration for the CloudWatch logging agent.
68
+ See the class docstring for the expected format.
69
+ """
70
+ self.config = _CloudwatchLoggingConfig(**config)
71
+ super().__init__()
72
+
73
+ def get_setup_command(self,
74
+ cluster_name: resources_utils.ClusterName) -> str:
75
+ """Get the command to set up the CloudWatch logging agent.
76
+
77
+ Args:
78
+ cluster_name: The name of the cluster.
79
+
80
+ Returns:
81
+ The command to set up the CloudWatch logging agent.
82
+ """
83
+
84
+ if self.config.credentials_file:
85
+ credential_path = self.config.credentials_file
86
+
87
+ # Set AWS credentials and check whether credentials are valid.
88
+ # CloudWatch plugin supports IAM roles, credentials file, and
89
+ # environment variables. We prefer IAM roles when available
90
+ # (on EC2 instances). If credentials file is provided, we use
91
+ # it. Otherwise, we check if credentials are available in
92
+ # the environment.
93
+ pre_cmd = ''
94
+ if self.config.credentials_file:
95
+ pre_cmd = (
96
+ f'export AWS_SHARED_CREDENTIALS_FILE={credential_path}; '
97
+ f'if [ ! -f {credential_path} ]; then '
98
+ f'echo "ERROR: AWS credentials file {credential_path} '
99
+ f'not found. Please check if the file exists and is '
100
+ f'accessible." && exit 1; '
101
+ f'fi; '
102
+ f'if ! grep -q "\\[.*\\]" {credential_path} || '
103
+ f'! grep -q "aws_access_key_id" {credential_path}; then '
104
+ f'echo "ERROR: AWS credentials file {credential_path} is '
105
+ f'invalid. It should contain a profile section '
106
+ f'[profile_name] and aws_access_key_id." && exit 1; '
107
+ f'fi;')
108
+ else:
109
+ # Check if we're running on EC2 with an IAM role or if
110
+ # AWS credentials are available in the environment
111
+ pre_cmd = (
112
+ 'if ! curl -s -m 1 http://169.254.169.254'
113
+ '/latest/meta-data/iam/security-credentials/ > /dev/null; '
114
+ 'then '
115
+ # failed EC2 check, look for env vars
116
+ 'if [ -z "$AWS_ACCESS_KEY_ID" ] || '
117
+ '[ -z "$AWS_SECRET_ACCESS_KEY" ]; then '
118
+ 'echo "ERROR: AWS CloudWatch logging configuration error. '
119
+ 'Not running on EC2 with IAM role and AWS credentials not '
120
+ 'found in environment. Please do one of the following: '
121
+ '1. Run on an EC2 instance with an IAM role that has '
122
+ 'CloudWatch permissions, 2. Set AWS_ACCESS_KEY_ID and '
123
+ 'AWS_SECRET_ACCESS_KEY environment variables, or '
124
+ '3. Provide a credentials file via logs.aws.credentials_file '
125
+ 'in SkyPilot config." && exit 1; '
126
+ 'fi; '
127
+ 'fi;')
128
+
129
+ # If region is specified, set it in the environment
130
+ if self.config.region:
131
+ pre_cmd += f' export AWS_REGION={self.config.region};'
132
+ else:
133
+ # If region is not specified, check if it's available in
134
+ # the environment or credentials file
135
+ pre_cmd += (
136
+ ' if [ -z "$AWS_REGION" ] && '
137
+ '[ -z "$AWS_DEFAULT_REGION" ]; then '
138
+ 'echo "WARNING: AWS region not specified in configuration or '
139
+ 'environment. CloudWatch logging may fail if the region '
140
+ 'cannot be determined. Consider setting logs.aws.region in '
141
+ 'SkyPilot config."; '
142
+ 'fi; ')
143
+
144
+ # Add a test command to verify AWS credentials work with CloudWatch
145
+ pre_cmd += (
146
+ ' echo "Verifying AWS CloudWatch access..."; '
147
+ 'if command -v aws > /dev/null; then '
148
+ 'aws cloudwatch list-metrics --namespace AWS/Logs --max-items 1 '
149
+ '> /dev/null 2>&1 || '
150
+ '{ echo "ERROR: Failed to access AWS CloudWatch. Please check '
151
+ 'your credentials and permissions."; '
152
+ 'echo "The IAM role or user must have cloudwatch:ListMetrics '
153
+ 'and logs:* permissions."; '
154
+ 'exit 1; }; '
155
+ 'else echo "AWS CLI not installed, skipping CloudWatch access '
156
+ 'verification."; '
157
+ 'fi; ')
158
+
159
+ return pre_cmd + ' ' + super().get_setup_command(cluster_name)
160
+
161
+ def fluentbit_config(self,
162
+ cluster_name: resources_utils.ClusterName) -> str:
163
+ """Get the Fluent Bit configuration for CloudWatch.
164
+
165
+ This overrides the base method to add a fallback output for local file
166
+ logging in case CloudWatch logging fails.
167
+
168
+ Args:
169
+ cluster_name: The name of the cluster.
170
+
171
+ Returns:
172
+ The Fluent Bit configuration as a YAML string.
173
+ """
174
+ display_name = cluster_name.display_name
175
+ unique_name = cluster_name.name_on_cloud
176
+ # Build tags for the log stream
177
+ tags = {
178
+ 'skypilot.cluster_name': display_name,
179
+ 'skypilot.cluster_id': unique_name,
180
+ }
181
+
182
+ # Add additional tags if provided
183
+ if self.config.additional_tags:
184
+ tags.update(self.config.additional_tags)
185
+
186
+ log_processors = []
187
+ for key, value in tags.items():
188
+ log_processors.append({
189
+ 'name': 'content_modifier',
190
+ 'action': 'upsert',
191
+ 'key': key,
192
+ 'value': value
193
+ })
194
+
195
+ cfg_dict = {
196
+ 'pipeline': {
197
+ 'inputs': [{
198
+ 'name': 'tail',
199
+ 'path': f'{constants.SKY_LOGS_DIRECTORY}/*/*.log',
200
+ 'path_key': 'log_path',
201
+ # Shorten the refresh interval from 60s to 1s since every
202
+ # job creates a new log file and we must be responsive
203
+ # for this: the VM might be autodown within a minute
204
+ # right after the job completion.
205
+ 'refresh_interval': 1,
206
+ 'processors': {
207
+ 'logs': log_processors,
208
+ }
209
+ }],
210
+ 'outputs': [self.fluentbit_output_config(cluster_name)],
211
+ }
212
+ }
213
+
214
+ # Add fallback outputs for graceful failure handling
215
+ cfg_dict = self.add_fallback_outputs(cfg_dict)
216
+
217
+ return common_utils.dump_yaml_str(cfg_dict)
218
+
219
+ def add_fallback_outputs(self, cfg_dict: Dict[str, Any]) -> Dict[str, Any]:
220
+ """Add fallback outputs to the Fluent Bit configuration.
221
+
222
+ This adds a local file output as a fallback in case
223
+ CloudWatch logging fails.
224
+
225
+ Args:
226
+ cfg_dict: The Fluent Bit configuration dictionary.
227
+
228
+ Returns:
229
+ The updated configuration dictionary.
230
+ """
231
+ # Add a local file output as a fallback
232
+ fallback_output = {
233
+ 'name': 'file',
234
+ 'match': '*',
235
+ 'path': '/tmp/skypilot_logs_fallback.log',
236
+ 'format': 'out_file',
237
+ }
238
+
239
+ # Add the fallback output to the configuration
240
+ cfg_dict['pipeline']['outputs'].append(fallback_output)
241
+
242
+ return cfg_dict
243
+
244
+ def fluentbit_output_config(
245
+ self, cluster_name: resources_utils.ClusterName) -> Dict[str, Any]:
246
+ """Get the Fluent Bit output configuration for CloudWatch.
247
+
248
+ Args:
249
+ cluster_name: The name of the cluster.
250
+
251
+ Returns:
252
+ The Fluent Bit output configuration for CloudWatch.
253
+ """
254
+ unique_name = cluster_name.name_on_cloud
255
+
256
+ # Format the log stream name to include cluster information
257
+ # This helps with identifying logs in CloudWatch
258
+ log_stream_prefix = f'{self.config.log_stream_prefix}{unique_name}-'
259
+
260
+ # Create the CloudWatch output configuration with error handling options
261
+ return _CloudWatchOutputConfig(
262
+ region=self.config.region,
263
+ log_group_name=self.config.log_group_name,
264
+ log_stream_prefix=log_stream_prefix,
265
+ auto_create_group=self.config.auto_create_group,
266
+ ).to_dict()
267
+
268
+ def get_credential_file_mounts(self) -> Dict[str, str]:
269
+ """Get the credential file mounts for the CloudWatch logging agent.
270
+
271
+ Returns:
272
+ A dictionary mapping local credential file paths to remote paths.
273
+ """
274
+ if self.config.credentials_file:
275
+ return {self.config.credentials_file: self.config.credentials_file}
276
+ return {}
sky/server/common.py CHANGED
@@ -13,12 +13,14 @@ import shutil
13
13
  import subprocess
14
14
  import sys
15
15
  import tempfile
16
+ import threading
16
17
  import time
17
18
  import typing
18
19
  from typing import Any, Dict, Literal, Optional, Tuple, Union
19
20
  from urllib import parse
20
21
  import uuid
21
22
 
23
+ import cachetools
22
24
  import colorama
23
25
  import filelock
24
26
 
@@ -132,6 +134,8 @@ def get_api_cookie_jar() -> requests.cookies.RequestsCookieJar:
132
134
  def set_api_cookie_jar(cookie_jar: CookieJar,
133
135
  create_if_not_exists: bool = True) -> None:
134
136
  """Updates the file cookie jar with the given cookie jar."""
137
+ if len(cookie_jar) == 0:
138
+ return
135
139
  cookie_path = get_api_cookie_jar_path()
136
140
  if not cookie_path.exists() and not create_if_not_exists:
137
141
  # if the file doesn't exist and we don't want to create it, do nothing
@@ -274,6 +278,10 @@ def _handle_non_200_server_status(
274
278
  return ApiServerInfo(status=ApiServerStatus.UNHEALTHY)
275
279
 
276
280
 
281
+ @cachetools.cached(cache=cachetools.TTLCache(maxsize=10,
282
+ ttl=5.0,
283
+ timer=time.time),
284
+ lock=threading.RLock())
277
285
  def get_api_server_status(endpoint: Optional[str] = None) -> ApiServerInfo:
278
286
  """Retrieve the status of the API server.
279
287
 
@@ -351,7 +359,9 @@ def get_api_server_status(endpoint: Optional[str] = None) -> ApiServerInfo:
351
359
  error=version_info.error)
352
360
 
353
361
  cookies = get_cookies_from_response(response)
354
- set_api_cookie_jar(cookies, create_if_not_exists=False)
362
+ # Save or refresh the cookie jar in case of session affinity and
363
+ # OAuth.
364
+ set_api_cookie_jar(cookies, create_if_not_exists=True)
355
365
  return server_info
356
366
  except (json.JSONDecodeError, AttributeError) as e:
357
367
  # Try to check if we got redirected to a login page.
@@ -409,6 +419,7 @@ def _start_api_server(deploy: bool = False,
409
419
  server_url = get_server_url(host)
410
420
  assert server_url in AVAILABLE_LOCAL_API_SERVER_URLS, (
411
421
  f'server url {server_url} is not a local url')
422
+
412
423
  with rich_utils.client_status('Starting SkyPilot API server, '
413
424
  f'view logs at {constants.API_SERVER_LOGS}'):
414
425
  logger.info(f'{colorama.Style.DIM}Failed to connect to '
@@ -484,6 +495,8 @@ def _start_api_server(deploy: bool = False,
484
495
  'SkyPilot API server process exited unexpectedly.\n'
485
496
  f'View logs at: {constants.API_SERVER_LOGS}')
486
497
  try:
498
+ # Clear the cache to ensure fresh checks during startup
499
+ get_api_server_status.cache_clear() # type: ignore
487
500
  check_server_healthy()
488
501
  except exceptions.APIVersionMismatchError:
489
502
  raise
@@ -203,17 +203,33 @@ class DagRequestBody(RequestBody):
203
203
  return kwargs
204
204
 
205
205
 
206
- class ValidateBody(DagRequestBody):
206
+ class DagRequestBodyWithRequestOptions(DagRequestBody):
207
+ """Request body base class for endpoints with a dag and request options."""
208
+ request_options: Optional[admin_policy.RequestOptions]
209
+
210
+ def get_request_options(self) -> Optional[admin_policy.RequestOptions]:
211
+ """Get the request options."""
212
+ if self.request_options is None:
213
+ return None
214
+ if isinstance(self.request_options, dict):
215
+ return admin_policy.RequestOptions(**self.request_options)
216
+ return self.request_options
217
+
218
+ def to_kwargs(self) -> Dict[str, Any]:
219
+ kwargs = super().to_kwargs()
220
+ kwargs['request_options'] = self.get_request_options()
221
+ return kwargs
222
+
223
+
224
+ class ValidateBody(DagRequestBodyWithRequestOptions):
207
225
  """The request body for the validate endpoint."""
208
226
  dag: str
209
- request_options: Optional[admin_policy.RequestOptions]
210
227
 
211
228
 
212
- class OptimizeBody(DagRequestBody):
229
+ class OptimizeBody(DagRequestBodyWithRequestOptions):
213
230
  """The request body for the optimize endpoint."""
214
231
  dag: str
215
232
  minimize: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST
216
- request_options: Optional[admin_policy.RequestOptions]
217
233
 
218
234
 
219
235
  class LaunchBody(RequestBody):
sky/server/rest.py CHANGED
@@ -89,6 +89,12 @@ def retry_transient_errors(max_retries: int = 3,
89
89
  for retry_cnt in range(max_retries):
90
90
  try:
91
91
  return func(*args, **kwargs)
92
+ # Occurs when the server proactively interrupts the request
93
+ # during rolling update, we can retry immediately on the
94
+ # new replica.
95
+ except exceptions.RequestInterruptedError:
96
+ logger.debug('Request interrupted. Retry immediately.')
97
+ continue
92
98
  except Exception as e: # pylint: disable=broad-except
93
99
  if retry_cnt >= max_retries - 1:
94
100
  # Retries exhausted.
sky/server/server.py CHANGED
@@ -827,7 +827,8 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
827
827
  # added RTTs. For now, we stick to doing the validation inline in the
828
828
  # server thread.
829
829
  with admin_policy_utils.apply_and_use_config_in_current_request(
830
- dag, request_options=validate_body.request_options) as dag:
830
+ dag,
831
+ request_options=validate_body.get_request_options()) as dag:
831
832
  # Skip validating workdir and file_mounts, as those need to be
832
833
  # validated after the files are uploaded to the SkyPilot API server
833
834
  # with `upload_mounts_to_api_server`.
@@ -1763,6 +1764,9 @@ if __name__ == '__main__':
1763
1764
 
1764
1765
  from sky.server import uvicorn as skyuvicorn
1765
1766
 
1767
+ # Initialize global user state db
1768
+ global_user_state.initialize_and_get_db()
1769
+ # Initialize request db
1766
1770
  requests_lib.reset_db_and_logs()
1767
1771
 
1768
1772
  parser = argparse.ArgumentParser()
@@ -19,7 +19,7 @@ docker:
19
19
  username: |-
20
20
  {{docker_login_config.username}}
21
21
  password: |-
22
- {{docker_login_config.password}}
22
+ {{docker_login_config.password | indent(6) }}
23
23
  server: |-
24
24
  {{docker_login_config.server}}
25
25
  {%- endif %}
@@ -131,6 +131,12 @@ available_node_types:
131
131
  - systemctl disable apt-daily.timer apt-daily-upgrade.timer unattended-upgrades.service
132
132
  - systemctl mask apt-daily.service apt-daily-upgrade.service unattended-upgrades.service
133
133
  - systemctl daemon-reload
134
+ {%- if runcmd %}
135
+ runcmd:
136
+ {%- for cmd in runcmd %}
137
+ - {{cmd}}
138
+ {%- endfor %}
139
+ {%- endif %}
134
140
  TagSpecifications:
135
141
  - ResourceType: instance
136
142
  Tags:
@@ -19,7 +19,7 @@ docker:
19
19
  username: |-
20
20
  {{docker_login_config.username}}
21
21
  password: |-
22
- {{docker_login_config.password}}
22
+ {{docker_login_config.password | indent(6) }}
23
23
  server: |-
24
24
  {{docker_login_config.server}}
25
25
  {%- endif %}
@@ -19,7 +19,7 @@ docker:
19
19
  username: |-
20
20
  {{docker_login_config.username}}
21
21
  password: |-
22
- {{docker_login_config.password}}
22
+ {{docker_login_config.password | indent(6) }}
23
23
  server: |-
24
24
  {{docker_login_config.server}}
25
25
  {%- endif %}
@@ -19,7 +19,7 @@ docker:
19
19
  username: |-
20
20
  {{docker_login_config.username}}
21
21
  password: |-
22
- {{docker_login_config.password}}
22
+ {{docker_login_config.password | indent(6) }}
23
23
  server: |-
24
24
  {{docker_login_config.server}}
25
25
  {%- endif %}
@@ -25,7 +25,7 @@ docker:
25
25
  username: |-
26
26
  {{docker_login_config.username}}
27
27
  password: |-
28
- {{docker_login_config.password}}
28
+ {{docker_login_config.password | indent(6) }}
29
29
  server: |-
30
30
  {{docker_login_config.server}}
31
31
  {%- endif %}
@@ -19,7 +19,7 @@ docker:
19
19
  username: |-
20
20
  {{docker_login_config.username}}
21
21
  password: |-
22
- {{docker_login_config.password}}
22
+ {{docker_login_config.password | indent(6) }}
23
23
  server: |-
24
24
  {{docker_login_config.server}}
25
25
  {%- endif %}
@@ -20,7 +20,7 @@ provider:
20
20
  username: |-
21
21
  {{docker_login_config.username}}
22
22
  password: |-
23
- {{docker_login_config.password}}
23
+ {{docker_login_config.password | indent(6) }}
24
24
  server: |-
25
25
  {{docker_login_config.server}}
26
26
  {%- endif %}
sky/utils/config_utils.py CHANGED
@@ -6,6 +6,8 @@ from sky import sky_logging
6
6
 
7
7
  logger = sky_logging.init_logger(__name__)
8
8
 
9
+ _REGION_CONFIG_CLOUDS = ['nebius', 'oci']
10
+
9
11
 
10
12
  class Config(Dict[str, Any]):
11
13
  """SkyPilot config that supports setting/getting values with nested keys."""
@@ -248,7 +250,7 @@ def get_cloud_config_value_from_dict(
248
250
  region_key = None
249
251
  if cloud == 'kubernetes':
250
252
  region_key = 'context_configs'
251
- if cloud == 'nebius':
253
+ elif cloud in _REGION_CONFIG_CLOUDS:
252
254
  region_key = 'region_configs'
253
255
 
254
256
  per_context_config = None
@@ -257,7 +259,7 @@ def get_cloud_config_value_from_dict(
257
259
  keys=(cloud, region_key, region) + keys,
258
260
  default_value=None,
259
261
  override_configs=override_configs)
260
- if not per_context_config and cloud == 'nebius':
262
+ if not per_context_config and cloud in _REGION_CONFIG_CLOUDS:
261
263
  # TODO (kyuds): Backward compatibility, remove after 0.11.0.
262
264
  per_context_config = input_config.get_nested(
263
265
  keys=(cloud, region) + keys,
@@ -265,9 +267,9 @@ def get_cloud_config_value_from_dict(
265
267
  override_configs=override_configs)
266
268
  if per_context_config is not None:
267
269
  logger.info(
268
- 'Nebius configuration is using the legacy format. \n'
270
+ f'{cloud} configuration is using the legacy format. \n'
269
271
  'This format will be deprecated after 0.11.0, refer to '
270
- '`https://docs.skypilot.co/en/latest/reference/config.html#nebius` ' # pylint: disable=line-too-long
272
+ '`https://docs.skypilot.co/en/latest/reference/config.html` ' # pylint: disable=line-too-long
271
273
  'for the new format. Please use `region_configs` to specify region specific configuration.'
272
274
  )
273
275
  # if no override found for specified region
@@ -3,6 +3,7 @@
3
3
  import contextlib
4
4
  import logging
5
5
  import os
6
+ import pathlib
6
7
 
7
8
  from alembic import command as alembic_command
8
9
  from alembic.config import Config
@@ -10,6 +11,12 @@ from alembic.runtime import migration
10
11
  import filelock
11
12
  import sqlalchemy
12
13
 
14
+ from sky import sky_logging
15
+ from sky import skypilot_config
16
+ from sky.skylet import constants
17
+
18
+ logger = sky_logging.init_logger(__name__)
19
+
13
20
  DB_INIT_LOCK_TIMEOUT_SECONDS = 10
14
21
 
15
22
  GLOBAL_USER_STATE_DB_NAME = 'state_db'
@@ -21,6 +28,21 @@ SPOT_JOBS_VERSION = '001'
21
28
  SPOT_JOBS_LOCK_PATH = '~/.sky/locks/.spot_jobs_db.lock'
22
29
 
23
30
 
31
+ def get_engine(db_name: str):
32
+ conn_string = None
33
+ if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
34
+ conn_string = skypilot_config.get_nested(('db',), None)
35
+ if conn_string:
36
+ logger.debug(f'using db URI from {conn_string}')
37
+ engine = sqlalchemy.create_engine(conn_string,
38
+ poolclass=sqlalchemy.NullPool)
39
+ else:
40
+ db_path = os.path.expanduser(f'~/.sky/{db_name}.db')
41
+ pathlib.Path(db_path).parents[0].mkdir(parents=True, exist_ok=True)
42
+ engine = sqlalchemy.create_engine('sqlite:///' + db_path)
43
+ return engine
44
+
45
+
24
46
  @contextlib.contextmanager
25
47
  def db_lock(db_name: str):
26
48
  lock_path = os.path.expanduser(f'~/.sky/locks/.{db_name}.lock')
@@ -37,7 +59,6 @@ def db_lock(db_name: str):
37
59
 
38
60
  def get_alembic_config(engine: sqlalchemy.engine.Engine, section: str):
39
61
  """Get Alembic configuration for the given section"""
40
- # Use the alembic.ini file from setup_files (included in wheel)
41
62
  # From sky/utils/db/migration_utils.py -> sky/setup_files/alembic.ini
42
63
  alembic_ini_path = os.path.join(
43
64
  os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
@@ -47,31 +68,29 @@ def get_alembic_config(engine: sqlalchemy.engine.Engine, section: str):
47
68
  # Override the database URL to match SkyPilot's current connection
48
69
  # Use render_as_string to get the full URL with password
49
70
  url = engine.url.render_as_string(hide_password=False)
71
+ # Replace % with %% to escape the % character in the URL
72
+ # set_section_option uses variable interpolation, which treats % as a
73
+ # special character.
74
+ # any '%' symbol not used for interpolation needs to be escaped.
75
+ url = url.replace('%', '%%')
50
76
  alembic_cfg.set_section_option(section, 'sqlalchemy.url', url)
51
77
 
52
78
  return alembic_cfg
53
79
 
54
80
 
55
- def safe_alembic_upgrade(engine: sqlalchemy.engine.Engine,
56
- alembic_config: Config, target_revision: str):
57
- """Only upgrade if current version is older than target.
58
-
59
- This handles the case where a database was created with a newer version of
60
- the code and we're now running older code. Since our migrations are purely
61
- additive, it's safe to run a newer database with older code.
81
+ def needs_upgrade(engine: sqlalchemy.engine.Engine, section: str,
82
+ target_revision: str):
83
+ """Check if the database needs to be upgraded.
62
84
 
63
85
  Args:
64
86
  engine: SQLAlchemy engine for the database
65
- alembic_config: Alembic configuration object
87
+ section: Alembic section to upgrade (e.g., 'state_db' or 'spot_jobs_db')
66
88
  target_revision: Target revision to upgrade to (e.g., '001')
67
89
  """
68
- # set alembic logger to warning level
69
- alembic_logger = logging.getLogger('alembic')
70
- alembic_logger.setLevel(logging.WARNING)
71
-
72
90
  current_rev = None
73
91
 
74
- # Get the current revision from the database
92
+ # get alembic config for the given section
93
+ alembic_config = get_alembic_config(engine, section)
75
94
  version_table = alembic_config.get_section_option(
76
95
  alembic_config.config_ini_section, 'version_table', 'alembic_version')
77
96
 
@@ -81,13 +100,35 @@ def safe_alembic_upgrade(engine: sqlalchemy.engine.Engine,
81
100
  current_rev = context.get_current_revision()
82
101
 
83
102
  if current_rev is None:
84
- alembic_command.upgrade(alembic_config, target_revision)
85
- return
103
+ return True
86
104
 
87
105
  # Compare revisions - assuming they are numeric strings like '001', '002'
88
106
  current_rev_num = int(current_rev)
89
107
  target_rev_num = int(target_revision)
90
108
 
91
- # only upgrade if current revision is older than target revision
92
- if current_rev_num < target_rev_num:
93
- alembic_command.upgrade(alembic_config, target_revision)
109
+ return current_rev_num < target_rev_num
110
+
111
+
112
+ def safe_alembic_upgrade(engine: sqlalchemy.engine.Engine, section: str,
113
+ target_revision: str):
114
+ """Upgrade the database if needed. Uses a file lock to ensure
115
+ that only one process tries to upgrade the database at a time.
116
+
117
+ Args:
118
+ engine: SQLAlchemy engine for the database
119
+ section: Alembic section to upgrade (e.g., 'state_db' or 'spot_jobs_db')
120
+ target_revision: Target revision to upgrade to (e.g., '001')
121
+ """
122
+ # set alembic logger to warning level
123
+ alembic_logger = logging.getLogger('alembic')
124
+ alembic_logger.setLevel(logging.WARNING)
125
+
126
+ alembic_config = get_alembic_config(engine, section)
127
+
128
+ # only acquire lock if db needs upgrade
129
+ if needs_upgrade(engine, section, target_revision):
130
+ with db_lock(section):
131
+ # check again if db needs upgrade in case another
132
+ # process upgraded it while we were waiting for the lock
133
+ if needs_upgrade(engine, section, target_revision):
134
+ alembic_command.upgrade(alembic_config, target_revision)