skypilot-nightly 1.0.0.dev20250114__py3-none-any.whl → 1.0.0.dev20250124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/cloud_vm_ray_backend.py +50 -67
  3. sky/check.py +31 -1
  4. sky/cli.py +11 -34
  5. sky/clouds/kubernetes.py +3 -3
  6. sky/clouds/service_catalog/kubernetes_catalog.py +14 -0
  7. sky/core.py +8 -5
  8. sky/data/storage.py +66 -14
  9. sky/global_user_state.py +1 -1
  10. sky/jobs/constants.py +8 -7
  11. sky/jobs/controller.py +19 -22
  12. sky/jobs/core.py +0 -2
  13. sky/jobs/recovery_strategy.py +114 -143
  14. sky/jobs/scheduler.py +283 -0
  15. sky/jobs/state.py +263 -21
  16. sky/jobs/utils.py +338 -96
  17. sky/provision/aws/config.py +48 -26
  18. sky/provision/gcp/instance_utils.py +15 -9
  19. sky/provision/kubernetes/instance.py +1 -1
  20. sky/provision/kubernetes/utils.py +76 -18
  21. sky/resources.py +1 -1
  22. sky/serve/autoscalers.py +359 -301
  23. sky/serve/controller.py +10 -8
  24. sky/serve/core.py +84 -7
  25. sky/serve/load_balancer.py +27 -10
  26. sky/serve/replica_managers.py +1 -3
  27. sky/serve/serve_state.py +10 -5
  28. sky/serve/serve_utils.py +28 -1
  29. sky/serve/service.py +4 -3
  30. sky/serve/service_spec.py +31 -0
  31. sky/skylet/constants.py +1 -1
  32. sky/skylet/events.py +7 -3
  33. sky/skylet/job_lib.py +10 -30
  34. sky/skylet/log_lib.py +8 -8
  35. sky/skylet/log_lib.pyi +3 -0
  36. sky/skylet/skylet.py +1 -1
  37. sky/templates/jobs-controller.yaml.j2 +7 -3
  38. sky/templates/sky-serve-controller.yaml.j2 +4 -0
  39. sky/utils/db_utils.py +18 -4
  40. sky/utils/kubernetes/deploy_remote_cluster.sh +5 -5
  41. sky/utils/resources_utils.py +25 -21
  42. sky/utils/schemas.py +13 -0
  43. sky/utils/subprocess_utils.py +48 -9
  44. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/METADATA +4 -1
  45. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/RECORD +49 -48
  46. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/LICENSE +0 -0
  47. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/WHEEL +0 -0
  48. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/entry_points.txt +0 -0
  49. {skypilot_nightly-1.0.0.dev20250114.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/top_level.txt +0 -0
sky/serve/controller.py CHANGED
@@ -67,9 +67,16 @@ class SkyServeController:
67
67
  try:
68
68
  replica_infos = serve_state.get_replica_infos(
69
69
  self._service_name)
70
+ # Use the active versions set by replica manager to make
71
+ # sure we only scale down the outdated replicas that are
72
+ # not used by the load balancer.
73
+ record = serve_state.get_service_from_name(self._service_name)
74
+ assert record is not None, ('No service record found for '
75
+ f'{self._service_name}')
76
+ active_versions = record['active_versions']
70
77
  logger.info(f'All replica info: {replica_infos}')
71
- scaling_options = self._autoscaler.evaluate_scaling(
72
- replica_infos)
78
+ scaling_options = self._autoscaler.generate_scaling_decisions(
79
+ replica_infos, active_versions)
73
80
  for scaling_option in scaling_options:
74
81
  logger.info(f'Scaling option received: {scaling_option}')
75
82
  if (scaling_option.operator ==
@@ -77,15 +84,10 @@ class SkyServeController:
77
84
  assert (scaling_option.target is None or isinstance(
78
85
  scaling_option.target, dict)), scaling_option
79
86
  self._replica_manager.scale_up(scaling_option.target)
80
- elif (scaling_option.operator ==
81
- autoscalers.AutoscalerDecisionOperator.SCALE_DOWN):
87
+ else:
82
88
  assert isinstance(scaling_option.target,
83
89
  int), scaling_option
84
90
  self._replica_manager.scale_down(scaling_option.target)
85
- else:
86
- with ux_utils.enable_traceback():
87
- logger.error('Error in scaling_option.operator: '
88
- f'{scaling_option.operator}')
89
91
  except Exception as e: # pylint: disable=broad-except
90
92
  # No matter what error happens, we should keep the
91
93
  # monitor running.
sky/serve/core.py CHANGED
@@ -1,6 +1,9 @@
1
1
  """SkyServe core APIs."""
2
2
  import re
3
+ import signal
4
+ import subprocess
3
5
  import tempfile
6
+ import threading
4
7
  from typing import Any, Dict, List, Optional, Tuple, Union
5
8
 
6
9
  import colorama
@@ -18,6 +21,7 @@ from sky.serve import serve_utils
18
21
  from sky.skylet import constants
19
22
  from sky.usage import usage_lib
20
23
  from sky.utils import admin_policy_utils
24
+ from sky.utils import command_runner
21
25
  from sky.utils import common_utils
22
26
  from sky.utils import controller_utils
23
27
  from sky.utils import resources_utils
@@ -91,6 +95,38 @@ def _validate_service_task(task: 'sky.Task') -> None:
91
95
  'Please specify the same port instead.')
92
96
 
93
97
 
98
+ def _rewrite_tls_credential_paths_and_get_tls_env_vars(
99
+ service_name: str, task: 'sky.Task') -> Dict[str, Any]:
100
+ """Rewrite the paths of TLS credentials in the task.
101
+
102
+ Args:
103
+ service_name: Name of the service.
104
+ task: sky.Task to rewrite.
105
+
106
+ Returns:
107
+ The generated template variables for TLS.
108
+ """
109
+ service_spec = task.service
110
+ # Already checked by _validate_service_task
111
+ assert service_spec is not None
112
+ if service_spec.tls_credential is None:
113
+ return {'use_tls': False}
114
+ remote_tls_keyfile = (
115
+ serve_utils.generate_remote_tls_keyfile_name(service_name))
116
+ remote_tls_certfile = (
117
+ serve_utils.generate_remote_tls_certfile_name(service_name))
118
+ tls_template_vars = {
119
+ 'use_tls': True,
120
+ 'remote_tls_keyfile': remote_tls_keyfile,
121
+ 'remote_tls_certfile': remote_tls_certfile,
122
+ 'local_tls_keyfile': service_spec.tls_credential.keyfile,
123
+ 'local_tls_certfile': service_spec.tls_credential.certfile,
124
+ }
125
+ service_spec.tls_credential = serve_utils.TLSCredential(
126
+ remote_tls_keyfile, remote_tls_certfile)
127
+ return tls_template_vars
128
+
129
+
94
130
  @usage_lib.entrypoint
95
131
  def up(
96
132
  task: 'sky.Task',
@@ -136,6 +172,9 @@ def up(
136
172
  controller_utils.maybe_translate_local_file_mounts_and_sync_up(
137
173
  task, path='serve')
138
174
 
175
+ tls_template_vars = _rewrite_tls_credential_paths_and_get_tls_env_vars(
176
+ service_name, task)
177
+
139
178
  with tempfile.NamedTemporaryFile(
140
179
  prefix=f'service-task-{service_name}-',
141
180
  mode='w',
@@ -164,6 +203,7 @@ def up(
164
203
  'remote_user_config_path': remote_config_yaml_path,
165
204
  'modified_catalogs':
166
205
  service_catalog_common.get_modified_catalog_file_mounts(),
206
+ **tls_template_vars,
167
207
  **controller_utils.shared_controller_vars_to_fill(
168
208
  controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
169
209
  remote_user_config_path=remote_config_yaml_path,
@@ -269,10 +309,16 @@ def up(
269
309
  else:
270
310
  lb_port = serve_utils.load_service_initialization_result(
271
311
  lb_port_payload)
272
- endpoint = backend_utils.get_endpoints(
312
+ socket_endpoint = backend_utils.get_endpoints(
273
313
  controller_handle.cluster_name, lb_port,
274
314
  skip_status_check=True).get(lb_port)
275
- assert endpoint is not None, 'Did not get endpoint for controller.'
315
+ assert socket_endpoint is not None, (
316
+ 'Did not get endpoint for controller.')
317
+ # Already checked by _validate_service_task
318
+ assert task.service is not None
319
+ protocol = ('http'
320
+ if task.service.tls_credential is None else 'https')
321
+ endpoint = f'{protocol}://{socket_endpoint}'
276
322
 
277
323
  sky_logging.print(
278
324
  f'{fore.CYAN}Service name: '
@@ -321,6 +367,7 @@ def update(
321
367
  service_name: Name of the service.
322
368
  """
323
369
  _validate_service_task(task)
370
+
324
371
  # Always apply the policy again here, even though it might have been applied
325
372
  # in the CLI. This is to ensure that we apply the policy to the final DAG
326
373
  # and get the mutated config.
@@ -329,6 +376,14 @@ def update(
329
376
  dag, _ = admin_policy_utils.apply(
330
377
  task, use_mutated_config_in_current_request=False)
331
378
  task = dag.tasks[0]
379
+
380
+ assert task.service is not None
381
+ if task.service.tls_credential is not None:
382
+ logger.warning('Updating TLS keyfile and certfile is not supported. '
383
+ 'Any updates to the keyfile and certfile will not take '
384
+ 'effect. To update TLS keyfile and certfile, please '
385
+ 'tear down the service and spin up a new one.')
386
+
332
387
  handle = backend_utils.is_controller_accessible(
333
388
  controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
334
389
  stopped_message=
@@ -596,6 +651,7 @@ def status(
596
651
  'requested_resources_str': (str) str representation of
597
652
  requested resources,
598
653
  'load_balancing_policy': (str) load balancing policy name,
654
+ 'tls_encrypted': (bool) whether the service is TLS encrypted,
599
655
  'replica_info': (List[Dict[str, Any]]) replica information,
600
656
  }
601
657
 
@@ -731,8 +787,29 @@ def tail_logs(
731
787
 
732
788
  backend = backend_utils.get_backend_from_handle(handle)
733
789
  assert isinstance(backend, backends.CloudVmRayBackend), backend
734
- backend.tail_serve_logs(handle,
735
- service_name,
736
- target,
737
- replica_id,
738
- follow=follow)
790
+
791
+ if target != serve_utils.ServiceComponent.REPLICA:
792
+ code = serve_utils.ServeCodeGen.stream_serve_process_logs(
793
+ service_name,
794
+ stream_controller=(
795
+ target == serve_utils.ServiceComponent.CONTROLLER),
796
+ follow=follow)
797
+ else:
798
+ assert replica_id is not None, service_name
799
+ code = serve_utils.ServeCodeGen.stream_replica_logs(
800
+ service_name, replica_id, follow)
801
+
802
+ # With the stdin=subprocess.DEVNULL, the ctrl-c will not directly
803
+ # kill the process, so we need to handle it manually here.
804
+ if threading.current_thread() is threading.main_thread():
805
+ signal.signal(signal.SIGINT, backend_utils.interrupt_handler)
806
+ signal.signal(signal.SIGTSTP, backend_utils.stop_handler)
807
+
808
+ # Refer to the notes in
809
+ # sky/backends/cloud_vm_ray_backend.py::CloudVmRayBackend::tail_logs.
810
+ backend.run_on_head(handle,
811
+ code,
812
+ stream_logs=True,
813
+ process_stream=False,
814
+ ssh_mode=command_runner.SshMode.INTERACTIVE,
815
+ stdin=subprocess.DEVNULL)
@@ -27,10 +27,12 @@ class SkyServeLoadBalancer:
27
27
  policy.
28
28
  """
29
29
 
30
- def __init__(self,
31
- controller_url: str,
32
- load_balancer_port: int,
33
- load_balancing_policy_name: Optional[str] = None) -> None:
30
+ def __init__(
31
+ self,
32
+ controller_url: str,
33
+ load_balancer_port: int,
34
+ load_balancing_policy_name: Optional[str] = None,
35
+ tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
34
36
  """Initialize the load balancer.
35
37
 
36
38
  Args:
@@ -38,6 +40,8 @@ class SkyServeLoadBalancer:
38
40
  load_balancer_port: The port where the load balancer listens to.
39
41
  load_balancing_policy_name: The name of the load balancing policy
40
42
  to use. Defaults to None.
43
+ tls_credentials: The TLS credentials for HTTPS endpoint. Defaults
44
+ to None.
41
45
  """
42
46
  self._app = fastapi.FastAPI()
43
47
  self._controller_url: str = controller_url
@@ -49,6 +53,8 @@ class SkyServeLoadBalancer:
49
53
  f'{load_balancing_policy_name}.')
50
54
  self._request_aggregator: serve_utils.RequestsAggregator = (
51
55
  serve_utils.RequestTimestamp())
56
+ self._tls_credential: Optional[serve_utils.TLSCredential] = (
57
+ tls_credential)
52
58
  # TODO(tian): httpx.Client has a resource limit of 100 max connections
53
59
  # for each client. We should wait for feedback on the best max
54
60
  # connections.
@@ -231,15 +237,25 @@ class SkyServeLoadBalancer:
231
237
  # Register controller synchronization task
232
238
  asyncio.create_task(self._sync_with_controller())
233
239
 
240
+ uvicorn_tls_kwargs = ({} if self._tls_credential is None else
241
+ self._tls_credential.dump_uvicorn_kwargs())
242
+
243
+ protocol = 'https' if self._tls_credential is not None else 'http'
244
+
234
245
  logger.info('SkyServe Load Balancer started on '
235
- f'http://0.0.0.0:{self._load_balancer_port}')
246
+ f'{protocol}://0.0.0.0:{self._load_balancer_port}')
236
247
 
237
- uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)
248
+ uvicorn.run(self._app,
249
+ host='0.0.0.0',
250
+ port=self._load_balancer_port,
251
+ **uvicorn_tls_kwargs)
238
252
 
239
253
 
240
- def run_load_balancer(controller_addr: str,
241
- load_balancer_port: int,
242
- load_balancing_policy_name: Optional[str] = None) -> None:
254
+ def run_load_balancer(
255
+ controller_addr: str,
256
+ load_balancer_port: int,
257
+ load_balancing_policy_name: Optional[str] = None,
258
+ tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
243
259
  """ Run the load balancer.
244
260
 
245
261
  Args:
@@ -251,7 +267,8 @@ def run_load_balancer(controller_addr: str,
251
267
  load_balancer = SkyServeLoadBalancer(
252
268
  controller_url=controller_addr,
253
269
  load_balancer_port=load_balancer_port,
254
- load_balancing_policy_name=load_balancing_policy_name)
270
+ load_balancing_policy_name=load_balancing_policy_name,
271
+ tls_credential=tls_credential)
255
272
  load_balancer.run()
256
273
 
257
274
 
@@ -998,9 +998,7 @@ class SkyPilotReplicaManager(ReplicaManager):
998
998
  # Re-raise the exception if it is not preempted.
999
999
  raise
1000
1000
  job_status = list(job_statuses.values())[0]
1001
- if job_status in [
1002
- job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
1003
- ]:
1001
+ if job_status in job_lib.JobStatus.user_code_failure_states():
1004
1002
  info.status_property.user_app_failed = True
1005
1003
  serve_state.add_or_update_replica(self._service_name,
1006
1004
  info.replica_id, info)
sky/serve/serve_state.py CHANGED
@@ -79,6 +79,9 @@ db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
79
79
  f'TEXT DEFAULT {json.dumps([])!r}')
80
80
  db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
81
81
  'load_balancing_policy', 'TEXT DEFAULT NULL')
82
+ # Whether the service's load balancer is encrypted with TLS.
83
+ db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'tls_encrypted',
84
+ 'INTEGER DEFAULT 0')
82
85
  _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'
83
86
 
84
87
 
@@ -245,7 +248,7 @@ _SERVICE_STATUS_TO_COLOR = {
245
248
 
246
249
  def add_service(name: str, controller_job_id: int, policy: str,
247
250
  requested_resources_str: str, load_balancing_policy: str,
248
- status: ServiceStatus) -> bool:
251
+ status: ServiceStatus, tls_encrypted: bool) -> bool:
249
252
  """Add a service in the database.
250
253
 
251
254
  Returns:
@@ -258,10 +261,11 @@ def add_service(name: str, controller_job_id: int, policy: str,
258
261
  """\
259
262
  INSERT INTO services
260
263
  (name, controller_job_id, status, policy,
261
- requested_resources_str, load_balancing_policy)
262
- VALUES (?, ?, ?, ?, ?, ?)""",
264
+ requested_resources_str, load_balancing_policy, tls_encrypted)
265
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
263
266
  (name, controller_job_id, status.value, policy,
264
- requested_resources_str, load_balancing_policy))
267
+ requested_resources_str, load_balancing_policy,
268
+ int(tls_encrypted)))
265
269
 
266
270
  except sqlite3.IntegrityError as e:
267
271
  if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
@@ -328,7 +332,7 @@ def set_service_load_balancer_port(service_name: str,
328
332
  def _get_service_from_row(row) -> Dict[str, Any]:
329
333
  (current_version, name, controller_job_id, controller_port,
330
334
  load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
331
- _, active_versions, load_balancing_policy) = row[:14]
335
+ _, active_versions, load_balancing_policy, tls_encrypted) = row[:15]
332
336
  if load_balancing_policy is None:
333
337
  # This entry in database was added in #4439, and it will always be set
334
338
  # to a str value. If it is None, it means it is an legacy entry and is
@@ -351,6 +355,7 @@ def _get_service_from_row(row) -> Dict[str, Any]:
351
355
  'active_versions': json.loads(active_versions),
352
356
  'requested_resources_str': requested_resources_str,
353
357
  'load_balancing_policy': load_balancing_policy,
358
+ 'tls_encrypted': bool(tls_encrypted),
354
359
  }
355
360
 
356
361
 
sky/serve/serve_utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """User interface with the SkyServe."""
2
2
  import base64
3
3
  import collections
4
+ import dataclasses
4
5
  import enum
5
6
  import os
6
7
  import pathlib
@@ -92,6 +93,19 @@ class UpdateMode(enum.Enum):
92
93
  BLUE_GREEN = 'blue_green'
93
94
 
94
95
 
96
+ @dataclasses.dataclass
97
+ class TLSCredential:
98
+ """TLS credential for the service."""
99
+ keyfile: str
100
+ certfile: str
101
+
102
+ def dump_uvicorn_kwargs(self) -> Dict[str, str]:
103
+ return {
104
+ 'ssl_keyfile': os.path.expanduser(self.keyfile),
105
+ 'ssl_certfile': os.path.expanduser(self.certfile),
106
+ }
107
+
108
+
95
109
  DEFAULT_UPDATE_MODE = UpdateMode.ROLLING
96
110
 
97
111
  _SIGNAL_TO_ERROR = {
@@ -243,6 +257,18 @@ def generate_replica_log_file_name(service_name: str, replica_id: int) -> str:
243
257
  return os.path.join(dir_name, f'replica_{replica_id}.log')
244
258
 
245
259
 
260
+ def generate_remote_tls_keyfile_name(service_name: str) -> str:
261
+ dir_name = generate_remote_service_dir_name(service_name)
262
+ # Don't expand here since it is used for remote machine.
263
+ return os.path.join(dir_name, 'tls_keyfile')
264
+
265
+
266
+ def generate_remote_tls_certfile_name(service_name: str) -> str:
267
+ dir_name = generate_remote_service_dir_name(service_name)
268
+ # Don't expand here since it is used for remote machine.
269
+ return os.path.join(dir_name, 'tls_certfile')
270
+
271
+
246
272
  def generate_replica_cluster_name(service_name: str, replica_id: int) -> str:
247
273
  return f'{service_name}-{replica_id}'
248
274
 
@@ -799,7 +825,8 @@ def get_endpoint(service_record: Dict[str, Any]) -> str:
799
825
  if endpoint is None:
800
826
  return '-'
801
827
  assert isinstance(endpoint, str), endpoint
802
- return endpoint
828
+ protocol = 'https' if service_record['tls_encrypted'] else 'http'
829
+ return f'{protocol}://{endpoint}'
803
830
 
804
831
 
805
832
  def format_service_table(service_records: List[Dict[str, Any]],
sky/serve/service.py CHANGED
@@ -151,7 +151,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
151
151
  policy=service_spec.autoscaling_policy_str(),
152
152
  requested_resources_str=backend_utils.get_task_resources_str(task),
153
153
  load_balancing_policy=service_spec.load_balancing_policy,
154
- status=serve_state.ServiceStatus.CONTROLLER_INIT)
154
+ status=serve_state.ServiceStatus.CONTROLLER_INIT,
155
+ tls_encrypted=service_spec.tls_credential is not None)
155
156
  # Directly throw an error here. See sky/serve/api.py::up
156
157
  # for more details.
157
158
  if not success:
@@ -214,7 +215,6 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
214
215
  serve_state.set_service_controller_port(service_name,
215
216
  controller_port)
216
217
 
217
- # TODO(tian): Support HTTPS.
218
218
  controller_addr = f'http://{controller_host}:{controller_port}'
219
219
 
220
220
  load_balancer_port = common_utils.find_free_port(
@@ -231,7 +231,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
231
231
  target=ux_utils.RedirectOutputForProcess(
232
232
  load_balancer.run_load_balancer,
233
233
  load_balancer_log_file).run,
234
- args=(controller_addr, load_balancer_port, policy_name))
234
+ args=(controller_addr, load_balancer_port, policy_name,
235
+ service_spec.tls_credential))
235
236
  load_balancer_process.start()
236
237
  serve_state.set_service_load_balancer_port(service_name,
237
238
  load_balancer_port)
sky/serve/service_spec.py CHANGED
@@ -9,6 +9,7 @@ import yaml
9
9
  from sky import serve
10
10
  from sky.serve import constants
11
11
  from sky.serve import load_balancing_policies as lb_policies
12
+ from sky.serve import serve_utils
12
13
  from sky.utils import common_utils
13
14
  from sky.utils import schemas
14
15
  from sky.utils import ux_utils
@@ -26,6 +27,7 @@ class SkyServiceSpec:
26
27
  max_replicas: Optional[int] = None,
27
28
  target_qps_per_replica: Optional[float] = None,
28
29
  post_data: Optional[Dict[str, Any]] = None,
30
+ tls_credential: Optional[serve_utils.TLSCredential] = None,
29
31
  readiness_headers: Optional[Dict[str, str]] = None,
30
32
  dynamic_ondemand_fallback: Optional[bool] = None,
31
33
  base_ondemand_fallback_replicas: Optional[int] = None,
@@ -72,6 +74,8 @@ class SkyServiceSpec:
72
74
  self._max_replicas: Optional[int] = max_replicas
73
75
  self._target_qps_per_replica: Optional[float] = target_qps_per_replica
74
76
  self._post_data: Optional[Dict[str, Any]] = post_data
77
+ self._tls_credential: Optional[serve_utils.TLSCredential] = (
78
+ tls_credential)
75
79
  self._readiness_headers: Optional[Dict[str, str]] = readiness_headers
76
80
  self._dynamic_ondemand_fallback: Optional[
77
81
  bool] = dynamic_ondemand_fallback
@@ -163,6 +167,14 @@ class SkyServiceSpec:
163
167
 
164
168
  service_config['load_balancing_policy'] = config.get(
165
169
  'load_balancing_policy', None)
170
+
171
+ tls_section = config.get('tls', None)
172
+ if tls_section is not None:
173
+ service_config['tls_credential'] = serve_utils.TLSCredential(
174
+ keyfile=tls_section.get('keyfile', None),
175
+ certfile=tls_section.get('certfile', None),
176
+ )
177
+
166
178
  return SkyServiceSpec(**service_config)
167
179
 
168
180
  @staticmethod
@@ -223,6 +235,9 @@ class SkyServiceSpec:
223
235
  self.downscale_delay_seconds)
224
236
  add_if_not_none('load_balancing_policy', None,
225
237
  self._load_balancing_policy)
238
+ if self.tls_credential is not None:
239
+ add_if_not_none('tls', 'keyfile', self.tls_credential.keyfile)
240
+ add_if_not_none('tls', 'certfile', self.tls_credential.certfile)
226
241
  return config
227
242
 
228
243
  def probe_str(self):
@@ -267,12 +282,19 @@ class SkyServiceSpec:
267
282
  f'replica{max_plural} (target QPS per replica: '
268
283
  f'{self.target_qps_per_replica})')
269
284
 
285
+ def tls_str(self):
286
+ if self.tls_credential is None:
287
+ return 'No TLS Enabled'
288
+ return (f'Keyfile: {self.tls_credential.keyfile}, '
289
+ f'Certfile: {self.tls_credential.certfile}')
290
+
270
291
  def __repr__(self) -> str:
271
292
  return textwrap.dedent(f"""\
272
293
  Readiness probe method: {self.probe_str()}
273
294
  Readiness initial delay seconds: {self.initial_delay_seconds}
274
295
  Readiness probe timeout seconds: {self.readiness_timeout_seconds}
275
296
  Replica autoscaling policy: {self.autoscaling_policy_str()}
297
+ TLS Certificates: {self.tls_str()}
276
298
  Spot Policy: {self.spot_policy_str()}
277
299
  Load Balancing Policy: {self.load_balancing_policy}
278
300
  """)
@@ -306,6 +328,15 @@ class SkyServiceSpec:
306
328
  def post_data(self) -> Optional[Dict[str, Any]]:
307
329
  return self._post_data
308
330
 
331
+ @property
332
+ def tls_credential(self) -> Optional[serve_utils.TLSCredential]:
333
+ return self._tls_credential
334
+
335
+ @tls_credential.setter
336
+ def tls_credential(self,
337
+ value: Optional[serve_utils.TLSCredential]) -> None:
338
+ self._tls_credential = value
339
+
309
340
  @property
310
341
  def readiness_headers(self) -> Optional[Dict[str, str]]:
311
342
  return self._readiness_headers
sky/skylet/constants.py CHANGED
@@ -86,7 +86,7 @@ TASK_ID_LIST_ENV_VAR = 'SKYPILOT_TASK_IDS'
86
86
  # cluster yaml is updated.
87
87
  #
88
88
  # TODO(zongheng,zhanghao): make the upgrading of skylet automatic?
89
- SKYLET_VERSION = '9'
89
+ SKYLET_VERSION = '10'
90
90
  # The version of the lib files that skylet/jobs use. Whenever there is an API
91
91
  # change for the job_lib or log_lib, we need to bump this version, so that the
92
92
  # user can be notified to update their SkyPilot version on the remote cluster.
sky/skylet/events.py CHANGED
@@ -13,6 +13,8 @@ from sky import clouds
13
13
  from sky import sky_logging
14
14
  from sky.backends import cloud_vm_ray_backend
15
15
  from sky.clouds import cloud_registry
16
+ from sky.jobs import scheduler as managed_job_scheduler
17
+ from sky.jobs import state as managed_job_state
16
18
  from sky.jobs import utils as managed_job_utils
17
19
  from sky.serve import serve_utils
18
20
  from sky.skylet import autostop_lib
@@ -67,12 +69,13 @@ class JobSchedulerEvent(SkyletEvent):
67
69
  job_lib.scheduler.schedule_step(force_update_jobs=True)
68
70
 
69
71
 
70
- class ManagedJobUpdateEvent(SkyletEvent):
71
- """Skylet event for updating managed job status."""
72
+ class ManagedJobEvent(SkyletEvent):
73
+ """Skylet event for updating and scheduling managed jobs."""
72
74
  EVENT_INTERVAL_SECONDS = 300
73
75
 
74
76
  def _run(self):
75
77
  managed_job_utils.update_managed_job_status()
78
+ managed_job_scheduler.maybe_schedule_next_jobs()
76
79
 
77
80
 
78
81
  class ServiceUpdateEvent(SkyletEvent):
@@ -116,7 +119,8 @@ class AutostopEvent(SkyletEvent):
116
119
  logger.debug('autostop_config not set. Skipped.')
117
120
  return
118
121
 
119
- if job_lib.is_cluster_idle():
122
+ if (job_lib.is_cluster_idle() and
123
+ not managed_job_state.get_num_alive_jobs()):
120
124
  idle_minutes = (time.time() -
121
125
  autostop_lib.get_last_active_time()) // 60
122
126
  logger.debug(
sky/skylet/job_lib.py CHANGED
@@ -10,9 +10,8 @@ import pathlib
10
10
  import shlex
11
11
  import signal
12
12
  import sqlite3
13
- import subprocess
14
13
  import time
15
- from typing import Any, Dict, List, Optional
14
+ from typing import Any, Dict, List, Optional, Sequence
16
15
 
17
16
  import colorama
18
17
  import filelock
@@ -23,6 +22,7 @@ from sky.skylet import constants
23
22
  from sky.utils import common_utils
24
23
  from sky.utils import db_utils
25
24
  from sky.utils import log_utils
25
+ from sky.utils import subprocess_utils
26
26
 
27
27
  logger = sky_logging.init_logger(__name__)
28
28
 
@@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
162
162
  def nonterminal_statuses(cls) -> List['JobStatus']:
163
163
  return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]
164
164
 
165
- def is_terminal(self):
165
+ def is_terminal(self) -> bool:
166
166
  return self not in self.nonterminal_statuses()
167
167
 
168
- def __lt__(self, other):
168
+ @classmethod
169
+ def user_code_failure_states(cls) -> Sequence['JobStatus']:
170
+ return (cls.FAILED, cls.FAILED_SETUP)
171
+
172
+ def __lt__(self, other: 'JobStatus') -> bool:
169
173
  return list(JobStatus).index(self) < list(JobStatus).index(other)
170
174
 
171
- def colored_str(self):
175
+ def colored_str(self) -> str:
172
176
  color = _JOB_STATUS_TO_COLOR[self]
173
177
  return f'{color}{self.value}{colorama.Style.RESET_ALL}'
174
178
 
@@ -205,31 +209,7 @@ class JobScheduler:
205
209
  _CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} '
206
210
  f'WHERE job_id={job_id!r}'))
207
211
  _CONN.commit()
208
- # Use nohup to ensure the job driver process is a separate process tree,
209
- # instead of being a child of the current process. This is important to
210
- # avoid a chain of driver processes (job driver can call schedule_step()
211
- # to submit new jobs, and the new job can also call schedule_step()
212
- # recursively).
213
- #
214
- # echo $! will output the PID of the last background process started
215
- # in the current shell, so we can retrieve it and record in the DB.
216
- #
217
- # TODO(zhwu): A more elegant solution is to use another daemon process
218
- # to be in charge of starting these driver processes, instead of
219
- # starting them in the current process.
220
- wrapped_cmd = (f'nohup bash -c {shlex.quote(run_cmd)} '
221
- '</dev/null >/dev/null 2>&1 & echo $!')
222
- proc = subprocess.run(wrapped_cmd,
223
- stdout=subprocess.PIPE,
224
- stderr=subprocess.PIPE,
225
- stdin=subprocess.DEVNULL,
226
- start_new_session=True,
227
- check=True,
228
- shell=True,
229
- text=True)
230
- # Get the PID of the detached process
231
- pid = int(proc.stdout.strip())
232
-
212
+ pid = subprocess_utils.launch_new_process_tree(run_cmd)
233
213
  # TODO(zhwu): Backward compatibility, remove this check after 0.10.0.
234
214
  # This is for the case where the job is submitted with SkyPilot older
235
215
  # than #4318, using ray job submit.
sky/skylet/log_lib.py CHANGED
@@ -25,9 +25,9 @@ from sky.utils import log_utils
25
25
  from sky.utils import subprocess_utils
26
26
  from sky.utils import ux_utils
27
27
 
28
- _SKY_LOG_WAITING_GAP_SECONDS = 1
29
- _SKY_LOG_WAITING_MAX_RETRY = 5
30
- _SKY_LOG_TAILING_GAP_SECONDS = 0.2
28
+ SKY_LOG_WAITING_GAP_SECONDS = 1
29
+ SKY_LOG_WAITING_MAX_RETRY = 5
30
+ SKY_LOG_TAILING_GAP_SECONDS = 0.2
31
31
  # Peek the head of the lines to check if we need to start
32
32
  # streaming when tail > 0.
33
33
  PEEK_HEAD_LINES_FOR_START_STREAM = 20
@@ -336,7 +336,7 @@ def _follow_job_logs(file,
336
336
  ]:
337
337
  if wait_last_logs:
338
338
  # Wait all the logs are printed before exit.
339
- time.sleep(1 + _SKY_LOG_TAILING_GAP_SECONDS)
339
+ time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS)
340
340
  wait_last_logs = False
341
341
  continue
342
342
  status_str = status.value if status is not None else 'None'
@@ -345,7 +345,7 @@ def _follow_job_logs(file,
345
345
  f'Job finished (status: {status_str}).'))
346
346
  return
347
347
 
348
- time.sleep(_SKY_LOG_TAILING_GAP_SECONDS)
348
+ time.sleep(SKY_LOG_TAILING_GAP_SECONDS)
349
349
  status = job_lib.get_status_no_lock(job_id)
350
350
 
351
351
 
@@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int],
426
426
  retry_cnt += 1
427
427
  if os.path.exists(log_path) and status != job_lib.JobStatus.INIT:
428
428
  break
429
- if retry_cnt >= _SKY_LOG_WAITING_MAX_RETRY:
429
+ if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY:
430
430
  print(
431
431
  f'{colorama.Fore.RED}ERROR: Logs for '
432
432
  f'{job_str} (status: {status.value}) does not exist '
433
433
  f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}')
434
434
  return
435
- print(f'INFO: Waiting {_SKY_LOG_WAITING_GAP_SECONDS}s for the logs '
435
+ print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs '
436
436
  'to be written...')
437
- time.sleep(_SKY_LOG_WAITING_GAP_SECONDS)
437
+ time.sleep(SKY_LOG_WAITING_GAP_SECONDS)
438
438
  status = job_lib.update_job_status([job_id], silent=True)[0]
439
439
 
440
440
  start_stream_at = LOG_FILE_START_STREAMING_AT