skypilot-nightly 1.0.0.dev20250730__py3-none-any.whl → 1.0.0.dev20250801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (81) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +4 -1
  3. sky/backends/cloud_vm_ray_backend.py +4 -3
  4. sky/catalog/__init__.py +3 -3
  5. sky/catalog/aws_catalog.py +12 -0
  6. sky/catalog/common.py +2 -2
  7. sky/catalog/data_fetchers/fetch_aws.py +13 -1
  8. sky/client/cli/command.py +452 -53
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/_next/static/chunks/{webpack-5adfc4d4b3db6f71.js → webpack-42cd1b19a6b01078.js} +1 -1
  11. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  12. sky/dashboard/out/clusters/[cluster].html +1 -1
  13. sky/dashboard/out/clusters.html +1 -1
  14. sky/dashboard/out/config.html +1 -1
  15. sky/dashboard/out/index.html +1 -1
  16. sky/dashboard/out/infra/[context].html +1 -1
  17. sky/dashboard/out/infra.html +1 -1
  18. sky/dashboard/out/jobs/[job].html +1 -1
  19. sky/dashboard/out/jobs.html +1 -1
  20. sky/dashboard/out/users.html +1 -1
  21. sky/dashboard/out/volumes.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/data/data_utils.py +21 -1
  26. sky/data/storage.py +12 -0
  27. sky/jobs/__init__.py +3 -0
  28. sky/jobs/client/sdk.py +80 -3
  29. sky/jobs/controller.py +76 -25
  30. sky/jobs/recovery_strategy.py +80 -34
  31. sky/jobs/scheduler.py +68 -20
  32. sky/jobs/server/core.py +228 -136
  33. sky/jobs/server/server.py +40 -0
  34. sky/jobs/state.py +129 -24
  35. sky/jobs/utils.py +109 -51
  36. sky/provision/nebius/constants.py +3 -0
  37. sky/provision/runpod/utils.py +27 -12
  38. sky/py.typed +0 -0
  39. sky/resources.py +16 -12
  40. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  41. sky/serve/autoscalers.py +8 -0
  42. sky/serve/client/impl.py +188 -0
  43. sky/serve/client/sdk.py +12 -82
  44. sky/serve/constants.py +5 -1
  45. sky/serve/controller.py +5 -0
  46. sky/serve/replica_managers.py +112 -37
  47. sky/serve/serve_state.py +16 -6
  48. sky/serve/serve_utils.py +274 -77
  49. sky/serve/server/core.py +8 -525
  50. sky/serve/server/impl.py +709 -0
  51. sky/serve/service.py +13 -9
  52. sky/serve/service_spec.py +74 -4
  53. sky/server/constants.py +1 -1
  54. sky/server/daemons.py +164 -0
  55. sky/server/requests/payloads.py +33 -0
  56. sky/server/requests/requests.py +2 -107
  57. sky/server/requests/serializers/decoders.py +12 -3
  58. sky/server/requests/serializers/encoders.py +13 -2
  59. sky/server/server.py +2 -1
  60. sky/server/uvicorn.py +2 -1
  61. sky/sky_logging.py +30 -0
  62. sky/skylet/constants.py +2 -1
  63. sky/skylet/events.py +9 -0
  64. sky/skypilot_config.py +24 -21
  65. sky/task.py +41 -11
  66. sky/templates/jobs-controller.yaml.j2 +3 -0
  67. sky/templates/sky-serve-controller.yaml.j2 +18 -2
  68. sky/users/server.py +1 -1
  69. sky/utils/command_runner.py +4 -2
  70. sky/utils/controller_utils.py +14 -10
  71. sky/utils/dag_utils.py +4 -2
  72. sky/utils/db/migration_utils.py +2 -4
  73. sky/utils/schemas.py +47 -19
  74. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/METADATA +1 -1
  75. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/RECORD +81 -76
  76. /sky/dashboard/out/_next/static/{_r2LwCFLjlWjZDUIJQG_V → f2fEsZwJxryJVOYRNtNKE}/_buildManifest.js +0 -0
  77. /sky/dashboard/out/_next/static/{_r2LwCFLjlWjZDUIJQG_V → f2fEsZwJxryJVOYRNtNKE}/_ssgManifest.js +0 -0
  78. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/WHEEL +0 -0
  79. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/entry_points.txt +0 -0
  80. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/licenses/LICENSE +0 -0
  81. {skypilot_nightly-1.0.0.dev20250730.dist-info → skypilot_nightly-1.0.0.dev20250801.dist-info}/top_level.txt +0 -0
sky/jobs/state.py CHANGED
@@ -100,6 +100,13 @@ job_info_table = sqlalchemy.Table(
100
100
  sqlalchemy.Column('original_user_yaml_path',
101
101
  sqlalchemy.Text,
102
102
  server_default=None),
103
+ sqlalchemy.Column('pool', sqlalchemy.Text, server_default=None),
104
+ sqlalchemy.Column('current_cluster_name',
105
+ sqlalchemy.Text,
106
+ server_default=None),
107
+ sqlalchemy.Column('job_id_on_pool_cluster',
108
+ sqlalchemy.Integer,
109
+ server_default=None),
103
110
  )
104
111
 
105
112
  ha_recovery_script_table = sqlalchemy.Table(
@@ -215,6 +222,9 @@ def _get_jobs_dict(r: 'row.RowMapping') -> Dict[str, Any]:
215
222
  'priority': r['priority'],
216
223
  'entrypoint': r['entrypoint'],
217
224
  'original_user_yaml_path': r['original_user_yaml_path'],
225
+ 'pool': r['pool'],
226
+ 'current_cluster_name': r['current_cluster_name'],
227
+ 'job_id_on_pool_cluster': r['job_id_on_pool_cluster'],
218
228
  }
219
229
 
220
230
 
@@ -451,8 +461,8 @@ def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
451
461
 
452
462
 
453
463
  @_init_db
454
- def set_job_info_without_job_id(name: str, workspace: str,
455
- entrypoint: str) -> int:
464
+ def set_job_info_without_job_id(name: str, workspace: str, entrypoint: str,
465
+ pool: Optional[str]) -> int:
456
466
  assert _SQLALCHEMY_ENGINE is not None
457
467
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
458
468
  if (_SQLALCHEMY_ENGINE.dialect.name ==
@@ -469,6 +479,7 @@ def set_job_info_without_job_id(name: str, workspace: str,
469
479
  schedule_state=ManagedJobScheduleState.INACTIVE.value,
470
480
  workspace=workspace,
471
481
  entrypoint=entrypoint,
482
+ pool=pool,
472
483
  )
473
484
 
474
485
  if (_SQLALCHEMY_ENGINE.dialect.name ==
@@ -1278,6 +1289,56 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1278
1289
  return updated_count == 0
1279
1290
 
1280
1291
 
1292
+ @_init_db
1293
+ def get_pool_from_job_id(job_id: int) -> Optional[str]:
1294
+ """Get the pool from the job id."""
1295
+ assert _SQLALCHEMY_ENGINE is not None
1296
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1297
+ pool = session.execute(
1298
+ sqlalchemy.select(job_info_table.c.pool).where(
1299
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1300
+ return pool[0] if pool else None
1301
+
1302
+
1303
+ @_init_db
1304
+ def set_current_cluster_name(job_id: int, current_cluster_name: str) -> None:
1305
+ """Set the current cluster name for a job."""
1306
+ assert _SQLALCHEMY_ENGINE is not None
1307
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1308
+ session.query(job_info_table).filter(
1309
+ job_info_table.c.spot_job_id == job_id).update(
1310
+ {job_info_table.c.current_cluster_name: current_cluster_name})
1311
+ session.commit()
1312
+
1313
+
1314
+ @_init_db
1315
+ def set_job_id_on_pool_cluster(job_id: int,
1316
+ job_id_on_pool_cluster: int) -> None:
1317
+ """Set the job id on the pool cluster for a job."""
1318
+ assert _SQLALCHEMY_ENGINE is not None
1319
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1320
+ session.query(job_info_table).filter(
1321
+ job_info_table.c.spot_job_id == job_id).update({
1322
+ job_info_table.c.job_id_on_pool_cluster: job_id_on_pool_cluster
1323
+ })
1324
+ session.commit()
1325
+
1326
+
1327
+ @_init_db
1328
+ def get_pool_submit_info(job_id: int) -> Tuple[Optional[str], Optional[int]]:
1329
+ """Get the cluster name and job id on the pool from the managed job id."""
1330
+ assert _SQLALCHEMY_ENGINE is not None
1331
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1332
+ info = session.execute(
1333
+ sqlalchemy.select(
1334
+ job_info_table.c.current_cluster_name,
1335
+ job_info_table.c.job_id_on_pool_cluster).where(
1336
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1337
+ if info is None:
1338
+ return None, None
1339
+ return info[0], info[1]
1340
+
1341
+
1281
1342
  @_init_db
1282
1343
  def scheduler_set_launching(job_id: int,
1283
1344
  current_state: ManagedJobScheduleState) -> None:
@@ -1398,28 +1459,68 @@ def get_num_launching_jobs() -> int:
1398
1459
  sqlalchemy.select(
1399
1460
  sqlalchemy.func.count() # pylint: disable=not-callable
1400
1461
  ).select_from(job_info_table).where(
1401
- job_info_table.c.schedule_state ==
1402
- ManagedJobScheduleState.LAUNCHING.value)).fetchone()[0]
1462
+ sqlalchemy.and_(
1463
+ job_info_table.c.schedule_state ==
1464
+ ManagedJobScheduleState.LAUNCHING.value,
1465
+ # We only count jobs that are not in the pool, because the
1466
+ # job in the pool does not actually calling the sky.launch.
1467
+ job_info_table.c.pool.is_(None)))).fetchone()[0]
1403
1468
 
1404
1469
 
1405
1470
  @_init_db
1406
- def get_num_alive_jobs() -> int:
1471
+ def get_num_alive_jobs(pool: Optional[str] = None) -> int:
1407
1472
  assert _SQLALCHEMY_ENGINE is not None
1408
1473
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1474
+ where_conditions = [
1475
+ job_info_table.c.schedule_state.in_([
1476
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1477
+ ManagedJobScheduleState.LAUNCHING.value,
1478
+ ManagedJobScheduleState.ALIVE.value,
1479
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1480
+ ])
1481
+ ]
1482
+
1483
+ if pool is not None:
1484
+ where_conditions.append(job_info_table.c.pool == pool)
1485
+
1409
1486
  return session.execute(
1410
1487
  sqlalchemy.select(
1411
1488
  sqlalchemy.func.count() # pylint: disable=not-callable
1412
1489
  ).select_from(job_info_table).where(
1413
- job_info_table.c.schedule_state.in_([
1414
- ManagedJobScheduleState.ALIVE_WAITING.value,
1415
- ManagedJobScheduleState.LAUNCHING.value,
1416
- ManagedJobScheduleState.ALIVE.value,
1417
- ManagedJobScheduleState.ALIVE_BACKOFF.value,
1418
- ]))).fetchone()[0]
1490
+ sqlalchemy.and_(*where_conditions))).fetchone()[0]
1491
+
1492
+
1493
+ @_init_db
1494
+ def get_nonterminal_job_ids_by_pool(pool: str,
1495
+ cluster_name: Optional[str] = None
1496
+ ) -> List[int]:
1497
+ """Get nonterminal job ids in a pool."""
1498
+ assert _SQLALCHEMY_ENGINE is not None
1499
+
1500
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1501
+ query = sqlalchemy.select(
1502
+ spot_table.c.spot_job_id.distinct()).select_from(
1503
+ spot_table.outerjoin(
1504
+ job_info_table,
1505
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1506
+ and_conditions = [
1507
+ ~spot_table.c.status.in_([
1508
+ status.value for status in ManagedJobStatus.terminal_statuses()
1509
+ ]),
1510
+ job_info_table.c.pool == pool,
1511
+ ]
1512
+ if cluster_name is not None:
1513
+ and_conditions.append(
1514
+ job_info_table.c.current_cluster_name == cluster_name)
1515
+ query = query.where(sqlalchemy.and_(*and_conditions)).order_by(
1516
+ spot_table.c.spot_job_id.asc())
1517
+ rows = session.execute(query).fetchall()
1518
+ job_ids = [row[0] for row in rows if row[0] is not None]
1519
+ return job_ids
1419
1520
 
1420
1521
 
1421
1522
  @_init_db
1422
- def get_waiting_job() -> Optional[Dict[str, Any]]:
1523
+ def get_waiting_job(pool: Optional[str]) -> Optional[Dict[str, Any]]:
1423
1524
  """Get the next job that should transition to LAUNCHING.
1424
1525
 
1425
1526
  Selects the highest-priority WAITING or ALIVE_WAITING job, provided its
@@ -1442,23 +1543,26 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1442
1543
  ManagedJobScheduleState.ALIVE_BACKOFF.value,
1443
1544
  ])).scalar_subquery()
1444
1545
  # Main query for waiting jobs
1546
+ select_conds = [
1547
+ job_info_table.c.schedule_state.in_([
1548
+ ManagedJobScheduleState.WAITING.value,
1549
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1550
+ ]),
1551
+ job_info_table.c.priority >= sqlalchemy.func.coalesce(
1552
+ max_priority_subquery, 0),
1553
+ ]
1554
+ if pool is not None:
1555
+ select_conds.append(job_info_table.c.pool == pool)
1445
1556
  query = sqlalchemy.select(
1446
1557
  job_info_table.c.spot_job_id,
1447
1558
  job_info_table.c.schedule_state,
1448
1559
  job_info_table.c.dag_yaml_path,
1449
1560
  job_info_table.c.env_file_path,
1450
- ).where(
1451
- sqlalchemy.and_(
1452
- job_info_table.c.schedule_state.in_([
1453
- ManagedJobScheduleState.WAITING.value,
1454
- ManagedJobScheduleState.ALIVE_WAITING.value,
1455
- ]),
1456
- job_info_table.c.priority >= sqlalchemy.func.coalesce(
1457
- max_priority_subquery, 0),
1458
- )).order_by(
1459
- job_info_table.c.priority.desc(),
1460
- job_info_table.c.spot_job_id.asc(),
1461
- ).limit(1)
1561
+ job_info_table.c.pool,
1562
+ ).where(sqlalchemy.and_(*select_conds)).order_by(
1563
+ job_info_table.c.priority.desc(),
1564
+ job_info_table.c.spot_job_id.asc(),
1565
+ ).limit(1)
1462
1566
  waiting_job_row = session.execute(query).fetchone()
1463
1567
  if waiting_job_row is None:
1464
1568
  return None
@@ -1468,6 +1572,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1468
1572
  'schedule_state': ManagedJobScheduleState(waiting_job_row[1]),
1469
1573
  'dag_yaml_path': waiting_job_row[2],
1470
1574
  'env_file_path': waiting_job_row[3],
1575
+ 'pool': waiting_job_row[4],
1471
1576
  }
1472
1577
 
1473
1578
 
sky/jobs/utils.py CHANGED
@@ -30,7 +30,6 @@ from sky.backends import backend_utils
30
30
  from sky.jobs import constants as managed_job_constants
31
31
  from sky.jobs import scheduler
32
32
  from sky.jobs import state as managed_job_state
33
- from sky.server import common as server_common
34
33
  from sky.skylet import constants
35
34
  from sky.skylet import job_lib
36
35
  from sky.skylet import log_lib
@@ -39,7 +38,6 @@ from sky.utils import annotations
39
38
  from sky.utils import command_runner
40
39
  from sky.utils import common_utils
41
40
  from sky.utils import controller_utils
42
- from sky.utils import env_options
43
41
  from sky.utils import infra_utils
44
42
  from sky.utils import log_utils
45
43
  from sky.utils import message_utils
@@ -136,12 +134,6 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
136
134
  def _validate_consolidation_mode_config(
137
135
  current_is_consolidation_mode: bool) -> None:
138
136
  """Validate the consolidation mode config."""
139
- if (current_is_consolidation_mode and
140
- not env_options.Options.IS_DEVELOPER.get() and
141
- server_common.is_api_server_local()):
142
- with ux_utils.print_exception_no_traceback():
143
- raise exceptions.NotSupportedError(
144
- 'Consolidation mode is not supported when running locally.')
145
137
  # Check whether the consolidation mode config is changed.
146
138
  if current_is_consolidation_mode:
147
139
  controller_cn = (
@@ -239,8 +231,8 @@ def ha_recovery_for_consolidation_mode():
239
231
  f.write(f'Total recovery time: {time.time() - start} seconds\n')
240
232
 
241
233
 
242
- def get_job_status(backend: 'backends.CloudVmRayBackend',
243
- cluster_name: str) -> Optional['job_lib.JobStatus']:
234
+ def get_job_status(backend: 'backends.CloudVmRayBackend', cluster_name: str,
235
+ job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
244
236
  """Check the status of the job running on a managed job cluster.
245
237
 
246
238
  It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
@@ -253,10 +245,13 @@ def get_job_status(backend: 'backends.CloudVmRayBackend',
253
245
  logger.info(f'Cluster {cluster_name} not found.')
254
246
  return None
255
247
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
248
+ job_ids = None if job_id is None else [job_id]
256
249
  for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
257
250
  try:
258
251
  logger.info('=== Checking the job status... ===')
259
- statuses = backend.get_job_status(handle, stream_logs=False)
252
+ statuses = backend.get_job_status(handle,
253
+ job_ids=job_ids,
254
+ stream_logs=False)
260
255
  status = list(statuses.values())[0]
261
256
  if status is None:
262
257
  logger.info('No job found.')
@@ -323,13 +318,20 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
323
318
  error_msg = None
324
319
  tasks = managed_job_state.get_managed_jobs(job_id)
325
320
  for task in tasks:
326
- task_name = task['job_name']
327
- cluster_name = generate_managed_job_cluster_name(task_name, job_id)
321
+ pool = task.get('pool', None)
322
+ if pool is None:
323
+ task_name = task['job_name']
324
+ cluster_name = generate_managed_job_cluster_name(
325
+ task_name, job_id)
326
+ else:
327
+ cluster_name, _ = (
328
+ managed_job_state.get_pool_submit_info(job_id))
328
329
  handle = global_user_state.get_handle_from_cluster_name(
329
330
  cluster_name)
330
331
  if handle is not None:
331
332
  try:
332
- terminate_cluster(cluster_name)
333
+ if pool is None:
334
+ terminate_cluster(cluster_name)
333
335
  except Exception as e: # pylint: disable=broad-except
334
336
  error_msg = (
335
337
  f'Failed to terminate cluster {cluster_name}: '
@@ -510,10 +512,10 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
510
512
 
511
513
 
512
514
  def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
513
- get_end_time: bool) -> float:
515
+ job_id: Optional[int], get_end_time: bool) -> float:
514
516
  """Get the submitted/ended time of the job."""
515
517
  code = job_lib.JobLibCodeGen.get_job_submitted_or_ended_timestamp_payload(
516
- job_id=None, get_ended_time=get_end_time)
518
+ job_id=job_id, get_ended_time=get_end_time)
517
519
  handle = global_user_state.get_handle_from_cluster_name(cluster_name)
518
520
  returncode, stdout, stderr = backend.run_on_head(handle,
519
521
  code,
@@ -527,14 +529,17 @@ def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str,
527
529
 
528
530
 
529
531
  def try_to_get_job_end_time(backend: 'backends.CloudVmRayBackend',
530
- cluster_name: str) -> float:
532
+ cluster_name: str, job_id: Optional[int]) -> float:
531
533
  """Try to get the end time of the job.
532
534
 
533
535
  If the job is preempted or we can't connect to the instance for whatever
534
536
  reason, fall back to the current time.
535
537
  """
536
538
  try:
537
- return get_job_timestamp(backend, cluster_name, get_end_time=True)
539
+ return get_job_timestamp(backend,
540
+ cluster_name,
541
+ job_id=job_id,
542
+ get_end_time=True)
538
543
  except exceptions.CommandError as e:
539
544
  if e.returncode == 255:
540
545
  # Failed to connect - probably the instance was preempted since the
@@ -556,8 +561,12 @@ def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
556
561
  if event_callback is None or task is None:
557
562
  return
558
563
  event_callback = event_callback.strip()
559
- cluster_name = generate_managed_job_cluster_name(
560
- task.name, job_id) if task.name else None
564
+ pool = managed_job_state.get_pool_from_job_id(job_id)
565
+ if pool is not None:
566
+ cluster_name, _ = (managed_job_state.get_pool_submit_info(job_id))
567
+ else:
568
+ cluster_name = generate_managed_job_cluster_name(
569
+ task.name, job_id) if task.name else None
561
570
  logger.info(f'=== START: event callback for {status!r} ===')
562
571
  log_path = os.path.join(constants.SKY_LOGS_DIRECTORY,
563
572
  'managed_job_event',
@@ -684,6 +693,15 @@ def cancel_job_by_name(job_name: str,
684
693
  return f'{job_name!r} {msg}'
685
694
 
686
695
 
696
+ def cancel_jobs_by_pool(pool_name: str,
697
+ current_workspace: Optional[str] = None) -> str:
698
+ """Cancel all jobs in a pool."""
699
+ job_ids = managed_job_state.get_nonterminal_job_ids_by_pool(pool_name)
700
+ if not job_ids:
701
+ return f'No running job found in pool {pool_name!r}.'
702
+ return cancel_jobs_by_id(job_ids, current_workspace=current_workspace)
703
+
704
+
687
705
  def stream_logs_by_id(job_id: int,
688
706
  follow: bool = True,
689
707
  tail: Optional[int] = None) -> Tuple[str, int]:
@@ -777,12 +795,19 @@ def stream_logs_by_id(job_id: int,
777
795
 
778
796
  while should_keep_logging(managed_job_status):
779
797
  handle = None
798
+ job_id_to_tail = None
780
799
  if task_id is not None:
781
- task_name = managed_job_state.get_task_name(job_id, task_id)
782
- cluster_name = generate_managed_job_cluster_name(
783
- task_name, job_id)
784
- handle = global_user_state.get_handle_from_cluster_name(
785
- cluster_name)
800
+ pool = managed_job_state.get_pool_from_job_id(job_id)
801
+ if pool is not None:
802
+ cluster_name, job_id_to_tail = (
803
+ managed_job_state.get_pool_submit_info(job_id))
804
+ else:
805
+ task_name = managed_job_state.get_task_name(job_id, task_id)
806
+ cluster_name = generate_managed_job_cluster_name(
807
+ task_name, job_id)
808
+ if cluster_name is not None:
809
+ handle = global_user_state.get_handle_from_cluster_name(
810
+ cluster_name)
786
811
 
787
812
  # Check the handle: The cluster can be preempted and removed from
788
813
  # the table before the managed job state is updated by the
@@ -814,7 +839,7 @@ def stream_logs_by_id(job_id: int,
814
839
  status_display.stop()
815
840
  tail_param = tail if tail is not None else 0
816
841
  returncode = backend.tail_logs(handle,
817
- job_id=None,
842
+ job_id=job_id_to_tail,
818
843
  managed_job_id=job_id,
819
844
  follow=follow,
820
845
  tail=tail_param)
@@ -1132,9 +1157,15 @@ def dump_managed_job_queue() -> str:
1132
1157
  job['status'] = job['status'].value
1133
1158
  job['schedule_state'] = job['schedule_state'].value
1134
1159
 
1135
- cluster_name = generate_managed_job_cluster_name(
1136
- job['task_name'], job['job_id'])
1137
- handle = global_user_state.get_handle_from_cluster_name(cluster_name)
1160
+ pool = managed_job_state.get_pool_from_job_id(job['job_id'])
1161
+ if pool is not None:
1162
+ cluster_name, _ = managed_job_state.get_pool_submit_info(
1163
+ job['job_id'])
1164
+ else:
1165
+ cluster_name = generate_managed_job_cluster_name(
1166
+ job['task_name'], job['job_id'])
1167
+ handle = global_user_state.get_handle_from_cluster_name(
1168
+ cluster_name) if cluster_name is not None else None
1138
1169
  if isinstance(handle, backends.CloudVmRayResourceHandle):
1139
1170
  resources_str = resources_utils.get_readable_resources_repr(
1140
1171
  handle, simplify=True)
@@ -1145,6 +1176,11 @@ def dump_managed_job_queue() -> str:
1145
1176
  job['cloud'] = str(handle.launched_resources.cloud)
1146
1177
  job['region'] = handle.launched_resources.region
1147
1178
  job['zone'] = handle.launched_resources.zone
1179
+ job['infra'] = infra_utils.InfraInfo(
1180
+ str(handle.launched_resources.cloud),
1181
+ handle.launched_resources.region,
1182
+ handle.launched_resources.zone).formatted_str()
1183
+ job['accelerators'] = handle.launched_resources.accelerators
1148
1184
  else:
1149
1185
  # FIXME(zongheng): display the last cached values for these.
1150
1186
  job['cluster_resources'] = '-'
@@ -1152,6 +1188,7 @@ def dump_managed_job_queue() -> str:
1152
1188
  job['cloud'] = '-'
1153
1189
  job['region'] = '-'
1154
1190
  job['zone'] = '-'
1191
+ job['infra'] = '-'
1155
1192
 
1156
1193
  # Add details about schedule state / backoff.
1157
1194
  state_details = None
@@ -1292,10 +1329,13 @@ def format_job_table(
1292
1329
  'JOB DURATION',
1293
1330
  '#RECOVERIES',
1294
1331
  'STATUS',
1332
+ 'WORKER_POOL',
1295
1333
  ]
1296
1334
  if show_all:
1297
1335
  # TODO: move SCHED. STATE to a separate flag (e.g. --debug)
1298
1336
  columns += [
1337
+ 'WORKER_CLUSTER',
1338
+ 'WORKER_JOB_ID',
1299
1339
  'STARTED',
1300
1340
  'INFRA',
1301
1341
  'RESOURCES',
@@ -1405,11 +1445,14 @@ def format_job_table(
1405
1445
  job_duration,
1406
1446
  recovery_cnt,
1407
1447
  status_str,
1448
+ job_tasks[0].get('pool', '-'),
1408
1449
  ]
1409
1450
  if show_all:
1410
1451
  details = job_tasks[current_task_id].get('details')
1411
1452
  failure_reason = job_tasks[current_task_id]['failure_reason']
1412
1453
  job_values.extend([
1454
+ '-',
1455
+ '-',
1413
1456
  '-',
1414
1457
  '-',
1415
1458
  '-',
@@ -1445,37 +1488,43 @@ def format_job_table(
1445
1488
  job_duration,
1446
1489
  task['recovery_count'],
1447
1490
  task['status'].colored_str(),
1491
+ task.get('pool', '-'),
1448
1492
  ]
1449
1493
  if show_all:
1450
1494
  # schedule_state is only set at the job level, so if we have
1451
1495
  # more than one task, only display on the aggregated row.
1452
1496
  schedule_state = (task['schedule_state']
1453
1497
  if len(job_tasks) == 1 else '-')
1454
- cloud = task.get('cloud')
1455
- if cloud is None:
1456
- # Backward compatibility for old jobs controller without
1457
- # cloud info returned, we parse it from the cluster
1458
- # resources
1459
- # TODO(zhwu): remove this after 0.12.0
1460
- cloud = task['cluster_resources'].split('(')[0].split(
1461
- 'x')[-1]
1462
- task['cluster_resources'] = task[
1463
- 'cluster_resources'].replace(f'{cloud}(',
1464
- '(').replace('x ', 'x')
1465
- region = task['region']
1466
- zone = task.get('zone')
1467
- if cloud == '-':
1468
- cloud = None
1469
- if region == '-':
1470
- region = None
1471
- if zone == '-':
1472
- zone = None
1473
-
1474
- infra = infra_utils.InfraInfo(cloud, region, zone)
1498
+ infra_str = task.get('infra')
1499
+ if infra_str is None:
1500
+ cloud = task.get('cloud')
1501
+ if cloud is None:
1502
+ # Backward compatibility for old jobs controller without
1503
+ # cloud info returned, we parse it from the cluster
1504
+ # resources
1505
+ # TODO(zhwu): remove this after 0.12.0
1506
+ cloud = task['cluster_resources'].split('(')[0].split(
1507
+ 'x')[-1]
1508
+ task['cluster_resources'] = task[
1509
+ 'cluster_resources'].replace(f'{cloud}(',
1510
+ '(').replace(
1511
+ 'x ', 'x')
1512
+ region = task['region']
1513
+ zone = task.get('zone')
1514
+ if cloud == '-':
1515
+ cloud = None
1516
+ if region == '-':
1517
+ region = None
1518
+ if zone == '-':
1519
+ zone = None
1520
+ infra_str = infra_utils.InfraInfo(cloud, region,
1521
+ zone).formatted_str()
1475
1522
  values.extend([
1523
+ task.get('current_cluster_name', '-'),
1524
+ task.get('job_id_on_pool_cluster', '-'),
1476
1525
  # STARTED
1477
1526
  log_utils.readable_time_duration(task['start_at']),
1478
- infra.formatted_str(),
1527
+ infra_str,
1479
1528
  task['cluster_resources'],
1480
1529
  schedule_state,
1481
1530
  generate_details(task.get('details'),
@@ -1567,6 +1616,15 @@ class ManagedJobCodeGen:
1567
1616
  """)
1568
1617
  return cls._build(code)
1569
1618
 
1619
+ @classmethod
1620
+ def cancel_jobs_by_pool(cls, pool_name: str) -> str:
1621
+ active_workspace = skypilot_config.get_active_workspace()
1622
+ code = textwrap.dedent(f"""\
1623
+ msg = utils.cancel_jobs_by_pool({pool_name!r}, {active_workspace!r})
1624
+ print(msg, end="", flush=True)
1625
+ """)
1626
+ return cls._build(code)
1627
+
1570
1628
  @classmethod
1571
1629
  def get_version_and_job_table(cls) -> str:
1572
1630
  """Generate code to get controller version and raw job table."""
@@ -15,6 +15,9 @@ INFINIBAND_ENV_VARS = {
15
15
  'mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1')
16
16
  }
17
17
 
18
+ # pylint: disable=line-too-long
19
+ INFINIBAND_IMAGE_ID = 'docker:cr.eu-north1.nebius.cloud/nebius-benchmarks/nccl-tests:2.23.4-ubu22.04-cu12.4'
20
+
18
21
  # Docker run options for InfiniBand support
19
22
  INFINIBAND_DOCKER_OPTIONS = ['--device=/dev/infiniband', '--cap-add=IPC_LOCK']
20
23
 
@@ -270,18 +270,17 @@ def launch(cluster_name: str, node_type: str, instance_type: str, region: str,
270
270
  docker_login_config: Optional[Dict[str, str]]) -> str:
271
271
  """Launches an instance with the given parameters.
272
272
 
273
- Converts the instance_type to the RunPod GPU name, finds the specs for the
274
- GPU, and launches the instance.
273
+ For CPU instances, we directly use the instance_type for launching the
274
+ instance.
275
+
276
+ For GPU instances, we convert the instance_type to the RunPod GPU name,
277
+ and finds the specs for the GPU, before launching the instance.
275
278
 
276
279
  Returns:
277
280
  instance_id: The instance ID.
278
281
  """
279
282
  name = f'{cluster_name}-{node_type}'
280
- gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]]
281
- gpu_quantity = int(instance_type.split('_')[0].replace('x', ''))
282
- cloud_type = instance_type.split('_')[2]
283
283
 
284
- gpu_specs = runpod.runpod.get_gpu(gpu_type)
285
284
  # TODO(zhwu): keep this align with setups in
286
285
  # `provision.kuberunetes.instance.py`
287
286
  setup_cmd = (
@@ -329,12 +328,7 @@ def launch(cluster_name: str, node_type: str, instance_type: str, region: str,
329
328
  params = {
330
329
  'name': name,
331
330
  'image_name': image_name_formatted,
332
- 'gpu_type_id': gpu_type,
333
- 'cloud_type': cloud_type,
334
331
  'container_disk_in_gb': disk_size,
335
- 'min_vcpu_count': 4 * gpu_quantity,
336
- 'min_memory_in_gb': gpu_specs['memoryInGb'] * gpu_quantity,
337
- 'gpu_count': gpu_quantity,
338
332
  'country_code': region,
339
333
  'data_center_id': zone,
340
334
  'ports': ports_str,
@@ -343,12 +337,33 @@ def launch(cluster_name: str, node_type: str, instance_type: str, region: str,
343
337
  'template_id': template_id,
344
338
  }
345
339
 
340
+ # GPU instance types start with f'{gpu_count}x',
341
+ # CPU instance types start with 'cpu'.
342
+ is_cpu_instance = instance_type.startswith('cpu')
343
+ if is_cpu_instance:
344
+ # RunPod CPU instances can be uniquely identified by the instance_id.
345
+ params.update({
346
+ 'instance_id': instance_type,
347
+ })
348
+ else:
349
+ gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]]
350
+ gpu_quantity = int(instance_type.split('_')[0].replace('x', ''))
351
+ cloud_type = instance_type.split('_')[2]
352
+ gpu_specs = runpod.runpod.get_gpu(gpu_type)
353
+ params.update({
354
+ 'gpu_type_id': gpu_type,
355
+ 'cloud_type': cloud_type,
356
+ 'min_vcpu_count': 4 * gpu_quantity,
357
+ 'min_memory_in_gb': gpu_specs['memoryInGb'] * gpu_quantity,
358
+ 'gpu_count': gpu_quantity,
359
+ })
360
+
346
361
  if preemptible is None or not preemptible:
347
362
  new_instance = runpod.runpod.create_pod(**params)
348
363
  else:
349
364
  new_instance = runpod_commands.create_spot_pod(
350
365
  bid_per_gpu=bid_per_gpu,
351
- **params,
366
+ **params, # type: ignore[arg-type]
352
367
  )
353
368
 
354
369
  return new_instance['id']
sky/py.typed ADDED
File without changes