skypilot-nightly 1.0.0.dev20250114__py3-none-any.whl → 1.0.0.dev20250124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/cloud_vm_ray_backend.py +50 -67
- sky/check.py +31 -1
- sky/cli.py +11 -34
- sky/clouds/kubernetes.py +3 -3
- sky/clouds/service_catalog/kubernetes_catalog.py +14 -0
- sky/core.py +8 -5
- sky/data/storage.py +66 -14
- sky/global_user_state.py +1 -1
- sky/jobs/constants.py +8 -7
- sky/jobs/controller.py +19 -22
- sky/jobs/core.py +0 -2
- sky/jobs/recovery_strategy.py +114 -143
- sky/jobs/scheduler.py +283 -0
- sky/jobs/state.py +263 -21
- sky/jobs/utils.py +338 -96
- sky/provision/aws/config.py +48 -26
- sky/provision/gcp/instance_utils.py +15 -9
- sky/provision/kubernetes/instance.py +1 -1
- sky/provision/kubernetes/utils.py +76 -18
- sky/resources.py +1 -1
- sky/serve/autoscalers.py +359 -301
- sky/serve/controller.py +10 -8
- sky/serve/core.py +84 -7
- sky/serve/load_balancer.py +27 -10
- sky/serve/replica_managers.py +1 -3
- sky/serve/serve_state.py +10 -5
- sky/serve/serve_utils.py +28 -1
- sky/serve/service.py +4 -3
- sky/serve/service_spec.py +31 -0
- sky/skylet/constants.py +1 -1
- sky/skylet/events.py +7 -3
- sky/skylet/job_lib.py +10 -30
- sky/skylet/log_lib.py +8 -8
- sky/skylet/log_lib.pyi +3 -0
- sky/skylet/skylet.py +1 -1
- sky/templates/jobs-controller.yaml.j2 +7 -3
- sky/templates/sky-serve-controller.yaml.j2 +4 -0
- sky/utils/db_utils.py +18 -4
- sky/utils/kubernetes/deploy_remote_cluster.sh +5 -5
- sky/utils/resources_utils.py +25 -21
- sky/utils/schemas.py +13 -0
- sky/utils/subprocess_utils.py +48 -9
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/METADATA +4 -1
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/RECORD +49 -48
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/top_level.txt +0 -0
sky/serve/controller.py
CHANGED
@@ -67,9 +67,16 @@ class SkyServeController:
|
|
67
67
|
try:
|
68
68
|
replica_infos = serve_state.get_replica_infos(
|
69
69
|
self._service_name)
|
70
|
+
# Use the active versions set by replica manager to make
|
71
|
+
# sure we only scale down the outdated replicas that are
|
72
|
+
# not used by the load balancer.
|
73
|
+
record = serve_state.get_service_from_name(self._service_name)
|
74
|
+
assert record is not None, ('No service record found for '
|
75
|
+
f'{self._service_name}')
|
76
|
+
active_versions = record['active_versions']
|
70
77
|
logger.info(f'All replica info: {replica_infos}')
|
71
|
-
scaling_options = self._autoscaler.
|
72
|
-
replica_infos)
|
78
|
+
scaling_options = self._autoscaler.generate_scaling_decisions(
|
79
|
+
replica_infos, active_versions)
|
73
80
|
for scaling_option in scaling_options:
|
74
81
|
logger.info(f'Scaling option received: {scaling_option}')
|
75
82
|
if (scaling_option.operator ==
|
@@ -77,15 +84,10 @@ class SkyServeController:
|
|
77
84
|
assert (scaling_option.target is None or isinstance(
|
78
85
|
scaling_option.target, dict)), scaling_option
|
79
86
|
self._replica_manager.scale_up(scaling_option.target)
|
80
|
-
|
81
|
-
autoscalers.AutoscalerDecisionOperator.SCALE_DOWN):
|
87
|
+
else:
|
82
88
|
assert isinstance(scaling_option.target,
|
83
89
|
int), scaling_option
|
84
90
|
self._replica_manager.scale_down(scaling_option.target)
|
85
|
-
else:
|
86
|
-
with ux_utils.enable_traceback():
|
87
|
-
logger.error('Error in scaling_option.operator: '
|
88
|
-
f'{scaling_option.operator}')
|
89
91
|
except Exception as e: # pylint: disable=broad-except
|
90
92
|
# No matter what error happens, we should keep the
|
91
93
|
# monitor running.
|
sky/serve/core.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
"""SkyServe core APIs."""
|
2
2
|
import re
|
3
|
+
import signal
|
4
|
+
import subprocess
|
3
5
|
import tempfile
|
6
|
+
import threading
|
4
7
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
8
|
|
6
9
|
import colorama
|
@@ -18,6 +21,7 @@ from sky.serve import serve_utils
|
|
18
21
|
from sky.skylet import constants
|
19
22
|
from sky.usage import usage_lib
|
20
23
|
from sky.utils import admin_policy_utils
|
24
|
+
from sky.utils import command_runner
|
21
25
|
from sky.utils import common_utils
|
22
26
|
from sky.utils import controller_utils
|
23
27
|
from sky.utils import resources_utils
|
@@ -91,6 +95,38 @@ def _validate_service_task(task: 'sky.Task') -> None:
|
|
91
95
|
'Please specify the same port instead.')
|
92
96
|
|
93
97
|
|
98
|
+
def _rewrite_tls_credential_paths_and_get_tls_env_vars(
|
99
|
+
service_name: str, task: 'sky.Task') -> Dict[str, Any]:
|
100
|
+
"""Rewrite the paths of TLS credentials in the task.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
service_name: Name of the service.
|
104
|
+
task: sky.Task to rewrite.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
The generated template variables for TLS.
|
108
|
+
"""
|
109
|
+
service_spec = task.service
|
110
|
+
# Already checked by _validate_service_task
|
111
|
+
assert service_spec is not None
|
112
|
+
if service_spec.tls_credential is None:
|
113
|
+
return {'use_tls': False}
|
114
|
+
remote_tls_keyfile = (
|
115
|
+
serve_utils.generate_remote_tls_keyfile_name(service_name))
|
116
|
+
remote_tls_certfile = (
|
117
|
+
serve_utils.generate_remote_tls_certfile_name(service_name))
|
118
|
+
tls_template_vars = {
|
119
|
+
'use_tls': True,
|
120
|
+
'remote_tls_keyfile': remote_tls_keyfile,
|
121
|
+
'remote_tls_certfile': remote_tls_certfile,
|
122
|
+
'local_tls_keyfile': service_spec.tls_credential.keyfile,
|
123
|
+
'local_tls_certfile': service_spec.tls_credential.certfile,
|
124
|
+
}
|
125
|
+
service_spec.tls_credential = serve_utils.TLSCredential(
|
126
|
+
remote_tls_keyfile, remote_tls_certfile)
|
127
|
+
return tls_template_vars
|
128
|
+
|
129
|
+
|
94
130
|
@usage_lib.entrypoint
|
95
131
|
def up(
|
96
132
|
task: 'sky.Task',
|
@@ -136,6 +172,9 @@ def up(
|
|
136
172
|
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
|
137
173
|
task, path='serve')
|
138
174
|
|
175
|
+
tls_template_vars = _rewrite_tls_credential_paths_and_get_tls_env_vars(
|
176
|
+
service_name, task)
|
177
|
+
|
139
178
|
with tempfile.NamedTemporaryFile(
|
140
179
|
prefix=f'service-task-{service_name}-',
|
141
180
|
mode='w',
|
@@ -164,6 +203,7 @@ def up(
|
|
164
203
|
'remote_user_config_path': remote_config_yaml_path,
|
165
204
|
'modified_catalogs':
|
166
205
|
service_catalog_common.get_modified_catalog_file_mounts(),
|
206
|
+
**tls_template_vars,
|
167
207
|
**controller_utils.shared_controller_vars_to_fill(
|
168
208
|
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
|
169
209
|
remote_user_config_path=remote_config_yaml_path,
|
@@ -269,10 +309,16 @@ def up(
|
|
269
309
|
else:
|
270
310
|
lb_port = serve_utils.load_service_initialization_result(
|
271
311
|
lb_port_payload)
|
272
|
-
|
312
|
+
socket_endpoint = backend_utils.get_endpoints(
|
273
313
|
controller_handle.cluster_name, lb_port,
|
274
314
|
skip_status_check=True).get(lb_port)
|
275
|
-
assert
|
315
|
+
assert socket_endpoint is not None, (
|
316
|
+
'Did not get endpoint for controller.')
|
317
|
+
# Already checked by _validate_service_task
|
318
|
+
assert task.service is not None
|
319
|
+
protocol = ('http'
|
320
|
+
if task.service.tls_credential is None else 'https')
|
321
|
+
endpoint = f'{protocol}://{socket_endpoint}'
|
276
322
|
|
277
323
|
sky_logging.print(
|
278
324
|
f'{fore.CYAN}Service name: '
|
@@ -321,6 +367,7 @@ def update(
|
|
321
367
|
service_name: Name of the service.
|
322
368
|
"""
|
323
369
|
_validate_service_task(task)
|
370
|
+
|
324
371
|
# Always apply the policy again here, even though it might have been applied
|
325
372
|
# in the CLI. This is to ensure that we apply the policy to the final DAG
|
326
373
|
# and get the mutated config.
|
@@ -329,6 +376,14 @@ def update(
|
|
329
376
|
dag, _ = admin_policy_utils.apply(
|
330
377
|
task, use_mutated_config_in_current_request=False)
|
331
378
|
task = dag.tasks[0]
|
379
|
+
|
380
|
+
assert task.service is not None
|
381
|
+
if task.service.tls_credential is not None:
|
382
|
+
logger.warning('Updating TLS keyfile and certfile is not supported. '
|
383
|
+
'Any updates to the keyfile and certfile will not take '
|
384
|
+
'effect. To update TLS keyfile and certfile, please '
|
385
|
+
'tear down the service and spin up a new one.')
|
386
|
+
|
332
387
|
handle = backend_utils.is_controller_accessible(
|
333
388
|
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
|
334
389
|
stopped_message=
|
@@ -596,6 +651,7 @@ def status(
|
|
596
651
|
'requested_resources_str': (str) str representation of
|
597
652
|
requested resources,
|
598
653
|
'load_balancing_policy': (str) load balancing policy name,
|
654
|
+
'tls_encrypted': (bool) whether the service is TLS encrypted,
|
599
655
|
'replica_info': (List[Dict[str, Any]]) replica information,
|
600
656
|
}
|
601
657
|
|
@@ -731,8 +787,29 @@ def tail_logs(
|
|
731
787
|
|
732
788
|
backend = backend_utils.get_backend_from_handle(handle)
|
733
789
|
assert isinstance(backend, backends.CloudVmRayBackend), backend
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
790
|
+
|
791
|
+
if target != serve_utils.ServiceComponent.REPLICA:
|
792
|
+
code = serve_utils.ServeCodeGen.stream_serve_process_logs(
|
793
|
+
service_name,
|
794
|
+
stream_controller=(
|
795
|
+
target == serve_utils.ServiceComponent.CONTROLLER),
|
796
|
+
follow=follow)
|
797
|
+
else:
|
798
|
+
assert replica_id is not None, service_name
|
799
|
+
code = serve_utils.ServeCodeGen.stream_replica_logs(
|
800
|
+
service_name, replica_id, follow)
|
801
|
+
|
802
|
+
# With the stdin=subprocess.DEVNULL, the ctrl-c will not directly
|
803
|
+
# kill the process, so we need to handle it manually here.
|
804
|
+
if threading.current_thread() is threading.main_thread():
|
805
|
+
signal.signal(signal.SIGINT, backend_utils.interrupt_handler)
|
806
|
+
signal.signal(signal.SIGTSTP, backend_utils.stop_handler)
|
807
|
+
|
808
|
+
# Refer to the notes in
|
809
|
+
# sky/backends/cloud_vm_ray_backend.py::CloudVmRayBackend::tail_logs.
|
810
|
+
backend.run_on_head(handle,
|
811
|
+
code,
|
812
|
+
stream_logs=True,
|
813
|
+
process_stream=False,
|
814
|
+
ssh_mode=command_runner.SshMode.INTERACTIVE,
|
815
|
+
stdin=subprocess.DEVNULL)
|
sky/serve/load_balancer.py
CHANGED
@@ -27,10 +27,12 @@ class SkyServeLoadBalancer:
|
|
27
27
|
policy.
|
28
28
|
"""
|
29
29
|
|
30
|
-
def __init__(
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
controller_url: str,
|
33
|
+
load_balancer_port: int,
|
34
|
+
load_balancing_policy_name: Optional[str] = None,
|
35
|
+
tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
|
34
36
|
"""Initialize the load balancer.
|
35
37
|
|
36
38
|
Args:
|
@@ -38,6 +40,8 @@ class SkyServeLoadBalancer:
|
|
38
40
|
load_balancer_port: The port where the load balancer listens to.
|
39
41
|
load_balancing_policy_name: The name of the load balancing policy
|
40
42
|
to use. Defaults to None.
|
43
|
+
tls_credentials: The TLS credentials for HTTPS endpoint. Defaults
|
44
|
+
to None.
|
41
45
|
"""
|
42
46
|
self._app = fastapi.FastAPI()
|
43
47
|
self._controller_url: str = controller_url
|
@@ -49,6 +53,8 @@ class SkyServeLoadBalancer:
|
|
49
53
|
f'{load_balancing_policy_name}.')
|
50
54
|
self._request_aggregator: serve_utils.RequestsAggregator = (
|
51
55
|
serve_utils.RequestTimestamp())
|
56
|
+
self._tls_credential: Optional[serve_utils.TLSCredential] = (
|
57
|
+
tls_credential)
|
52
58
|
# TODO(tian): httpx.Client has a resource limit of 100 max connections
|
53
59
|
# for each client. We should wait for feedback on the best max
|
54
60
|
# connections.
|
@@ -231,15 +237,25 @@ class SkyServeLoadBalancer:
|
|
231
237
|
# Register controller synchronization task
|
232
238
|
asyncio.create_task(self._sync_with_controller())
|
233
239
|
|
240
|
+
uvicorn_tls_kwargs = ({} if self._tls_credential is None else
|
241
|
+
self._tls_credential.dump_uvicorn_kwargs())
|
242
|
+
|
243
|
+
protocol = 'https' if self._tls_credential is not None else 'http'
|
244
|
+
|
234
245
|
logger.info('SkyServe Load Balancer started on '
|
235
|
-
f'
|
246
|
+
f'{protocol}://0.0.0.0:{self._load_balancer_port}')
|
236
247
|
|
237
|
-
uvicorn.run(self._app,
|
248
|
+
uvicorn.run(self._app,
|
249
|
+
host='0.0.0.0',
|
250
|
+
port=self._load_balancer_port,
|
251
|
+
**uvicorn_tls_kwargs)
|
238
252
|
|
239
253
|
|
240
|
-
def run_load_balancer(
|
241
|
-
|
242
|
-
|
254
|
+
def run_load_balancer(
|
255
|
+
controller_addr: str,
|
256
|
+
load_balancer_port: int,
|
257
|
+
load_balancing_policy_name: Optional[str] = None,
|
258
|
+
tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
|
243
259
|
""" Run the load balancer.
|
244
260
|
|
245
261
|
Args:
|
@@ -251,7 +267,8 @@ def run_load_balancer(controller_addr: str,
|
|
251
267
|
load_balancer = SkyServeLoadBalancer(
|
252
268
|
controller_url=controller_addr,
|
253
269
|
load_balancer_port=load_balancer_port,
|
254
|
-
load_balancing_policy_name=load_balancing_policy_name
|
270
|
+
load_balancing_policy_name=load_balancing_policy_name,
|
271
|
+
tls_credential=tls_credential)
|
255
272
|
load_balancer.run()
|
256
273
|
|
257
274
|
|
sky/serve/replica_managers.py
CHANGED
@@ -998,9 +998,7 @@ class SkyPilotReplicaManager(ReplicaManager):
|
|
998
998
|
# Re-raise the exception if it is not preempted.
|
999
999
|
raise
|
1000
1000
|
job_status = list(job_statuses.values())[0]
|
1001
|
-
if job_status in
|
1002
|
-
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
|
1003
|
-
]:
|
1001
|
+
if job_status in job_lib.JobStatus.user_code_failure_states():
|
1004
1002
|
info.status_property.user_app_failed = True
|
1005
1003
|
serve_state.add_or_update_replica(self._service_name,
|
1006
1004
|
info.replica_id, info)
|
sky/serve/serve_state.py
CHANGED
@@ -79,6 +79,9 @@ db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
|
|
79
79
|
f'TEXT DEFAULT {json.dumps([])!r}')
|
80
80
|
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
|
81
81
|
'load_balancing_policy', 'TEXT DEFAULT NULL')
|
82
|
+
# Whether the service's load balancer is encrypted with TLS.
|
83
|
+
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'tls_encrypted',
|
84
|
+
'INTEGER DEFAULT 0')
|
82
85
|
_UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
|
83
86
|
|
84
87
|
|
@@ -245,7 +248,7 @@ _SERVICE_STATUS_TO_COLOR = {
|
|
245
248
|
|
246
249
|
def add_service(name: str, controller_job_id: int, policy: str,
|
247
250
|
requested_resources_str: str, load_balancing_policy: str,
|
248
|
-
status: ServiceStatus) -> bool:
|
251
|
+
status: ServiceStatus, tls_encrypted: bool) -> bool:
|
249
252
|
"""Add a service in the database.
|
250
253
|
|
251
254
|
Returns:
|
@@ -258,10 +261,11 @@ def add_service(name: str, controller_job_id: int, policy: str,
|
|
258
261
|
"""\
|
259
262
|
INSERT INTO services
|
260
263
|
(name, controller_job_id, status, policy,
|
261
|
-
requested_resources_str, load_balancing_policy)
|
262
|
-
VALUES (?, ?, ?, ?, ?, ?)""",
|
264
|
+
requested_resources_str, load_balancing_policy, tls_encrypted)
|
265
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
263
266
|
(name, controller_job_id, status.value, policy,
|
264
|
-
requested_resources_str, load_balancing_policy
|
267
|
+
requested_resources_str, load_balancing_policy,
|
268
|
+
int(tls_encrypted)))
|
265
269
|
|
266
270
|
except sqlite3.IntegrityError as e:
|
267
271
|
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
|
@@ -328,7 +332,7 @@ def set_service_load_balancer_port(service_name: str,
|
|
328
332
|
def _get_service_from_row(row) -> Dict[str, Any]:
|
329
333
|
(current_version, name, controller_job_id, controller_port,
|
330
334
|
load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
|
331
|
-
_, active_versions, load_balancing_policy) = row[:
|
335
|
+
_, active_versions, load_balancing_policy, tls_encrypted) = row[:15]
|
332
336
|
if load_balancing_policy is None:
|
333
337
|
# This entry in database was added in #4439, and it will always be set
|
334
338
|
# to a str value. If it is None, it means it is an legacy entry and is
|
@@ -351,6 +355,7 @@ def _get_service_from_row(row) -> Dict[str, Any]:
|
|
351
355
|
'active_versions': json.loads(active_versions),
|
352
356
|
'requested_resources_str': requested_resources_str,
|
353
357
|
'load_balancing_policy': load_balancing_policy,
|
358
|
+
'tls_encrypted': bool(tls_encrypted),
|
354
359
|
}
|
355
360
|
|
356
361
|
|
sky/serve/serve_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""User interface with the SkyServe."""
|
2
2
|
import base64
|
3
3
|
import collections
|
4
|
+
import dataclasses
|
4
5
|
import enum
|
5
6
|
import os
|
6
7
|
import pathlib
|
@@ -92,6 +93,19 @@ class UpdateMode(enum.Enum):
|
|
92
93
|
BLUE_GREEN = 'blue_green'
|
93
94
|
|
94
95
|
|
96
|
+
@dataclasses.dataclass
|
97
|
+
class TLSCredential:
|
98
|
+
"""TLS credential for the service."""
|
99
|
+
keyfile: str
|
100
|
+
certfile: str
|
101
|
+
|
102
|
+
def dump_uvicorn_kwargs(self) -> Dict[str, str]:
|
103
|
+
return {
|
104
|
+
'ssl_keyfile': os.path.expanduser(self.keyfile),
|
105
|
+
'ssl_certfile': os.path.expanduser(self.certfile),
|
106
|
+
}
|
107
|
+
|
108
|
+
|
95
109
|
DEFAULT_UPDATE_MODE = UpdateMode.ROLLING
|
96
110
|
|
97
111
|
_SIGNAL_TO_ERROR = {
|
@@ -243,6 +257,18 @@ def generate_replica_log_file_name(service_name: str, replica_id: int) -> str:
|
|
243
257
|
return os.path.join(dir_name, f'replica_{replica_id}.log')
|
244
258
|
|
245
259
|
|
260
|
+
def generate_remote_tls_keyfile_name(service_name: str) -> str:
|
261
|
+
dir_name = generate_remote_service_dir_name(service_name)
|
262
|
+
# Don't expand here since it is used for remote machine.
|
263
|
+
return os.path.join(dir_name, 'tls_keyfile')
|
264
|
+
|
265
|
+
|
266
|
+
def generate_remote_tls_certfile_name(service_name: str) -> str:
|
267
|
+
dir_name = generate_remote_service_dir_name(service_name)
|
268
|
+
# Don't expand here since it is used for remote machine.
|
269
|
+
return os.path.join(dir_name, 'tls_certfile')
|
270
|
+
|
271
|
+
|
246
272
|
def generate_replica_cluster_name(service_name: str, replica_id: int) -> str:
|
247
273
|
return f'{service_name}-{replica_id}'
|
248
274
|
|
@@ -799,7 +825,8 @@ def get_endpoint(service_record: Dict[str, Any]) -> str:
|
|
799
825
|
if endpoint is None:
|
800
826
|
return '-'
|
801
827
|
assert isinstance(endpoint, str), endpoint
|
802
|
-
|
828
|
+
protocol = 'https' if service_record['tls_encrypted'] else 'http'
|
829
|
+
return f'{protocol}://{endpoint}'
|
803
830
|
|
804
831
|
|
805
832
|
def format_service_table(service_records: List[Dict[str, Any]],
|
sky/serve/service.py
CHANGED
@@ -151,7 +151,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
|
|
151
151
|
policy=service_spec.autoscaling_policy_str(),
|
152
152
|
requested_resources_str=backend_utils.get_task_resources_str(task),
|
153
153
|
load_balancing_policy=service_spec.load_balancing_policy,
|
154
|
-
status=serve_state.ServiceStatus.CONTROLLER_INIT
|
154
|
+
status=serve_state.ServiceStatus.CONTROLLER_INIT,
|
155
|
+
tls_encrypted=service_spec.tls_credential is not None)
|
155
156
|
# Directly throw an error here. See sky/serve/api.py::up
|
156
157
|
# for more details.
|
157
158
|
if not success:
|
@@ -214,7 +215,6 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
|
|
214
215
|
serve_state.set_service_controller_port(service_name,
|
215
216
|
controller_port)
|
216
217
|
|
217
|
-
# TODO(tian): Support HTTPS.
|
218
218
|
controller_addr = f'http://{controller_host}:{controller_port}'
|
219
219
|
|
220
220
|
load_balancer_port = common_utils.find_free_port(
|
@@ -231,7 +231,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
|
|
231
231
|
target=ux_utils.RedirectOutputForProcess(
|
232
232
|
load_balancer.run_load_balancer,
|
233
233
|
load_balancer_log_file).run,
|
234
|
-
args=(controller_addr, load_balancer_port, policy_name
|
234
|
+
args=(controller_addr, load_balancer_port, policy_name,
|
235
|
+
service_spec.tls_credential))
|
235
236
|
load_balancer_process.start()
|
236
237
|
serve_state.set_service_load_balancer_port(service_name,
|
237
238
|
load_balancer_port)
|
sky/serve/service_spec.py
CHANGED
@@ -9,6 +9,7 @@ import yaml
|
|
9
9
|
from sky import serve
|
10
10
|
from sky.serve import constants
|
11
11
|
from sky.serve import load_balancing_policies as lb_policies
|
12
|
+
from sky.serve import serve_utils
|
12
13
|
from sky.utils import common_utils
|
13
14
|
from sky.utils import schemas
|
14
15
|
from sky.utils import ux_utils
|
@@ -26,6 +27,7 @@ class SkyServiceSpec:
|
|
26
27
|
max_replicas: Optional[int] = None,
|
27
28
|
target_qps_per_replica: Optional[float] = None,
|
28
29
|
post_data: Optional[Dict[str, Any]] = None,
|
30
|
+
tls_credential: Optional[serve_utils.TLSCredential] = None,
|
29
31
|
readiness_headers: Optional[Dict[str, str]] = None,
|
30
32
|
dynamic_ondemand_fallback: Optional[bool] = None,
|
31
33
|
base_ondemand_fallback_replicas: Optional[int] = None,
|
@@ -72,6 +74,8 @@ class SkyServiceSpec:
|
|
72
74
|
self._max_replicas: Optional[int] = max_replicas
|
73
75
|
self._target_qps_per_replica: Optional[float] = target_qps_per_replica
|
74
76
|
self._post_data: Optional[Dict[str, Any]] = post_data
|
77
|
+
self._tls_credential: Optional[serve_utils.TLSCredential] = (
|
78
|
+
tls_credential)
|
75
79
|
self._readiness_headers: Optional[Dict[str, str]] = readiness_headers
|
76
80
|
self._dynamic_ondemand_fallback: Optional[
|
77
81
|
bool] = dynamic_ondemand_fallback
|
@@ -163,6 +167,14 @@ class SkyServiceSpec:
|
|
163
167
|
|
164
168
|
service_config['load_balancing_policy'] = config.get(
|
165
169
|
'load_balancing_policy', None)
|
170
|
+
|
171
|
+
tls_section = config.get('tls', None)
|
172
|
+
if tls_section is not None:
|
173
|
+
service_config['tls_credential'] = serve_utils.TLSCredential(
|
174
|
+
keyfile=tls_section.get('keyfile', None),
|
175
|
+
certfile=tls_section.get('certfile', None),
|
176
|
+
)
|
177
|
+
|
166
178
|
return SkyServiceSpec(**service_config)
|
167
179
|
|
168
180
|
@staticmethod
|
@@ -223,6 +235,9 @@ class SkyServiceSpec:
|
|
223
235
|
self.downscale_delay_seconds)
|
224
236
|
add_if_not_none('load_balancing_policy', None,
|
225
237
|
self._load_balancing_policy)
|
238
|
+
if self.tls_credential is not None:
|
239
|
+
add_if_not_none('tls', 'keyfile', self.tls_credential.keyfile)
|
240
|
+
add_if_not_none('tls', 'certfile', self.tls_credential.certfile)
|
226
241
|
return config
|
227
242
|
|
228
243
|
def probe_str(self):
|
@@ -267,12 +282,19 @@ class SkyServiceSpec:
|
|
267
282
|
f'replica{max_plural} (target QPS per replica: '
|
268
283
|
f'{self.target_qps_per_replica})')
|
269
284
|
|
285
|
+
def tls_str(self):
|
286
|
+
if self.tls_credential is None:
|
287
|
+
return 'No TLS Enabled'
|
288
|
+
return (f'Keyfile: {self.tls_credential.keyfile}, '
|
289
|
+
f'Certfile: {self.tls_credential.certfile}')
|
290
|
+
|
270
291
|
def __repr__(self) -> str:
|
271
292
|
return textwrap.dedent(f"""\
|
272
293
|
Readiness probe method: {self.probe_str()}
|
273
294
|
Readiness initial delay seconds: {self.initial_delay_seconds}
|
274
295
|
Readiness probe timeout seconds: {self.readiness_timeout_seconds}
|
275
296
|
Replica autoscaling policy: {self.autoscaling_policy_str()}
|
297
|
+
TLS Certificates: {self.tls_str()}
|
276
298
|
Spot Policy: {self.spot_policy_str()}
|
277
299
|
Load Balancing Policy: {self.load_balancing_policy}
|
278
300
|
""")
|
@@ -306,6 +328,15 @@ class SkyServiceSpec:
|
|
306
328
|
def post_data(self) -> Optional[Dict[str, Any]]:
|
307
329
|
return self._post_data
|
308
330
|
|
331
|
+
@property
|
332
|
+
def tls_credential(self) -> Optional[serve_utils.TLSCredential]:
|
333
|
+
return self._tls_credential
|
334
|
+
|
335
|
+
@tls_credential.setter
|
336
|
+
def tls_credential(self,
|
337
|
+
value: Optional[serve_utils.TLSCredential]) -> None:
|
338
|
+
self._tls_credential = value
|
339
|
+
|
309
340
|
@property
|
310
341
|
def readiness_headers(self) -> Optional[Dict[str, str]]:
|
311
342
|
return self._readiness_headers
|
sky/skylet/constants.py
CHANGED
@@ -86,7 +86,7 @@ TASK_ID_LIST_ENV_VAR = 'SKYPILOT_TASK_IDS'
|
|
86
86
|
# cluster yaml is updated.
|
87
87
|
#
|
88
88
|
# TODO(zongheng,zhanghao): make the upgrading of skylet automatic?
|
89
|
-
SKYLET_VERSION = '
|
89
|
+
SKYLET_VERSION = '10'
|
90
90
|
# The version of the lib files that skylet/jobs use. Whenever there is an API
|
91
91
|
# change for the job_lib or log_lib, we need to bump this version, so that the
|
92
92
|
# user can be notified to update their SkyPilot version on the remote cluster.
|
sky/skylet/events.py
CHANGED
@@ -13,6 +13,8 @@ from sky import clouds
|
|
13
13
|
from sky import sky_logging
|
14
14
|
from sky.backends import cloud_vm_ray_backend
|
15
15
|
from sky.clouds import cloud_registry
|
16
|
+
from sky.jobs import scheduler as managed_job_scheduler
|
17
|
+
from sky.jobs import state as managed_job_state
|
16
18
|
from sky.jobs import utils as managed_job_utils
|
17
19
|
from sky.serve import serve_utils
|
18
20
|
from sky.skylet import autostop_lib
|
@@ -67,12 +69,13 @@ class JobSchedulerEvent(SkyletEvent):
|
|
67
69
|
job_lib.scheduler.schedule_step(force_update_jobs=True)
|
68
70
|
|
69
71
|
|
70
|
-
class
|
71
|
-
"""Skylet event for updating managed
|
72
|
+
class ManagedJobEvent(SkyletEvent):
|
73
|
+
"""Skylet event for updating and scheduling managed jobs."""
|
72
74
|
EVENT_INTERVAL_SECONDS = 300
|
73
75
|
|
74
76
|
def _run(self):
|
75
77
|
managed_job_utils.update_managed_job_status()
|
78
|
+
managed_job_scheduler.maybe_schedule_next_jobs()
|
76
79
|
|
77
80
|
|
78
81
|
class ServiceUpdateEvent(SkyletEvent):
|
@@ -116,7 +119,8 @@ class AutostopEvent(SkyletEvent):
|
|
116
119
|
logger.debug('autostop_config not set. Skipped.')
|
117
120
|
return
|
118
121
|
|
119
|
-
if job_lib.is_cluster_idle()
|
122
|
+
if (job_lib.is_cluster_idle() and
|
123
|
+
not managed_job_state.get_num_alive_jobs()):
|
120
124
|
idle_minutes = (time.time() -
|
121
125
|
autostop_lib.get_last_active_time()) // 60
|
122
126
|
logger.debug(
|
sky/skylet/job_lib.py
CHANGED
@@ -10,9 +10,8 @@ import pathlib
|
|
10
10
|
import shlex
|
11
11
|
import signal
|
12
12
|
import sqlite3
|
13
|
-
import subprocess
|
14
13
|
import time
|
15
|
-
from typing import Any, Dict, List, Optional
|
14
|
+
from typing import Any, Dict, List, Optional, Sequence
|
16
15
|
|
17
16
|
import colorama
|
18
17
|
import filelock
|
@@ -23,6 +22,7 @@ from sky.skylet import constants
|
|
23
22
|
from sky.utils import common_utils
|
24
23
|
from sky.utils import db_utils
|
25
24
|
from sky.utils import log_utils
|
25
|
+
from sky.utils import subprocess_utils
|
26
26
|
|
27
27
|
logger = sky_logging.init_logger(__name__)
|
28
28
|
|
@@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
|
|
162
162
|
def nonterminal_statuses(cls) -> List['JobStatus']:
|
163
163
|
return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]
|
164
164
|
|
165
|
-
def is_terminal(self):
|
165
|
+
def is_terminal(self) -> bool:
|
166
166
|
return self not in self.nonterminal_statuses()
|
167
167
|
|
168
|
-
|
168
|
+
@classmethod
|
169
|
+
def user_code_failure_states(cls) -> Sequence['JobStatus']:
|
170
|
+
return (cls.FAILED, cls.FAILED_SETUP)
|
171
|
+
|
172
|
+
def __lt__(self, other: 'JobStatus') -> bool:
|
169
173
|
return list(JobStatus).index(self) < list(JobStatus).index(other)
|
170
174
|
|
171
|
-
def colored_str(self):
|
175
|
+
def colored_str(self) -> str:
|
172
176
|
color = _JOB_STATUS_TO_COLOR[self]
|
173
177
|
return f'{color}{self.value}{colorama.Style.RESET_ALL}'
|
174
178
|
|
@@ -205,31 +209,7 @@ class JobScheduler:
|
|
205
209
|
_CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} '
|
206
210
|
f'WHERE job_id={job_id!r}'))
|
207
211
|
_CONN.commit()
|
208
|
-
|
209
|
-
# instead of being a child of the current process. This is important to
|
210
|
-
# avoid a chain of driver processes (job driver can call schedule_step()
|
211
|
-
# to submit new jobs, and the new job can also call schedule_step()
|
212
|
-
# recursively).
|
213
|
-
#
|
214
|
-
# echo $! will output the PID of the last background process started
|
215
|
-
# in the current shell, so we can retrieve it and record in the DB.
|
216
|
-
#
|
217
|
-
# TODO(zhwu): A more elegant solution is to use another daemon process
|
218
|
-
# to be in charge of starting these driver processes, instead of
|
219
|
-
# starting them in the current process.
|
220
|
-
wrapped_cmd = (f'nohup bash -c {shlex.quote(run_cmd)} '
|
221
|
-
'</dev/null >/dev/null 2>&1 & echo $!')
|
222
|
-
proc = subprocess.run(wrapped_cmd,
|
223
|
-
stdout=subprocess.PIPE,
|
224
|
-
stderr=subprocess.PIPE,
|
225
|
-
stdin=subprocess.DEVNULL,
|
226
|
-
start_new_session=True,
|
227
|
-
check=True,
|
228
|
-
shell=True,
|
229
|
-
text=True)
|
230
|
-
# Get the PID of the detached process
|
231
|
-
pid = int(proc.stdout.strip())
|
232
|
-
|
212
|
+
pid = subprocess_utils.launch_new_process_tree(run_cmd)
|
233
213
|
# TODO(zhwu): Backward compatibility, remove this check after 0.10.0.
|
234
214
|
# This is for the case where the job is submitted with SkyPilot older
|
235
215
|
# than #4318, using ray job submit.
|
sky/skylet/log_lib.py
CHANGED
@@ -25,9 +25,9 @@ from sky.utils import log_utils
|
|
25
25
|
from sky.utils import subprocess_utils
|
26
26
|
from sky.utils import ux_utils
|
27
27
|
|
28
|
-
|
29
|
-
|
30
|
-
|
28
|
+
SKY_LOG_WAITING_GAP_SECONDS = 1
|
29
|
+
SKY_LOG_WAITING_MAX_RETRY = 5
|
30
|
+
SKY_LOG_TAILING_GAP_SECONDS = 0.2
|
31
31
|
# Peek the head of the lines to check if we need to start
|
32
32
|
# streaming when tail > 0.
|
33
33
|
PEEK_HEAD_LINES_FOR_START_STREAM = 20
|
@@ -336,7 +336,7 @@ def _follow_job_logs(file,
|
|
336
336
|
]:
|
337
337
|
if wait_last_logs:
|
338
338
|
# Wait all the logs are printed before exit.
|
339
|
-
time.sleep(1 +
|
339
|
+
time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS)
|
340
340
|
wait_last_logs = False
|
341
341
|
continue
|
342
342
|
status_str = status.value if status is not None else 'None'
|
@@ -345,7 +345,7 @@ def _follow_job_logs(file,
|
|
345
345
|
f'Job finished (status: {status_str}).'))
|
346
346
|
return
|
347
347
|
|
348
|
-
time.sleep(
|
348
|
+
time.sleep(SKY_LOG_TAILING_GAP_SECONDS)
|
349
349
|
status = job_lib.get_status_no_lock(job_id)
|
350
350
|
|
351
351
|
|
@@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int],
|
|
426
426
|
retry_cnt += 1
|
427
427
|
if os.path.exists(log_path) and status != job_lib.JobStatus.INIT:
|
428
428
|
break
|
429
|
-
if retry_cnt >=
|
429
|
+
if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY:
|
430
430
|
print(
|
431
431
|
f'{colorama.Fore.RED}ERROR: Logs for '
|
432
432
|
f'{job_str} (status: {status.value}) does not exist '
|
433
433
|
f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}')
|
434
434
|
return
|
435
|
-
print(f'INFO: Waiting {
|
435
|
+
print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs '
|
436
436
|
'to be written...')
|
437
|
-
time.sleep(
|
437
|
+
time.sleep(SKY_LOG_WAITING_GAP_SECONDS)
|
438
438
|
status = job_lib.update_job_status([job_id], silent=True)[0]
|
439
439
|
|
440
440
|
start_stream_at = LOG_FILE_START_STREAMING_AT
|