skypilot-nightly 1.0.0.dev20250730__py3-none-any.whl → 1.0.0.dev20250731__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +4 -1
  3. sky/backends/cloud_vm_ray_backend.py +4 -3
  4. sky/catalog/__init__.py +3 -3
  5. sky/catalog/aws_catalog.py +12 -0
  6. sky/catalog/common.py +2 -2
  7. sky/catalog/data_fetchers/fetch_aws.py +13 -1
  8. sky/client/cli/command.py +448 -53
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/config.html +1 -1
  14. sky/dashboard/out/index.html +1 -1
  15. sky/dashboard/out/infra/[context].html +1 -1
  16. sky/dashboard/out/infra.html +1 -1
  17. sky/dashboard/out/jobs/[job].html +1 -1
  18. sky/dashboard/out/jobs.html +1 -1
  19. sky/dashboard/out/users.html +1 -1
  20. sky/dashboard/out/volumes.html +1 -1
  21. sky/dashboard/out/workspace/new.html +1 -1
  22. sky/dashboard/out/workspaces/[name].html +1 -1
  23. sky/dashboard/out/workspaces.html +1 -1
  24. sky/jobs/__init__.py +3 -0
  25. sky/jobs/client/sdk.py +80 -3
  26. sky/jobs/controller.py +76 -25
  27. sky/jobs/recovery_strategy.py +80 -34
  28. sky/jobs/scheduler.py +68 -20
  29. sky/jobs/server/core.py +228 -136
  30. sky/jobs/server/server.py +40 -0
  31. sky/jobs/state.py +129 -24
  32. sky/jobs/utils.py +109 -51
  33. sky/provision/nebius/constants.py +3 -0
  34. sky/py.typed +0 -0
  35. sky/resources.py +16 -12
  36. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  37. sky/serve/autoscalers.py +8 -0
  38. sky/serve/client/impl.py +188 -0
  39. sky/serve/client/sdk.py +12 -82
  40. sky/serve/constants.py +5 -1
  41. sky/serve/controller.py +5 -0
  42. sky/serve/replica_managers.py +112 -37
  43. sky/serve/serve_state.py +16 -6
  44. sky/serve/serve_utils.py +274 -77
  45. sky/serve/server/core.py +8 -525
  46. sky/serve/server/impl.py +709 -0
  47. sky/serve/service.py +13 -9
  48. sky/serve/service_spec.py +74 -4
  49. sky/server/constants.py +1 -1
  50. sky/server/requests/payloads.py +33 -0
  51. sky/server/requests/requests.py +18 -1
  52. sky/server/requests/serializers/decoders.py +12 -3
  53. sky/server/requests/serializers/encoders.py +13 -2
  54. sky/skylet/events.py +9 -0
  55. sky/skypilot_config.py +24 -21
  56. sky/task.py +41 -11
  57. sky/templates/jobs-controller.yaml.j2 +3 -0
  58. sky/templates/sky-serve-controller.yaml.j2 +18 -2
  59. sky/users/server.py +1 -1
  60. sky/utils/command_runner.py +4 -2
  61. sky/utils/controller_utils.py +14 -10
  62. sky/utils/dag_utils.py +4 -2
  63. sky/utils/db/migration_utils.py +2 -4
  64. sky/utils/schemas.py +24 -19
  65. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/METADATA +1 -1
  66. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/RECORD +72 -68
  67. /sky/dashboard/out/_next/static/{_r2LwCFLjlWjZDUIJQG_V → oKqDxFQ88cquF4nQGE_0w}/_buildManifest.js +0 -0
  68. /sky/dashboard/out/_next/static/{_r2LwCFLjlWjZDUIJQG_V → oKqDxFQ88cquF4nQGE_0w}/_ssgManifest.js +0 -0
  69. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/WHEEL +0 -0
  70. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/entry_points.txt +0 -0
  71. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/licenses/LICENSE +0 -0
  72. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250731.dist-info}/top_level.txt +0 -0
sky/serve/service.py CHANGED
@@ -222,7 +222,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
222
222
  requested_resources_str=backend_utils.get_task_resources_str(task),
223
223
  load_balancing_policy=service_spec.load_balancing_policy,
224
224
  status=serve_state.ServiceStatus.CONTROLLER_INIT,
225
- tls_encrypted=service_spec.tls_credential is not None)
225
+ tls_encrypted=service_spec.tls_credential is not None,
226
+ pool=service_spec.pool)
226
227
  # Directly throw an error here. See sky/serve/api.py::up
227
228
  # for more details.
228
229
  if not success:
@@ -292,14 +293,17 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
292
293
  # TODO(tian): Probably we could enable multiple ports specified in
293
294
  # service spec and we could start multiple load balancers.
294
295
  # After that, we will have a mapping from replica port to endpoint.
295
- load_balancer_process = multiprocessing.Process(
296
- target=ux_utils.RedirectOutputForProcess(
297
- load_balancer.run_load_balancer,
298
- load_balancer_log_file).run,
299
- args=(controller_addr, load_balancer_port,
300
- service_spec.load_balancing_policy,
301
- service_spec.tls_credential))
302
- load_balancer_process.start()
296
+ # NOTE(tian): We don't need the load balancer for cluster pool.
297
+ # Skip the load balancer process for cluster pool.
298
+ if not service_spec.pool:
299
+ load_balancer_process = multiprocessing.Process(
300
+ target=ux_utils.RedirectOutputForProcess(
301
+ load_balancer.run_load_balancer,
302
+ load_balancer_log_file).run,
303
+ args=(controller_addr, load_balancer_port,
304
+ service_spec.load_balancing_policy,
305
+ service_spec.tls_credential))
306
+ load_balancer_process.start()
303
307
 
304
308
  if not is_recovery:
305
309
  serve_state.set_service_load_balancer_port(
sky/serve/service_spec.py CHANGED
@@ -43,7 +43,33 @@ class SkyServiceSpec:
43
43
  upscale_delay_seconds: Optional[int] = None,
44
44
  downscale_delay_seconds: Optional[int] = None,
45
45
  load_balancing_policy: Optional[str] = None,
46
+ pool: Optional[bool] = None,
46
47
  ) -> None:
48
+ if pool:
49
+ for unsupported_field in [
50
+ 'max_replicas',
51
+ 'num_overprovision',
52
+ 'target_qps_per_replica',
53
+ 'upscale_delay_seconds',
54
+ 'downscale_delay_seconds',
55
+ 'base_ondemand_fallback_replicas',
56
+ 'dynamic_ondemand_fallback',
57
+ 'spot_placer',
58
+ 'load_balancing_policy',
59
+ 'ports',
60
+ 'post_data',
61
+ 'tls_credential',
62
+ 'readiness_headers',
63
+ ]:
64
+ if locals()[unsupported_field] is not None:
65
+ with ux_utils.print_exception_no_traceback():
66
+ raise ValueError(
67
+ f'{unsupported_field} is not supported for pool.')
68
+ if max_replicas is not None and max_replicas != min_replicas:
69
+ with ux_utils.print_exception_no_traceback():
70
+ raise ValueError('Autoscaling is not supported for pool '
71
+ 'for now.')
72
+
47
73
  if max_replicas is not None and max_replicas < min_replicas:
48
74
  with ux_utils.print_exception_no_traceback():
49
75
  raise ValueError('max_replicas must be greater than or '
@@ -96,6 +122,7 @@ class SkyServiceSpec:
96
122
  self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds
97
123
  self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds
98
124
  self._load_balancing_policy: Optional[str] = load_balancing_policy
125
+ self._pool: Optional[bool] = pool
99
126
 
100
127
  self._use_ondemand_fallback: bool = (
101
128
  self.dynamic_ondemand_fallback is not None and
@@ -115,7 +142,7 @@ class SkyServiceSpec:
115
142
 
116
143
  service_config: Dict[str, Any] = {}
117
144
 
118
- readiness_section = config['readiness_probe']
145
+ readiness_section = config.get('readiness_probe', '/')
119
146
  if isinstance(readiness_section, str):
120
147
  service_config['readiness_path'] = readiness_section
121
148
  initial_delay_seconds = None
@@ -157,8 +184,29 @@ class SkyServiceSpec:
157
184
  raise ValueError('Port must be between 1 and 65535.')
158
185
  service_config['ports'] = str(ports) if ports is not None else None
159
186
 
187
+ pool_config = config.get('pool', None)
188
+ if pool_config is not None:
189
+ service_config['pool'] = pool_config
190
+
160
191
  policy_section = config.get('replica_policy', None)
192
+ if policy_section is not None and pool_config:
193
+ with ux_utils.print_exception_no_traceback():
194
+ raise ValueError('Cannot specify `replica_policy` for cluster '
195
+ 'pool. Only `workers: <num>` is supported '
196
+ 'for cluster pool now.')
197
+
161
198
  simplified_policy_section = config.get('replicas', None)
199
+ workers_config = config.get('workers', None)
200
+ if simplified_policy_section is not None and workers_config is not None:
201
+ with ux_utils.print_exception_no_traceback():
202
+ raise ValueError('Cannot specify both `replicas` and `workers`.'
203
+ ' Please use one of them.')
204
+ if simplified_policy_section is not None and pool_config:
205
+ with ux_utils.print_exception_no_traceback():
206
+ raise ValueError('Cannot specify `replicas` for cluster pool. '
207
+ 'Please use `workers` instead.')
208
+ if simplified_policy_section is None:
209
+ simplified_policy_section = workers_config
162
210
  if policy_section is None or simplified_policy_section is not None:
163
211
  if simplified_policy_section is not None:
164
212
  min_replicas = simplified_policy_section
@@ -239,6 +287,13 @@ class SkyServiceSpec:
239
287
  config[section] = dict()
240
288
  config[section][key] = value
241
289
 
290
+ add_if_not_none('pool', None, self._pool)
291
+
292
+ if self.pool:
293
+ # For pool, currently only `workers: <num>` is supported.
294
+ add_if_not_none('workers', None, self.min_replicas)
295
+ return config
296
+
242
297
  add_if_not_none('readiness_probe', 'path', self.readiness_path)
243
298
  add_if_not_none('readiness_probe', 'initial_delay_seconds',
244
299
  self.initial_delay_seconds)
@@ -306,10 +361,14 @@ class SkyServiceSpec:
306
361
  return ' '.join(policy_strs)
307
362
 
308
363
  def autoscaling_policy_str(self):
364
+ if self.pool:
365
+ # We only support fixed-size pool for now.
366
+ return f'Fixed-size ({self.min_replicas} workers)'
309
367
  # TODO(MaoZiming): Update policy_str
368
+ noun = 'worker' if self.pool else 'replica'
310
369
  min_plural = '' if self.min_replicas == 1 else 's'
311
370
  if self.max_replicas == self.min_replicas or self.max_replicas is None:
312
- return f'Fixed {self.min_replicas} replica{min_plural}'
371
+ return f'Fixed {self.min_replicas} {noun}{min_plural}'
313
372
  # Already checked in __init__.
314
373
  assert self.target_qps_per_replica is not None
315
374
  # TODO(tian): Refactor to contain more information
@@ -319,8 +378,8 @@ class SkyServiceSpec:
319
378
  overprovision_str = (
320
379
  f' with {self.num_overprovision} overprovisioned replicas')
321
380
  return (f'Autoscaling from {self.min_replicas} to {self.max_replicas} '
322
- f'replica{max_plural}{overprovision_str} (target QPS per '
323
- f'replica: {self.target_qps_per_replica})')
381
+ f'{noun}{max_plural}{overprovision_str} (target QPS per '
382
+ f'{noun}: {self.target_qps_per_replica})')
324
383
 
325
384
  def set_ports(self, ports: str) -> None:
326
385
  self._ports = ports
@@ -332,6 +391,10 @@ class SkyServiceSpec:
332
391
  f'Certfile: {self.tls_credential.certfile}')
333
392
 
334
393
  def __repr__(self) -> str:
394
+ if self.pool:
395
+ return textwrap.dedent(f"""\
396
+ Worker policy: {self.autoscaling_policy_str()}
397
+ """)
335
398
  return textwrap.dedent(f"""\
336
399
  Readiness probe method: {self.probe_str()}
337
400
  Readiness initial delay seconds: {self.initial_delay_seconds}
@@ -420,3 +483,10 @@ class SkyServiceSpec:
420
483
  def load_balancing_policy(self) -> str:
421
484
  return lb_policies.LoadBalancingPolicy.make_policy_name(
422
485
  self._load_balancing_policy)
486
+
487
+ @property
488
+ def pool(self) -> bool:
489
+ # This can happen for backward compatibility.
490
+ if not hasattr(self, '_pool'):
491
+ return False
492
+ return bool(self._pool)
sky/server/constants.py CHANGED
@@ -10,7 +10,7 @@ from sky.skylet import constants
10
10
  # based on version info is needed.
11
11
  # For more details and code guidelines, refer to:
12
12
  # https://docs.skypilot.co/en/latest/developers/CONTRIBUTING.html#backward-compatibility-guidelines
13
- API_VERSION = 11
13
+ API_VERSION = 12
14
14
 
15
15
  # The minimum peer API version that the code should still work with.
16
16
  # Notes (dev):
@@ -478,6 +478,8 @@ class JobsLaunchBody(RequestBody):
478
478
  """The request body for the jobs launch endpoint."""
479
479
  task: str
480
480
  name: Optional[str]
481
+ pool: Optional[str] = None
482
+ num_jobs: Optional[int] = None
481
483
 
482
484
  def to_kwargs(self) -> Dict[str, Any]:
483
485
  kwargs = super().to_kwargs()
@@ -500,6 +502,7 @@ class JobsCancelBody(RequestBody):
500
502
  job_ids: Optional[List[int]] = None
501
503
  all: bool = False
502
504
  all_users: bool = False
505
+ pool: Optional[str] = None
503
506
 
504
507
 
505
508
  class JobsLogsBody(RequestBody):
@@ -671,6 +674,36 @@ class JobsDownloadLogsBody(RequestBody):
671
674
  local_dir: str = constants.SKY_LOGS_DIRECTORY
672
675
 
673
676
 
677
+ class JobsPoolApplyBody(RequestBody):
678
+ """The request body for the jobs pool apply endpoint."""
679
+ task: str
680
+ pool_name: str
681
+ mode: serve.UpdateMode
682
+
683
+ def to_kwargs(self) -> Dict[str, Any]:
684
+ kwargs = super().to_kwargs()
685
+ dag = common.process_mounts_in_task_on_api_server(self.task,
686
+ self.env_vars,
687
+ workdir_only=False)
688
+ assert len(
689
+ dag.tasks) == 1, ('Must only specify one task in the DAG for '
690
+ 'a pool.', dag)
691
+ kwargs['task'] = dag.tasks[0]
692
+ return kwargs
693
+
694
+
695
+ class JobsPoolDownBody(RequestBody):
696
+ """The request body for the jobs pool down endpoint."""
697
+ pool_names: Optional[Union[str, List[str]]]
698
+ all: bool = False
699
+ purge: bool = False
700
+
701
+
702
+ class JobsPoolStatusBody(RequestBody):
703
+ """The request body for the jobs pool status endpoint."""
704
+ pool_names: Optional[Union[str, List[str]]]
705
+
706
+
674
707
  class UploadZipFileResponse(pydantic.BaseModel):
675
708
  """The response body for the upload zip file endpoint."""
676
709
  status: str
@@ -361,7 +361,24 @@ def managed_job_status_refresh_event():
361
361
  managed_job_utils.ha_recovery_for_consolidation_mode()
362
362
  # After recovery, we start the event loop.
363
363
  from sky.skylet import events
364
- event = events.ManagedJobEvent()
364
+ refresh_event = events.ManagedJobEvent()
365
+ scheduling_event = events.ManagedJobSchedulingEvent()
366
+ while True:
367
+ logger.info('=== Running managed job event ===')
368
+ refresh_event.run()
369
+ scheduling_event.run()
370
+ time.sleep(events.EVENT_CHECKING_INTERVAL_SECONDS)
371
+
372
+
373
+ def sky_serve_status_refresh_event():
374
+ """Refresh the managed job status for controller consolidation mode."""
375
+ # pylint: disable=import-outside-toplevel
376
+ from sky.serve import serve_utils
377
+ if not serve_utils.is_consolidation_mode():
378
+ return
379
+ # TODO(tian): Add HA recovery logic.
380
+ from sky.skylet import events
381
+ event = events.ServiceUpdateEvent()
365
382
  while True:
366
383
  time.sleep(events.EVENT_CHECKING_INTERVAL_SECONDS)
367
384
  event.run()
@@ -109,9 +109,8 @@ def decode_jobs_queue(return_value: List[dict],) -> List[Dict[str, Any]]:
109
109
  return jobs
110
110
 
111
111
 
112
- @register_decoders('serve.status')
113
- def decode_serve_status(return_value: List[dict]) -> List[Dict[str, Any]]:
114
- service_statuses = return_value
112
+ def _decode_serve_status(
113
+ service_statuses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
115
114
  for service_status in service_statuses:
116
115
  service_status['status'] = serve_state.ServiceStatus(
117
116
  service_status['status'])
@@ -122,6 +121,16 @@ def decode_serve_status(return_value: List[dict]) -> List[Dict[str, Any]]:
122
121
  return service_statuses
123
122
 
124
123
 
124
+ @register_decoders('serve.status')
125
+ def decode_serve_status(return_value: List[dict]) -> List[Dict[str, Any]]:
126
+ return _decode_serve_status(return_value)
127
+
128
+
129
+ @register_decoders('jobs.pool_status')
130
+ def decode_jobs_pool_status(return_value: List[dict]) -> List[Dict[str, Any]]:
131
+ return _decode_serve_status(return_value)
132
+
133
+
125
134
  @register_decoders('cost_report')
126
135
  def decode_cost_report(
127
136
  return_value: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -112,8 +112,7 @@ def encode_jobs_queue(jobs: List[dict],) -> List[Dict[str, Any]]:
112
112
  return jobs
113
113
 
114
114
 
115
- @register_encoder('serve.status')
116
- def encode_serve_status(
115
+ def _encode_serve_status(
117
116
  service_statuses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
118
117
  for service_status in service_statuses:
119
118
  service_status['status'] = service_status['status'].value
@@ -123,6 +122,18 @@ def encode_serve_status(
123
122
  return service_statuses
124
123
 
125
124
 
125
+ @register_encoder('serve.status')
126
+ def encode_serve_status(
127
+ service_statuses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
128
+ return _encode_serve_status(service_statuses)
129
+
130
+
131
+ @register_encoder('jobs.pool_status')
132
+ def encode_jobs_pool_status(
133
+ pool_statuses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
134
+ return _encode_serve_status(pool_statuses)
135
+
136
+
126
137
  @register_encoder('cost_report')
127
138
  def encode_cost_report(
128
139
  cost_report: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
sky/skylet/events.py CHANGED
@@ -75,7 +75,16 @@ class ManagedJobEvent(SkyletEvent):
75
75
  EVENT_INTERVAL_SECONDS = 300
76
76
 
77
77
  def _run(self):
78
+ logger.info('=== Updating managed job status ===')
78
79
  managed_job_utils.update_managed_jobs_statuses()
80
+
81
+
82
+ class ManagedJobSchedulingEvent(SkyletEvent):
83
+ """Skylet event for scheduling managed jobs."""
84
+ EVENT_INTERVAL_SECONDS = 20
85
+
86
+ def _run(self):
87
+ logger.info('=== Scheduling next jobs ===')
79
88
  managed_job_scheduler.maybe_schedule_next_jobs()
80
89
 
81
90
 
sky/skypilot_config.py CHANGED
@@ -495,6 +495,12 @@ def parse_and_validate_config_file(config_path: str) -> config_utils.Config:
495
495
  try:
496
496
  config_dict = common_utils.read_yaml(config_path)
497
497
  config = config_utils.Config.from_dict(config_dict)
498
+ # pop the db url from the config, and set it to the env var.
499
+ # this is to avoid db url (considered a sensitive value)
500
+ # being printed with the rest of the config.
501
+ db_url = config.pop_nested(('db',), None)
502
+ if db_url:
503
+ os.environ[constants.ENV_VAR_DB_CONNECTION_URI] = db_url
498
504
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
499
505
  logger.debug(f'Config loaded from {config_path}:\n'
500
506
  f'{common_utils.dump_yaml_str(dict(config))}')
@@ -556,21 +562,16 @@ def _reload_config_as_server() -> None:
556
562
  _set_loaded_config_path(None)
557
563
 
558
564
  server_config_path = _resolve_server_config_path()
559
- db_url_from_env = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
560
565
  server_config = _get_config_from_path(server_config_path)
561
- if db_url_from_env:
562
- server_config.set_nested(('db',), db_url_from_env)
563
-
564
- if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
565
- logger.debug(f'server config: \n'
566
- f'{common_utils.dump_yaml_str(dict(server_config))}')
567
-
568
- db_url = server_config.get_nested(('db',), None)
569
- if db_url and len(server_config.keys()) > 1:
570
- raise ValueError(
571
- 'if db config is specified, no other config is allowed')
566
+ # Get the db url from the env var. _get_config_from_path should have moved
567
+ # the db url specified in config file to the env var.
568
+ db_url = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
572
569
 
573
570
  if db_url:
571
+ if len(server_config.keys()) > 1:
572
+ raise ValueError(
573
+ 'If db config is specified, no other config is allowed')
574
+ logger.debug('retrieving config from database')
574
575
  with _DB_USE_LOCK:
575
576
  sqlalchemy_engine = sqlalchemy.create_engine(db_url,
576
577
  poolclass=NullPool)
@@ -591,14 +592,13 @@ def _reload_config_as_server() -> None:
591
592
 
592
593
  db_config = _get_config_yaml_from_db(API_SERVER_CONFIG_KEY)
593
594
  if db_config:
594
- if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
595
- logger.debug(
596
- f'Config loaded from db:\n'
597
- f'{common_utils.dump_yaml_str(dict(db_config))}')
598
595
  server_config = overlay_skypilot_config(server_config,
599
596
  db_config)
600
597
  # Close the engine to avoid connection leaks
601
598
  sqlalchemy_engine.dispose()
599
+ if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
600
+ logger.debug(f'server config: \n'
601
+ f'{common_utils.dump_yaml_str(dict(server_config))}')
602
602
  _set_loaded_config(server_config)
603
603
  _set_loaded_config_path(server_config_path)
604
604
 
@@ -681,6 +681,10 @@ def override_skypilot_config(
681
681
 
682
682
  disallowed_diff_keys = []
683
683
  for key in constants.SKIPPED_CLIENT_OVERRIDE_KEYS:
684
+ if key == ('db',):
685
+ # since db key is popped out of server config, the key is expected
686
+ # to be different between client and server.
687
+ continue
684
688
  value = override_configs.pop_nested(key, default_value=None)
685
689
  if (value is not None and
686
690
  value != original_config.get_nested(key, default_value=None)):
@@ -855,11 +859,11 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
855
859
 
856
860
  db_updated = False
857
861
  if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
858
- existing_db_url = get_nested(('db',), None)
862
+ existing_db_url = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
863
+ new_db_url = config.pop_nested(('db',), None)
864
+ if new_db_url and new_db_url != existing_db_url:
865
+ raise ValueError('Cannot change db url while server is running')
859
866
  if existing_db_url:
860
- new_db_url = config.get_nested(('db',), None)
861
- if new_db_url and new_db_url != existing_db_url:
862
- raise ValueError('Cannot change db url while server is running')
863
867
  with _DB_USE_LOCK:
864
868
  sqlalchemy_engine = sqlalchemy.create_engine(existing_db_url,
865
869
  poolclass=NullPool)
@@ -869,7 +873,6 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
869
873
  def _set_config_yaml_to_db(key: str,
870
874
  config: config_utils.Config):
871
875
  assert sqlalchemy_engine is not None
872
- config.pop_nested(('db',), None)
873
876
  config_str = common_utils.dump_yaml_str(dict(config))
874
877
  with orm.Session(sqlalchemy_engine) as session:
875
878
  if (sqlalchemy_engine.dialect.name ==
sky/task.py CHANGED
@@ -256,6 +256,7 @@ class Task:
256
256
  file_mounts_mapping: Optional[Dict[str, str]] = None,
257
257
  volume_mounts: Optional[List[volume_lib.VolumeMount]] = None,
258
258
  metadata: Optional[Dict[str, Any]] = None,
259
+ _user_specified_yaml: Optional[str] = None,
259
260
  ):
260
261
  """Initializes a Task.
261
262
 
@@ -381,6 +382,8 @@ class Task:
381
382
  if dag is not None:
382
383
  dag.add(self)
383
384
 
385
+ self._user_specified_yaml = _user_specified_yaml
386
+
384
387
  def validate(self,
385
388
  skip_file_mounts: bool = False,
386
389
  skip_workdir: bool = False):
@@ -525,6 +528,8 @@ class Task:
525
528
  env_overrides: Optional[List[Tuple[str, str]]] = None,
526
529
  secrets_overrides: Optional[List[Tuple[str, str]]] = None,
527
530
  ) -> 'Task':
531
+ user_specified_yaml = config.pop('_user_specified_yaml',
532
+ common_utils.dump_yaml_str(config))
528
533
  # More robust handling for 'envs': explicitly convert keys and values to
529
534
  # str, since users may pass '123' as keys/values which will get parsed
530
535
  # as int causing validate_schema() to fail.
@@ -590,19 +595,23 @@ class Task:
590
595
 
591
596
  # Fill in any Task.envs into file_mounts (src/dst paths, storage
592
597
  # name/source).
598
+ env_vars = config.get('envs', {})
599
+ secrets = config.get('secrets', {})
600
+ env_and_secrets = env_vars.copy()
601
+ env_and_secrets.update(secrets)
593
602
  if config.get('file_mounts') is not None:
594
603
  config['file_mounts'] = _fill_in_env_vars(config['file_mounts'],
595
- config.get('envs', {}))
604
+ env_and_secrets)
596
605
 
597
606
  # Fill in any Task.envs into service (e.g. MODEL_NAME).
598
607
  if config.get('service') is not None:
599
608
  config['service'] = _fill_in_env_vars(config['service'],
600
- config.get('envs', {}))
609
+ env_and_secrets)
601
610
 
602
611
  # Fill in any Task.envs into workdir
603
612
  if config.get('workdir') is not None:
604
613
  config['workdir'] = _fill_in_env_vars(config['workdir'],
605
- config.get('envs', {}))
614
+ env_and_secrets)
606
615
 
607
616
  task = Task(
608
617
  config.pop('name', None),
@@ -616,6 +625,7 @@ class Task:
616
625
  file_mounts_mapping=config.pop('file_mounts_mapping', None),
617
626
  volumes=config.pop('volumes', None),
618
627
  metadata=config.pop('_metadata', None),
628
+ _user_specified_yaml=user_specified_yaml,
619
629
  )
620
630
 
621
631
  # Create lists to store storage objects inlined in file_mounts.
@@ -736,9 +746,19 @@ class Task:
736
746
  task.set_resources(sky.Resources.from_yaml_config(resources_config))
737
747
 
738
748
  service = config.pop('service', None)
749
+ pool = config.pop('pool', None)
750
+ if service is not None and pool is not None:
751
+ with ux_utils.print_exception_no_traceback():
752
+ raise ValueError(
753
+ 'Cannot set both service and pool in the same task.')
754
+
739
755
  if service is not None:
740
756
  service = service_spec.SkyServiceSpec.from_yaml_config(service)
741
- task.set_service(service)
757
+ task.set_service(service)
758
+ elif pool is not None:
759
+ pool['pool'] = True
760
+ pool = service_spec.SkyServiceSpec.from_yaml_config(pool)
761
+ task.set_service(pool)
742
762
 
743
763
  volume_mounts = config.pop('volume_mounts', None)
744
764
  if volume_mounts is not None:
@@ -773,7 +793,8 @@ class Task:
773
793
  # TODO(zongheng): use
774
794
  # https://github.com/yaml/pyyaml/issues/165#issuecomment-430074049
775
795
  # to raise errors on duplicate keys.
776
- config = yaml.safe_load(f)
796
+ user_specified_yaml = f.read()
797
+ config = yaml.safe_load(user_specified_yaml)
777
798
 
778
799
  if isinstance(config, str):
779
800
  with ux_utils.print_exception_no_traceback():
@@ -782,6 +803,7 @@ class Task:
782
803
 
783
804
  if config is None:
784
805
  config = {}
806
+ config['_user_specified_yaml'] = user_specified_yaml
785
807
  return Task.from_yaml_config(config)
786
808
 
787
809
  def resolve_and_validate_volumes(self) -> None:
@@ -1537,11 +1559,22 @@ class Task:
1537
1559
  d[k] = v
1538
1560
  return d
1539
1561
 
1540
- def to_yaml_config(self, redact_secrets: bool = False) -> Dict[str, Any]:
1562
+ def to_yaml_config(self,
1563
+ use_user_specified_yaml: bool = False) -> Dict[str, Any]:
1541
1564
  """Returns a yaml-style dict representation of the task.
1542
1565
 
1543
1566
  INTERNAL: this method is internal-facing.
1544
1567
  """
1568
+ if use_user_specified_yaml:
1569
+ if self._user_specified_yaml is None:
1570
+ return self._to_yaml_config(redact_secrets=True)
1571
+ config = yaml.safe_load(self._user_specified_yaml)
1572
+ if config.get('secrets') is not None:
1573
+ config['secrets'] = {k: '<redacted>' for k in config['secrets']}
1574
+ return config
1575
+ return self._to_yaml_config()
1576
+
1577
+ def _to_yaml_config(self, redact_secrets: bool = False) -> Dict[str, Any]:
1545
1578
  config = {}
1546
1579
 
1547
1580
  def add_if_not_none(key, value, no_empty: bool = False):
@@ -1586,13 +1619,9 @@ class Task:
1586
1619
  # Add envs without redaction
1587
1620
  add_if_not_none('envs', self.envs, no_empty=True)
1588
1621
 
1589
- # Add secrets with redaction if requested
1590
1622
  secrets = self.secrets
1591
1623
  if secrets and redact_secrets:
1592
- secrets = {
1593
- k: '<redacted>' if isinstance(v, str) else v
1594
- for k, v in secrets.items()
1595
- }
1624
+ secrets = {k: '<redacted>' for k in secrets}
1596
1625
  add_if_not_none('secrets', secrets, no_empty=True)
1597
1626
 
1598
1627
  add_if_not_none('file_mounts', {})
@@ -1615,6 +1644,7 @@ class Task:
1615
1644
  ]
1616
1645
  # we manually check if its empty to not clog up the generated yaml
1617
1646
  add_if_not_none('_metadata', self._metadata if self._metadata else None)
1647
+ add_if_not_none('_user_specified_yaml', self._user_specified_yaml)
1618
1648
  return config
1619
1649
 
1620
1650
  def get_required_cloud_features(
@@ -57,6 +57,9 @@ run: |
57
57
  --job-id $SKYPILOT_INTERNAL_JOB_ID \
58
58
  {%- endif %}
59
59
  --env-file {{remote_env_file_path}} \
60
+ {%- if pool is not none %}
61
+ --pool {{pool}} \
62
+ {%- endif %}
60
63
  --priority {{priority}}
61
64
 
62
65
 
@@ -45,13 +45,29 @@ file_mounts:
45
45
  run: |
46
46
  # Activate the Python environment, so that cloud SDKs can be found in the
47
47
  # PATH.
48
+ {%- if consolidation_mode_job_id is none %}
48
49
  {{ sky_activate_python_env }}
50
+ {%- endif %}
49
51
  # Start sky serve service.
50
- python -u -m sky.serve.service \
52
+ {%- if consolidation_mode_job_id is not none %}
53
+ {{sky_python_cmd}} \
54
+ {%- else %}
55
+ python \
56
+ {%- endif %}
57
+ -u -m sky.serve.service \
51
58
  --service-name {{service_name}} \
52
59
  --task-yaml {{remote_task_yaml_path}} \
60
+ {%- if consolidation_mode_job_id is not none %}
61
+ --job-id {{consolidation_mode_job_id}} \
62
+ {%- else %}
53
63
  --job-id $SKYPILOT_INTERNAL_JOB_ID \
54
- >> {{controller_log_file}} 2>&1
64
+ {%- endif %}
65
+ >> {{controller_log_file}} 2>&1 \
66
+ {%- if consolidation_mode_job_id is not none %}
67
+ &
68
+ {%- endif %}
69
+ # For consolidation mode, we need to run the service in the background so
70
+ # that it can immediately return in serve.core.up().
55
71
 
56
72
  envs:
57
73
  {%- for env_name, env_value in controller_envs.items() %}
sky/users/server.py CHANGED
@@ -414,7 +414,7 @@ async def get_service_account_tokens(
414
414
 
415
415
  def _generate_service_account_user_id() -> str:
416
416
  """Generate a unique user ID for a service account."""
417
- random_suffix = secrets.token_hex(16) # 16 character hex string
417
+ random_suffix = secrets.token_hex(8) # 16 character hex string
418
418
  service_account_id = (f'sa-{random_suffix}')
419
419
  return service_account_id
420
420
 
@@ -201,6 +201,7 @@ class CommandRunner:
201
201
  separate_stderr: bool,
202
202
  skip_num_lines: int,
203
203
  source_bashrc: bool = False,
204
+ use_login: bool = True,
204
205
  ) -> str:
205
206
  """Returns the command to run."""
206
207
  if isinstance(cmd, list):
@@ -211,7 +212,7 @@ class CommandRunner:
211
212
  '/bin/bash',
212
213
  '--login',
213
214
  '-c',
214
- ]
215
+ ] if use_login else ['/bin/bash', '-c']
215
216
  if source_bashrc:
216
217
  command += [
217
218
  # Need this `-i` option to make sure `source ~/.bashrc` work.
@@ -1124,7 +1125,8 @@ class LocalProcessCommandRunner(CommandRunner):
1124
1125
  process_stream,
1125
1126
  separate_stderr,
1126
1127
  skip_num_lines=skip_num_lines,
1127
- source_bashrc=source_bashrc)
1128
+ source_bashrc=source_bashrc,
1129
+ use_login=False)
1128
1130
 
1129
1131
  log_dir = os.path.expanduser(os.path.dirname(log_path))
1130
1132
  os.makedirs(log_dir, exist_ok=True)