skypilot-nightly 1.0.0.dev20250225__py3-none-any.whl → 1.0.0.dev20250227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sky/jobs/server/core.py CHANGED
@@ -140,6 +140,7 @@ def launch(
140
140
  prefix = managed_job_constants.JOBS_TASK_YAML_PREFIX
141
141
  remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml'
142
142
  remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml'
143
+ remote_env_file_path = f'{prefix}/{dag.name}-{dag_uuid}.env'
143
144
  controller_resources = controller_utils.get_controller_resources(
144
145
  controller=controller_utils.Controllers.JOBS_CONTROLLER,
145
146
  task_resources=sum([list(t.resources) for t in dag.tasks], []))
@@ -152,6 +153,7 @@ def launch(
152
153
  # Note: actual cluster name will be <task.name>-<managed job ID>
153
154
  'dag_name': dag.name,
154
155
  'remote_user_config_path': remote_user_config_path,
156
+ 'remote_env_file_path': remote_env_file_path,
155
157
  'modified_catalogs':
156
158
  service_catalog_common.get_modified_catalog_file_mounts(),
157
159
  'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
@@ -318,7 +320,9 @@ def _maybe_restart_controller(
318
320
 
319
321
 
320
322
  @usage_lib.entrypoint
321
- def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
323
+ def queue(refresh: bool,
324
+ skip_finished: bool = False,
325
+ all_users: bool = False) -> List[Dict[str, Any]]:
322
326
  # NOTE(dev): Keep the docstring consistent between the Python API and CLI.
323
327
  """Gets statuses of managed jobs.
324
328
 
@@ -366,6 +370,19 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
366
370
  f'{returncode}')
367
371
 
368
372
  jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
373
+
374
+ if not all_users:
375
+
376
+ def user_hash_matches_or_missing(job: Dict[str, Any]) -> bool:
377
+ user_hash = job.get('user_hash', None)
378
+ if user_hash is None:
379
+ # For backwards compatibility, we show jobs that do not have a
380
+ # user_hash. TODO(cooperc): Remove before 0.12.0.
381
+ return True
382
+ return user_hash == common_utils.get_user_hash()
383
+
384
+ jobs = list(filter(user_hash_matches_or_missing, jobs))
385
+
369
386
  if skip_finished:
370
387
  # Filter out the finished jobs. If a multi-task job is partially
371
388
  # finished, we will include all its tasks.
@@ -374,6 +391,7 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
374
391
  non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
375
392
  jobs = list(
376
393
  filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
394
+
377
395
  return jobs
378
396
 
379
397
 
@@ -381,7 +399,8 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
381
399
  # pylint: disable=redefined-builtin
382
400
  def cancel(name: Optional[str] = None,
383
401
  job_ids: Optional[List[int]] = None,
384
- all: bool = False) -> None:
402
+ all: bool = False,
403
+ all_users: bool = False) -> None:
385
404
  # NOTE(dev): Keep the docstring consistent between the Python API and CLI.
386
405
  """Cancels managed jobs.
387
406
 
@@ -397,17 +416,22 @@ def cancel(name: Optional[str] = None,
397
416
  stopped_message='All managed jobs should have finished.')
398
417
 
399
418
  job_id_str = ','.join(map(str, job_ids))
400
- if sum([bool(job_ids), name is not None, all]) != 1:
401
- argument_str = f'job_ids={job_id_str}' if job_ids else ''
402
- argument_str += f' name={name}' if name is not None else ''
403
- argument_str += ' all' if all else ''
419
+ if sum([bool(job_ids), name is not None, all or all_users]) != 1:
420
+ arguments = []
421
+ arguments += [f'job_ids={job_id_str}'] if job_ids else []
422
+ arguments += [f'name={name}'] if name is not None else []
423
+ arguments += ['all'] if all else []
424
+ arguments += ['all_users'] if all_users else []
404
425
  with ux_utils.print_exception_no_traceback():
405
- raise ValueError('Can only specify one of JOB_IDS or name or all. '
406
- f'Provided {argument_str!r}.')
426
+ raise ValueError('Can only specify one of JOB_IDS, name, or all/'
427
+ f'all_users. Provided {" ".join(arguments)!r}.')
407
428
 
408
429
  backend = backend_utils.get_backend_from_handle(handle)
409
430
  assert isinstance(backend, backends.CloudVmRayBackend)
410
- if all:
431
+ if all_users:
432
+ code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(
433
+ None, all_users=True)
434
+ elif all:
411
435
  code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(None)
412
436
  elif job_ids:
413
437
  code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(job_ids)
sky/jobs/server/server.py CHANGED
@@ -109,9 +109,18 @@ async def download_logs(
109
109
  @router.get('/dashboard')
110
110
  async def dashboard(request: fastapi.Request,
111
111
  user_hash: str) -> fastapi.Response:
112
+ # TODO(cooperc): Support showing only jobs for a specific user.
113
+
114
+ # FIX(zhwu/cooperc/eric): Fix log downloading (assumes global
115
+ # /download_log/xx route)
116
+
112
117
  # Note: before #4717, each user had their own controller, and thus their own
113
118
  # dashboard. Now, all users share the same controller, so this isn't really
114
119
  # necessary. TODO(cooperc): clean up.
120
+
121
+ # TODO: Put this in an executor to avoid blocking the main server thread.
122
+ # It can take a long time if it needs to check the controller status.
123
+
115
124
  # Find the port for the dashboard of the user
116
125
  os.environ[constants.USER_ID_ENV_VAR] = user_hash
117
126
  server_common.reload_for_new_request(client_entrypoint=None,
sky/jobs/state.py CHANGED
@@ -116,7 +116,9 @@ def create_table(cursor, conn):
116
116
  name TEXT,
117
117
  schedule_state TEXT,
118
118
  controller_pid INTEGER DEFAULT NULL,
119
- dag_yaml_path TEXT)""")
119
+ dag_yaml_path TEXT,
120
+ env_file_path TEXT,
121
+ user_hash TEXT)""")
120
122
 
121
123
  db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state',
122
124
  'TEXT')
@@ -127,6 +129,11 @@ def create_table(cursor, conn):
127
129
  db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path',
128
130
  'TEXT')
129
131
 
132
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'env_file_path',
133
+ 'TEXT')
134
+
135
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'user_hash', 'TEXT')
136
+
130
137
  conn.commit()
131
138
 
132
139
 
@@ -181,6 +188,8 @@ columns = [
181
188
  'schedule_state',
182
189
  'controller_pid',
183
190
  'dag_yaml_path',
191
+ 'env_file_path',
192
+ 'user_hash',
184
193
  ]
185
194
 
186
195
 
@@ -683,20 +692,24 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
683
692
 
684
693
 
685
694
  # ======== utility functions ========
686
- def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
695
+ def get_nonterminal_job_ids_by_name(name: Optional[str],
696
+ all_users: bool = False) -> List[int]:
687
697
  """Get non-terminal job ids by name."""
688
698
  statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
689
699
  field_values = [
690
700
  status.value for status in ManagedJobStatus.terminal_statuses()
691
701
  ]
692
702
 
693
- name_filter = ''
703
+ job_filter = ''
704
+ if name is None and not all_users:
705
+ job_filter += 'AND (job_info.user_hash=(?)) '
706
+ field_values.append(common_utils.get_user_hash())
694
707
  if name is not None:
695
708
  # We match the job name from `job_info` for the jobs submitted after
696
709
  # #1982, and from `spot` for the jobs submitted before #1982, whose
697
710
  # job_info is not available.
698
- name_filter = ('AND (job_info.name=(?) OR '
699
- '(job_info.name IS NULL AND spot.task_name=(?)))')
711
+ job_filter += ('AND (job_info.name=(?) OR '
712
+ '(job_info.name IS NULL AND spot.task_name=(?))) ')
700
713
  field_values.extend([name, name])
701
714
 
702
715
  # Left outer join is used here instead of join, because the job_info does
@@ -710,7 +723,7 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
710
723
  ON spot.spot_job_id=job_info.spot_job_id
711
724
  WHERE status NOT IN
712
725
  ({statuses})
713
- {name_filter}
726
+ {job_filter}
714
727
  ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
715
728
  job_ids = [row[0] for row in rows if row[0] is not None]
716
729
  return job_ids
@@ -906,6 +919,9 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
906
919
  # existing controller before #1982, the job_info table may not exist,
907
920
  # and all the managed jobs created before will not present in the
908
921
  # job_info.
922
+ # Note: we will get the user_hash here, but don't try to call
923
+ # global_user_state.get_user() on it. This runs on the controller, which may
924
+ # not have the user info. Prefer to do it on the API server side.
909
925
  with db_utils.safe_cursor(_DB_PATH) as cursor:
910
926
  rows = cursor.execute(f"""\
911
927
  SELECT *
@@ -978,14 +994,17 @@ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
978
994
  # scheduler lock to work correctly.
979
995
 
980
996
 
981
- def scheduler_set_waiting(job_id: int, dag_yaml_path: str) -> None:
997
+ def scheduler_set_waiting(job_id: int, dag_yaml_path: str, env_file_path: str,
998
+ user_hash: str) -> None:
982
999
  """Do not call without holding the scheduler lock."""
983
1000
  with db_utils.safe_cursor(_DB_PATH) as cursor:
984
1001
  updated_count = cursor.execute(
985
1002
  'UPDATE job_info SET '
986
- 'schedule_state = (?), dag_yaml_path = (?) '
1003
+ 'schedule_state = (?), dag_yaml_path = (?), env_file_path = (?), '
1004
+ ' user_hash = (?) '
987
1005
  'WHERE spot_job_id = (?) AND schedule_state = (?)',
988
- (ManagedJobScheduleState.WAITING.value, dag_yaml_path, job_id,
1006
+ (ManagedJobScheduleState.WAITING.value, dag_yaml_path,
1007
+ env_file_path, user_hash, job_id,
989
1008
  ManagedJobScheduleState.INACTIVE.value)).rowcount
990
1009
  assert updated_count == 1, (job_id, updated_count)
991
1010
 
@@ -1085,7 +1104,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1085
1104
  """
1086
1105
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1087
1106
  row = cursor.execute(
1088
- 'SELECT spot_job_id, schedule_state, dag_yaml_path '
1107
+ 'SELECT spot_job_id, schedule_state, dag_yaml_path, env_file_path '
1089
1108
  'FROM job_info '
1090
1109
  'WHERE schedule_state in (?, ?) '
1091
1110
  'ORDER BY spot_job_id LIMIT 1',
@@ -1095,4 +1114,5 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1095
1114
  'job_id': row[0],
1096
1115
  'schedule_state': ManagedJobScheduleState(row[1]),
1097
1116
  'dag_yaml_path': row[2],
1117
+ 'env_file_path': row[3],
1098
1118
  } if row is not None else None
sky/jobs/utils.py CHANGED
@@ -449,13 +449,15 @@ def generate_managed_job_cluster_name(task_name: str, job_id: int) -> str:
449
449
  return f'{cluster_name}-{job_id}'
450
450
 
451
451
 
452
- def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
452
+ def cancel_jobs_by_id(job_ids: Optional[List[int]],
453
+ all_users: bool = False) -> str:
453
454
  """Cancel jobs by id.
454
455
 
455
456
  If job_ids is None, cancel all jobs.
456
457
  """
457
458
  if job_ids is None:
458
- job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None)
459
+ job_ids = managed_job_state.get_nonterminal_job_ids_by_name(
460
+ None, all_users)
459
461
  job_ids = list(set(job_ids))
460
462
  if not job_ids:
461
463
  return 'No job to cancel.'
@@ -917,6 +919,7 @@ def _get_job_status_from_tasks(
917
919
  @typing.overload
918
920
  def format_job_table(tasks: List[Dict[str, Any]],
919
921
  show_all: bool,
922
+ show_user: bool,
920
923
  return_rows: Literal[False] = False,
921
924
  max_jobs: Optional[int] = None) -> str:
922
925
  ...
@@ -925,6 +928,7 @@ def format_job_table(tasks: List[Dict[str, Any]],
925
928
  @typing.overload
926
929
  def format_job_table(tasks: List[Dict[str, Any]],
927
930
  show_all: bool,
931
+ show_user: bool,
928
932
  return_rows: Literal[True],
929
933
  max_jobs: Optional[int] = None) -> List[List[str]]:
930
934
  ...
@@ -933,6 +937,7 @@ def format_job_table(tasks: List[Dict[str, Any]],
933
937
  def format_job_table(
934
938
  tasks: List[Dict[str, Any]],
935
939
  show_all: bool,
940
+ show_user: bool,
936
941
  return_rows: bool = False,
937
942
  max_jobs: Optional[int] = None) -> Union[str, List[List[str]]]:
938
943
  """Returns managed jobs as a formatted string.
@@ -948,13 +953,14 @@ def format_job_table(
948
953
  a list of "rows" (each of which is a list of str).
949
954
  """
950
955
  jobs = collections.defaultdict(list)
951
- # Check if the tasks have user information.
952
- tasks_have_user = any([task.get('user') for task in tasks])
953
- if max_jobs and tasks_have_user:
956
+ # Check if the tasks have user information from kubernetes.
957
+ # This is only used for sky status --kubernetes.
958
+ tasks_have_k8s_user = any([task.get('user') for task in tasks])
959
+ if max_jobs and tasks_have_k8s_user:
954
960
  raise ValueError('max_jobs is not supported when tasks have user info.')
955
961
 
956
962
  def get_hash(task):
957
- if tasks_have_user:
963
+ if tasks_have_k8s_user:
958
964
  return (task['user'], task['job_id'])
959
965
  return task['job_id']
960
966
 
@@ -969,10 +975,17 @@ def format_job_table(
969
975
  if not managed_job_status.is_terminal():
970
976
  status_counts[managed_job_status.value] += 1
971
977
 
978
+ user_cols: List[str] = []
979
+ if show_user:
980
+ user_cols = ['USER']
981
+ if show_all:
982
+ user_cols.append('USER_ID')
983
+
972
984
  columns = [
973
985
  'ID',
974
986
  'TASK',
975
987
  'NAME',
988
+ *user_cols,
976
989
  'RESOURCES',
977
990
  'SUBMITTED',
978
991
  'TOT. DURATION',
@@ -983,7 +996,7 @@ def format_job_table(
983
996
  if show_all:
984
997
  # TODO: move SCHED. STATE to a separate flag (e.g. --debug)
985
998
  columns += ['STARTED', 'CLUSTER', 'REGION', 'SCHED. STATE', 'DETAILS']
986
- if tasks_have_user:
999
+ if tasks_have_k8s_user:
987
1000
  columns.insert(0, 'USER')
988
1001
  job_table = log_utils.create_table(columns)
989
1002
 
@@ -1006,6 +1019,22 @@ def format_job_table(
1006
1019
  return f'Failure: {failure_reason}'
1007
1020
  return '-'
1008
1021
 
1022
+ def get_user_column_values(task: Dict[str, Any]) -> List[str]:
1023
+ user_values: List[str] = []
1024
+ if show_user:
1025
+
1026
+ user_name = '-'
1027
+ user_hash = task.get('user_hash', None)
1028
+ if user_hash:
1029
+ user = global_user_state.get_user(user_hash)
1030
+ user_name = user.name if user.name else '-'
1031
+ user_values = [user_name]
1032
+
1033
+ if show_all:
1034
+ user_values.append(user_hash if user_hash is not None else '-')
1035
+
1036
+ return user_values
1037
+
1009
1038
  for job_hash, job_tasks in jobs.items():
1010
1039
  if show_all:
1011
1040
  schedule_state = job_tasks[0]['schedule_state']
@@ -1044,11 +1073,14 @@ def format_job_table(
1044
1073
  if not managed_job_status.is_terminal():
1045
1074
  status_str += f' (task: {current_task_id})'
1046
1075
 
1047
- job_id = job_hash[1] if tasks_have_user else job_hash
1076
+ user_values = get_user_column_values(job_tasks[0])
1077
+
1078
+ job_id = job_hash[1] if tasks_have_k8s_user else job_hash
1048
1079
  job_values = [
1049
1080
  job_id,
1050
1081
  '',
1051
1082
  job_name,
1083
+ *user_values,
1052
1084
  '-',
1053
1085
  submitted,
1054
1086
  total_duration,
@@ -1065,7 +1097,7 @@ def format_job_table(
1065
1097
  job_tasks[0]['schedule_state'],
1066
1098
  generate_details(failure_reason),
1067
1099
  ])
1068
- if tasks_have_user:
1100
+ if tasks_have_k8s_user:
1069
1101
  job_values.insert(0, job_tasks[0].get('user', '-'))
1070
1102
  job_table.add_row(job_values)
1071
1103
 
@@ -1075,10 +1107,12 @@ def format_job_table(
1075
1107
  job_duration = log_utils.readable_time_duration(
1076
1108
  0, task['job_duration'], absolute=True)
1077
1109
  submitted = log_utils.readable_time_duration(task['submitted_at'])
1110
+ user_values = get_user_column_values(task)
1078
1111
  values = [
1079
1112
  task['job_id'] if len(job_tasks) == 1 else ' \u21B3',
1080
1113
  task['task_id'] if len(job_tasks) > 1 else '-',
1081
1114
  task['task_name'],
1115
+ *user_values,
1082
1116
  task['resources'],
1083
1117
  # SUBMITTED
1084
1118
  submitted if submitted != '-' else submitted,
@@ -1103,7 +1137,7 @@ def format_job_table(
1103
1137
  schedule_state,
1104
1138
  generate_details(task['failure_reason']),
1105
1139
  ])
1106
- if tasks_have_user:
1140
+ if tasks_have_k8s_user:
1107
1141
  values.insert(0, task.get('user', '-'))
1108
1142
  job_table.add_row(values)
1109
1143
 
@@ -1135,6 +1169,9 @@ class ManagedJobCodeGen:
1135
1169
  _PREFIX = textwrap.dedent("""\
1136
1170
  from sky.jobs import utils
1137
1171
  from sky.jobs import state as managed_job_state
1172
+ from sky.jobs import constants as managed_job_constants
1173
+
1174
+ managed_job_version = managed_job_constants.MANAGED_JOBS_VERSION
1138
1175
  """)
1139
1176
 
1140
1177
  @classmethod
@@ -1146,9 +1183,17 @@ class ManagedJobCodeGen:
1146
1183
  return cls._build(code)
1147
1184
 
1148
1185
  @classmethod
1149
- def cancel_jobs_by_id(cls, job_ids: Optional[List[int]]) -> str:
1186
+ def cancel_jobs_by_id(cls,
1187
+ job_ids: Optional[List[int]],
1188
+ all_users: bool = False) -> str:
1150
1189
  code = textwrap.dedent(f"""\
1151
- msg = utils.cancel_jobs_by_id({job_ids})
1190
+ if managed_job_version < 2:
1191
+ # For backward compatibility, since all_users is not supported
1192
+ # before #4787. Assume th
1193
+ # TODO(cooperc): Remove compatibility before 0.12.0
1194
+ msg = utils.cancel_jobs_by_id({job_ids})
1195
+ else:
1196
+ msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users})
1152
1197
  print(msg, end="", flush=True)
1153
1198
  """)
1154
1199
  return cls._build(code)
sky/server/constants.py CHANGED
@@ -3,7 +3,7 @@
3
3
  # API server version, whenever there is a change in API server that requires a
4
4
  # restart of the local API server or error out when the client does not match
5
5
  # the server version.
6
- API_VERSION = '1'
6
+ API_VERSION = '2'
7
7
 
8
8
  # Prefix for API request names.
9
9
  REQUEST_NAME_PREFIX = 'sky.'
@@ -322,6 +322,7 @@ class JobsQueueBody(RequestBody):
322
322
  """The request body for the jobs queue endpoint."""
323
323
  refresh: bool = False
324
324
  skip_finished: bool = False
325
+ all_users: bool = False
325
326
 
326
327
 
327
328
  class JobsCancelBody(RequestBody):
@@ -329,6 +330,7 @@ class JobsCancelBody(RequestBody):
329
330
  name: Optional[str]
330
331
  job_ids: Optional[List[int]]
331
332
  all: bool = False
333
+ all_users: bool = False
332
334
 
333
335
 
334
336
  class JobsLogsBody(RequestBody):
@@ -55,12 +55,19 @@ setup: |
55
55
 
56
56
  run: |
57
57
  {{ sky_activate_python_env }}
58
+
59
+ # Write env vars to a file
60
+ {%- for env_name, env_value in controller_envs.items() %}
61
+ echo "export {{env_name}}='{{env_value}}'" >> {{remote_env_file_path}}
62
+ {%- endfor %}
63
+
58
64
  # Submit the job to the scheduler.
59
65
  # Note: The job is already in the `spot` table, marked as PENDING.
60
66
  # CloudVmRayBackend._exec_code_on_head() calls
61
67
  # managed_job_codegen.set_pending() before we get here.
62
68
  python -u -m sky.jobs.scheduler {{remote_user_yaml_path}} \
63
- --job-id $SKYPILOT_INTERNAL_JOB_ID
69
+ --job-id $SKYPILOT_INTERNAL_JOB_ID \
70
+ --env-file {{remote_env_file_path}}
64
71
 
65
72
 
66
73
  envs:
sky/utils/common_utils.py CHANGED
@@ -774,13 +774,10 @@ def is_port_available(port: int, reuse_addr: bool = True) -> bool:
774
774
  return False
775
775
 
776
776
 
777
- # TODO(aylei): should be aware of cgroups
778
777
  def get_cpu_count() -> int:
779
- """Get the number of CPUs.
780
-
781
- If the API server is deployed as a pod in k8s cluster, we assume the
782
- number of CPUs is provided by the downward API.
783
- """
778
+ """Get the number of CPUs, with cgroup awareness."""
779
+ # This env-var is kept since it is still useful for limiting the resource
780
+ # of SkyPilot in non-containerized environments.
784
781
  cpu_count = os.getenv('SKYPILOT_POD_CPU_CORE_LIMIT')
785
782
  if cpu_count is not None:
786
783
  try:
@@ -790,16 +787,11 @@ def get_cpu_count() -> int:
790
787
  raise ValueError(
791
788
  f'Failed to parse the number of CPUs from {cpu_count}'
792
789
  ) from e
793
- return psutil.cpu_count()
790
+ return _cpu_count()
794
791
 
795
792
 
796
- # TODO(aylei): should be aware of cgroups
797
793
  def get_mem_size_gb() -> float:
798
- """Get the memory size in GB.
799
-
800
- If the API server is deployed as a pod in k8s cluster, we assume the
801
- memory size is provided by the downward API.
802
- """
794
+ """Get the memory size in GB, with cgroup awareness."""
803
795
  mem_size = os.getenv('SKYPILOT_POD_MEMORY_GB_LIMIT')
804
796
  if mem_size is not None:
805
797
  try:
@@ -808,4 +800,92 @@ def get_mem_size_gb() -> float:
808
800
  with ux_utils.print_exception_no_traceback():
809
801
  raise ValueError(
810
802
  f'Failed to parse the memory size from {mem_size}') from e
811
- return psutil.virtual_memory().total / (1024**3)
803
+ return _mem_size_gb()
804
+
805
+
806
+ def _cpu_count() -> int:
807
+ # host cpu cores (logical)
808
+ cpu = psutil.cpu_count()
809
+ # cpu affinity on Linux
810
+ if hasattr(os, 'sched_getaffinity'):
811
+ # just for safe, length of CPU set should always <= logical cpu cores
812
+ cpu = min(cpu, len(os.sched_getaffinity(0)))
813
+ cgroup_cpu = _get_cgroup_cpu_limit()
814
+ if cgroup_cpu is not None:
815
+ cpu = min(cpu, int(cgroup_cpu))
816
+ return cpu
817
+
818
+
819
+ def _mem_size_gb() -> float:
820
+ # host memory limit
821
+ mem = psutil.virtual_memory().total
822
+ cgroup_mem = _get_cgroup_memory_limit()
823
+ if cgroup_mem is not None:
824
+ mem = min(mem, cgroup_mem)
825
+ return mem / (1024**3)
826
+
827
+
828
+ # Refer to:
829
+ # - https://docs.kernel.org/admin-guide/cgroup-v1/index.html
830
+ # - https://docs.kernel.org/admin-guide/cgroup-v2.html
831
+ # for the standards of handler files in cgroupv1 and v2.
832
+ # Since all those paths are well-known standards that are unlikely to change,
833
+ # we use string literals instead of defining extra constants.
834
+ def _get_cgroup_cpu_limit() -> Optional[float]:
835
+ """Return cpu limit from cgroups in cores.
836
+
837
+ Returns:
838
+ The cpu limit in cores as a float (can be fractional), or None if there
839
+ is no limit in cgroups.
840
+ """
841
+ try:
842
+ if _is_cgroup_v2():
843
+ with open('/sys/fs/cgroup/cpu.max', 'r', encoding='utf-8') as f:
844
+ quota_str, period_str = f.read().strip().split()
845
+ if quota_str == 'max':
846
+ return None
847
+ quota = float(quota_str)
848
+ period = float(period_str)
849
+ return quota / period if quota > 0 else None
850
+ else:
851
+ # cgroup v1
852
+ with open('/sys/fs/cgroup/cpu/cpu.cfs_quota_us',
853
+ 'r',
854
+ encoding='utf-8') as f:
855
+ quota = float(f.read().strip())
856
+ with open('/sys/fs/cgroup/cpu/cpu.cfs_period_us',
857
+ 'r',
858
+ encoding='utf-8') as f:
859
+ period = float(f.read().strip())
860
+ # Return unlimited if cpu quota is not set.
861
+ # Note that we do not use cpu.shares since it is a relative weight
862
+ # instead of a hard limit. It is okay to get CPU throttling under
863
+ # high contention. And unlimited enables the server to use as much
864
+ # CPU as available if there is no contention.
865
+ return quota / period if (quota > 0 and period > 0) else None
866
+ except (OSError, ValueError):
867
+ return None
868
+
869
+
870
+ def _get_cgroup_memory_limit() -> Optional[int]:
871
+ """Return memory limit from cgroups in bytes.
872
+
873
+ Returns:
874
+ The memory limit in bytes, or None if there is no limit in cgroups.
875
+ """
876
+ try:
877
+ path = ('/sys/fs/cgroup/memory.max' if _is_cgroup_v2() else
878
+ '/sys/fs/cgroup/memory/memory.limit_in_bytes')
879
+ with open(path, 'r', encoding='utf-8') as f:
880
+ value = f.read().strip()
881
+ if value == 'max' or not value:
882
+ return None
883
+ limit = int(value)
884
+ return limit if limit > 0 else None
885
+ except (OSError, ValueError):
886
+ return None
887
+
888
+
889
+ def _is_cgroup_v2() -> bool:
890
+ """Return True if the environment is running cgroup v2."""
891
+ return os.path.isfile('/sys/fs/cgroup/cgroup.controllers')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: skypilot-nightly
3
- Version: 1.0.0.dev20250225
3
+ Version: 1.0.0.dev20250227
4
4
  Summary: SkyPilot: An intercloud broker for the clouds
5
5
  Author: SkyPilot Team
6
6
  License: Apache 2.0
@@ -169,7 +169,7 @@ Dynamic: summary
169
169
 
170
170
  <p align="center">
171
171
  <a href="https://docs.skypilot.co/">
172
- <img alt="Documentation" src="https://readthedocs.org/projects/skypilot/badge/?version=latest">
172
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-gray?logo=readthedocs&logoColor=f5f5f5">
173
173
  </a>
174
174
 
175
175
  <a href="https://github.com/skypilot-org/skypilot/releases">
@@ -192,6 +192,7 @@ Dynamic: summary
192
192
 
193
193
  ----
194
194
  :fire: *News* :fire:
195
+ - [Feb 2025] Prepare and serve **Retrieval Augmented Generation (RAG) with DeepSeek-R1**: [**blog post**](https://blog.skypilot.co/deepseek-rag), [**example**](./llm/rag/)
195
196
  - [Feb 2025] Run and serve **DeepSeek-R1 671B** using SkyPilot and SGLang with high throughput: [**example**](./llm/deepseek-r1/)
196
197
  - [Feb 2025] Prepare and serve large-scale image search with **vector databases**: [**blog post**](https://blog.skypilot.co/large-scale-vector-database/), [**example**](./examples/vector_database/)
197
198
  - [Jan 2025] Launch and serve distilled models from **[DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1)** and **[Janus](https://github.com/deepseek-ai/DeepSeek-Janus)** on Kubernetes or any cloud: [**R1 example**](./llm/deepseek-r1-distilled/) and [**Janus example**](./llm/deepseek-janus/)