skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/usage/usage_lib.py CHANGED
@@ -3,7 +3,6 @@
3
3
  import contextlib
4
4
  import datetime
5
5
  import enum
6
- import inspect
7
6
  import json
8
7
  import os
9
8
  import time
@@ -11,20 +10,27 @@ import traceback
11
10
  import typing
12
11
  from typing import Any, Callable, Dict, List, Optional, Union
13
12
 
14
- import click
15
- import requests
16
-
17
13
  import sky
18
14
  from sky import sky_logging
15
+ from sky.adaptors import common as adaptors_common
19
16
  from sky.usage import constants
20
17
  from sky.utils import common_utils
21
18
  from sky.utils import env_options
22
19
  from sky.utils import ux_utils
23
20
 
24
21
  if typing.TYPE_CHECKING:
22
+ import inspect
23
+
24
+ import requests
25
+
25
26
  from sky import resources as resources_lib
26
- from sky import status_lib
27
27
  from sky import task as task_lib
28
+ from sky.utils import status_lib
29
+ else:
30
+ # requests and inspect cost ~100ms to load, which can be postponed to
31
+ # collection phase or skipped if user specifies no collection
32
+ requests = adaptors_common.LazyImport('requests')
33
+ inspect = adaptors_common.LazyImport('inspect')
28
34
 
29
35
  logger = sky_logging.init_logger(__name__)
30
36
 
@@ -36,6 +42,7 @@ def _get_current_timestamp_ns() -> int:
36
42
  class MessageType(enum.Enum):
37
43
  """Types for messages to be sent to Loki."""
38
44
  USAGE = 'usage'
45
+ HEARTBEAT = 'heartbeat'
39
46
  # TODO(zhwu): Add more types, e.g., cluster_lifecycle.
40
47
 
41
48
 
@@ -59,8 +66,9 @@ class MessageToReport:
59
66
  properties = self.__dict__.copy()
60
67
  return {k: v for k, v in properties.items() if not k.startswith('_')}
61
68
 
62
- def __repr__(self):
63
- raise NotImplementedError
69
+ def __repr__(self) -> str:
70
+ d = self.get_properties()
71
+ return json.dumps(d)
64
72
 
65
73
 
66
74
  class UsageMessageToReport(MessageToReport):
@@ -75,7 +83,12 @@ class UsageMessageToReport(MessageToReport):
75
83
  self.sky_commit: str = sky.__commit__
76
84
 
77
85
  # Entry
78
- self.cmd: str = common_utils.get_pretty_entry_point()
86
+ self.cmd: Optional[str] = common_utils.get_current_command()
87
+ # The entrypoint on the client side.
88
+ self.client_entrypoint: Optional[str] = None
89
+ # The entrypoint on the server side, where each request has a entrypoint
90
+ # and a single client_entrypoint can have multiple server-side
91
+ # entrypoints.
79
92
  self.entrypoint: Optional[str] = None # entrypoint_context
80
93
  #: Whether entrypoint is called by sky internal code.
81
94
  self.internal: bool = False # set_internal
@@ -140,6 +153,7 @@ class UsageMessageToReport(MessageToReport):
140
153
  #: Requested number of nodes
141
154
  self.task_num_nodes: Optional[int] = None # update_actual_task
142
155
  # YAMLs converted to JSON.
156
+ # TODO: include the skypilot config used in task yaml.
143
157
  self.user_task_yaml: Optional[List[Dict[
144
158
  str, Any]]] = None # update_user_task_yaml
145
159
  self.actual_task: Optional[List[Dict[str,
@@ -151,11 +165,14 @@ class UsageMessageToReport(MessageToReport):
151
165
  self.exception: Optional[str] = None # entrypoint_context
152
166
  self.stacktrace: Optional[str] = None # entrypoint_context
153
167
 
154
- def __repr__(self) -> str:
155
- d = self.get_properties()
156
- return json.dumps(d)
168
+ # Whether API server is deployed remotely.
169
+ self.using_remote_api_server: bool = (
170
+ common_utils.get_using_remote_api_server())
157
171
 
158
172
  def update_entrypoint(self, msg: str):
173
+ if self.client_entrypoint is None:
174
+ self.client_entrypoint = common_utils.get_current_client_entrypoint(
175
+ msg)
159
176
  self.entrypoint = msg
160
177
 
161
178
  def set_internal(self):
@@ -200,9 +217,11 @@ class UsageMessageToReport(MessageToReport):
200
217
  def update_ray_yaml(self, yaml_config_or_path: Union[Dict, str]):
201
218
  if self.ray_yamls is None:
202
219
  self.ray_yamls = []
203
- self.ray_yamls.extend(
204
- prepare_json_from_yaml_config(yaml_config_or_path))
205
- self.num_tried_regions = len(self.ray_yamls)
220
+ if self.num_tried_regions is None:
221
+ self.num_tried_regions = 0
222
+ # Only keep the latest ray yaml to reduce the size of the message.
223
+ self.ray_yamls = prepare_json_from_yaml_config(yaml_config_or_path)
224
+ self.num_tried_regions += 1
206
225
 
207
226
  def update_cluster_name(self, cluster_name: Union[List[str], str]):
208
227
  if isinstance(cluster_name, str):
@@ -266,16 +285,43 @@ class UsageMessageToReport(MessageToReport):
266
285
  name_or_fn)
267
286
 
268
287
 
288
+ class HeartbeatMessageToReport(MessageToReport):
289
+ """Message to be reported to Grafana Loki for heartbeat on a cluster."""
290
+
291
+ def __init__(self, interval_seconds: int = 600):
292
+ super().__init__(constants.USAGE_MESSAGE_SCHEMA_VERSION)
293
+ # This interval_seconds is mainly for recording the heartbeat interval
294
+ # in the heartbeat message, so that the collector can use it.
295
+ self.interval_seconds = interval_seconds
296
+
297
+ def get_properties(self) -> Dict[str, Any]:
298
+ properties = super().get_properties()
299
+ # The run id is set by the skylet, which will always be the same for
300
+ # the entire lifetime of the run.
301
+ with open(os.path.expanduser(constants.USAGE_RUN_ID_FILE),
302
+ 'r',
303
+ encoding='utf-8') as f:
304
+ properties['run_id'] = f.read().strip()
305
+ return properties
306
+
307
+
269
308
  class MessageCollection:
270
309
  """A collection of messages."""
271
310
 
272
311
  def __init__(self):
273
- self._messages = {MessageType.USAGE: UsageMessageToReport()}
312
+ self._messages = {
313
+ MessageType.USAGE: UsageMessageToReport(),
314
+ MessageType.HEARTBEAT: HeartbeatMessageToReport()
315
+ }
274
316
 
275
317
  @property
276
- def usage(self):
318
+ def usage(self) -> UsageMessageToReport:
277
319
  return self._messages[MessageType.USAGE]
278
320
 
321
+ @property
322
+ def heartbeat(self) -> HeartbeatMessageToReport:
323
+ return self._messages[MessageType.HEARTBEAT]
324
+
279
325
  def reset(self, message_type: MessageType):
280
326
  self._messages[message_type] = self._messages[message_type].__class__()
281
327
 
@@ -299,13 +345,25 @@ def _send_to_loki(message_type: MessageType):
299
345
 
300
346
  message = messages[message_type]
301
347
 
348
+ # In case the message has no start time, set it to the current time.
349
+ message.start()
302
350
  message.send_time = _get_current_timestamp_ns()
303
- log_timestamp = message.start_time
351
+ # Use send time instead of start time to avoid the message being dropped
352
+ # by Loki, due to the timestamp being too old. We still have the start time
353
+ # in the message for dashboard.
354
+ log_timestamp = message.send_time
304
355
 
305
356
  environment = 'prod'
306
357
  if env_options.Options.IS_DEVELOPER.get():
307
358
  environment = 'dev'
308
- prom_labels = {'type': message_type.value, 'environment': environment}
359
+ prom_labels = {
360
+ 'type': message_type.value,
361
+ 'environment': environment,
362
+ 'schema_version': message.schema_version,
363
+ }
364
+ if message_type == MessageType.USAGE:
365
+ prom_labels['new_cluster'] = (message.original_cluster_status != 'UP'
366
+ and message.final_cluster_status == 'UP')
309
367
 
310
368
  headers = {'Content-type': 'application/json'}
311
369
  payload = {
@@ -383,7 +441,7 @@ def prepare_json_from_yaml_config(
383
441
  def _send_local_messages():
384
442
  """Send all messages not been uploaded to Loki."""
385
443
  for msg_type, message in messages.items():
386
- if not message.message_sent:
444
+ if not message.message_sent and msg_type != MessageType.HEARTBEAT:
387
445
  # Avoid the fallback entrypoint to send the message again
388
446
  # in normal case.
389
447
  try:
@@ -393,17 +451,26 @@ def _send_local_messages():
393
451
  f'exception caught: {type(e)}({e})')
394
452
 
395
453
 
396
- @contextlib.contextmanager
397
- def entrypoint_context(name: str, fallback: bool = False):
398
- """Context manager for entrypoint.
454
+ def store_exception(e: Union[Exception, SystemExit, KeyboardInterrupt]) -> None:
455
+ with ux_utils.enable_traceback():
456
+ if hasattr(e, 'stacktrace') and e.stacktrace is not None:
457
+ messages.usage.stacktrace = e.stacktrace
458
+ else:
459
+ trace = traceback.format_exc()
460
+ messages.usage.stacktrace = trace
461
+ if hasattr(e, 'detailed_reason') and e.detailed_reason is not None:
462
+ messages.usage.stacktrace += '\nDetails: ' + e.detailed_reason
463
+ messages.usage.exception = common_utils.remove_color(
464
+ common_utils.format_exception(e))
399
465
 
400
- The context manager will send the usage message to Loki when exiting.
401
- The message will only be sent at the outermost level of the context.
402
466
 
403
- When the outermost context does not cover all the codepaths, an
404
- additional entrypoint_context with fallback=True can be used to wrap
405
- the global entrypoint to catch any exceptions that are not caught.
406
- """
467
+ def send_heartbeat(interval_seconds: int = 600):
468
+ messages.heartbeat.interval_seconds = interval_seconds
469
+ _send_to_loki(MessageType.HEARTBEAT)
470
+
471
+
472
+ def maybe_show_privacy_policy():
473
+ """Show the privacy policy if it is not already shown."""
407
474
  # Show the policy message only when the entrypoint is used.
408
475
  # An indicator for PRIVACY_POLICY has already been shown.
409
476
  privacy_policy_indicator = os.path.expanduser(constants.PRIVACY_POLICY_PATH)
@@ -411,10 +478,22 @@ def entrypoint_context(name: str, fallback: bool = False):
411
478
  os.makedirs(os.path.dirname(privacy_policy_indicator), exist_ok=True)
412
479
  try:
413
480
  with open(privacy_policy_indicator, 'x', encoding='utf-8'):
414
- click.secho(constants.USAGE_POLICY_MESSAGE, fg='yellow')
481
+ logger.info(constants.USAGE_POLICY_MESSAGE)
415
482
  except FileExistsError:
416
483
  pass
417
484
 
485
+
486
+ @contextlib.contextmanager
487
+ def entrypoint_context(name: str, fallback: bool = False):
488
+ """Context manager for entrypoint.
489
+
490
+ The context manager will send the usage message to Loki when exiting.
491
+ The message will only be sent at the outermost level of the context.
492
+
493
+ When the outermost context does not cover all the codepaths, an
494
+ additional entrypoint_context with fallback=True can be used to wrap
495
+ the global entrypoint to catch any exceptions that are not caught.
496
+ """
418
497
  is_entry = messages.usage.entrypoint is None
419
498
  if is_entry and not fallback:
420
499
  for message in messages.values():
@@ -428,13 +507,7 @@ def entrypoint_context(name: str, fallback: bool = False):
428
507
  try:
429
508
  yield
430
509
  except (Exception, SystemExit, KeyboardInterrupt) as e:
431
- with ux_utils.enable_traceback():
432
- trace = traceback.format_exc()
433
- messages.usage.stacktrace = trace
434
- if hasattr(e, 'detailed_reason') and e.detailed_reason is not None:
435
- messages.usage.stacktrace += '\nDetails: ' + e.detailed_reason
436
- messages.usage.exception = common_utils.remove_color(
437
- common_utils.format_exception(e))
510
+ store_exception(e)
438
511
  raise
439
512
  finally:
440
513
  if fallback:
@@ -442,7 +515,27 @@ def entrypoint_context(name: str, fallback: bool = False):
442
515
  _send_local_messages()
443
516
 
444
517
 
445
- def entrypoint(name_or_fn: Union[str, Callable], fallback: bool = False):
518
+ T = typing.TypeVar('T')
519
+
520
+
521
+ @typing.overload
522
+ def entrypoint(
523
+ name_or_fn: str,
524
+ fallback: bool = False
525
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
526
+ ...
527
+
528
+
529
+ @typing.overload
530
+ def entrypoint(name_or_fn: Callable[..., T],
531
+ fallback: bool = False) -> Callable[..., T]:
532
+ ...
533
+
534
+
535
+ def entrypoint(
536
+ name_or_fn: Union[str, Callable[..., T]],
537
+ fallback: bool = False
538
+ ) -> Union[Callable[..., T], Callable[[Callable[..., T]], Callable[..., T]]]:
446
539
  return common_utils.make_decorator(entrypoint_context,
447
540
  name_or_fn,
448
541
  fallback=fallback)
@@ -3,12 +3,13 @@ import typing
3
3
  from typing import Optional
4
4
 
5
5
  from sky.clouds import service_catalog
6
+ from sky.utils import rich_utils
6
7
  from sky.utils import ux_utils
7
8
 
8
9
  if typing.TYPE_CHECKING:
9
10
  from sky import clouds
10
11
 
11
- # Canonicalized names of all accelerators (except TPUs) supported by SkyPilot.
12
+ # Canonical names of all accelerators (except TPUs) supported by SkyPilot.
12
13
  # NOTE: Must include accelerators supported for local clusters.
13
14
  #
14
15
  # 1. What if a name is in this list, but not in any catalog?
@@ -30,30 +31,10 @@ if typing.TYPE_CHECKING:
30
31
  #
31
32
  # Append its case-sensitive canonical name to this list. The name must match
32
33
  # `AcceleratorName` in the service catalog.
33
- _ACCELERATORS = [
34
- 'A100',
35
- 'A10G',
36
- 'Gaudi HL-205',
37
- 'Inferentia',
38
- 'Trainium',
39
- 'K520',
40
- 'K80',
41
- 'M60',
42
- 'Radeon Pro V520',
43
- 'T4',
44
- 'T4g',
45
- 'V100',
46
- 'V100-32GB',
47
- 'Virtex UltraScale (VU9P)',
48
- 'A10',
49
- 'A100-80GB',
50
- 'P100',
51
- 'P40',
52
- 'Radeon MI25',
53
- 'P4',
54
- 'L4',
55
- 'H100',
56
- ]
34
+
35
+ # Use a cached version of accelerators to cloud mapping, so that we don't have
36
+ # to download and read the catalog file for every cloud locally.
37
+ _accelerator_df = service_catalog.common.read_catalog('common/accelerators.csv')
57
38
 
58
39
  # List of non-GPU accelerators that are supported by our backend for job queue
59
40
  # scheduling.
@@ -77,42 +58,45 @@ def canonicalize_accelerator_name(accelerator: str,
77
58
  """Returns the canonical accelerator name."""
78
59
  cloud_str = None
79
60
  if cloud is not None:
80
- cloud_str = str(cloud).lower()
61
+ cloud_str = str(cloud)
81
62
 
82
63
  # TPU names are always lowercase.
83
64
  if accelerator.lower().startswith('tpu-'):
84
65
  return accelerator.lower()
85
66
 
86
67
  # Common case: do not read the catalog files.
87
- mapping = {name.lower(): name for name in _ACCELERATORS}
88
- if accelerator.lower() in mapping:
89
- return mapping[accelerator.lower()]
90
-
91
- # _ACCELERATORS may not be comprehensive.
92
- # Users may manually add new accelerators to the catalogs, or download new
93
- # catalogs (that have new accelerators) without upgrading SkyPilot.
94
- # To cover such cases, we should search the accelerator name
95
- # in the service catalog.
96
- searched = service_catalog.list_accelerators(name_filter=accelerator,
97
- case_sensitive=False,
98
- clouds=cloud_str)
99
- names = list(searched.keys())
100
-
101
- # Exact match.
102
- if accelerator in names:
68
+ df = _accelerator_df[_accelerator_df['AcceleratorName'].str.contains(
69
+ accelerator, case=False, regex=True)]
70
+ names = []
71
+ for name, clouds in df[['AcceleratorName', 'Clouds']].values:
72
+ if accelerator.lower() == name.lower():
73
+ return name
74
+ if cloud_str is None or cloud_str in clouds:
75
+ names.append(name)
76
+
77
+ # Look for Kubernetes accelerators online if the accelerator is not found
78
+ # in the public cloud catalog. This is to make sure custom accelerators
79
+ # on Kubernetes can be correctly canonicalized.
80
+ if not names and cloud_str in ['kubernetes', None]:
81
+ with rich_utils.safe_status(
82
+ ux_utils.spinner_message('Listing accelerators on Kubernetes')):
83
+ searched = service_catalog.list_accelerators(
84
+ name_filter=accelerator,
85
+ case_sensitive=False,
86
+ clouds=cloud_str,
87
+ )
88
+ names = list(searched.keys())
89
+ if accelerator in names:
90
+ return accelerator
91
+
92
+ if not names:
93
+ # If no match is found, it is fine to return the original name, as
94
+ # the custom accelerator might be on kubernetes cluster.
103
95
  return accelerator
104
96
 
105
97
  if len(names) == 1:
106
98
  return names[0]
107
-
108
- # Do not print an error message here. Optimizer will handle it.
109
- if len(names) == 0:
110
- return accelerator
111
-
112
- # Currently unreachable.
113
- # This can happen if catalogs have the same accelerator with
114
- # different names (e.g., A10g and A10G).
115
- assert len(names) > 1
99
+ assert len(names) > 1, names
116
100
  with ux_utils.print_exception_no_traceback():
117
101
  raise ValueError(f'Accelerator name {accelerator!r} is ambiguous. '
118
102
  f'Please choose one of {names}.')
@@ -0,0 +1,147 @@
1
+ """Admin policy utils."""
2
+ import copy
3
+ import importlib
4
+ import os
5
+ import tempfile
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import colorama
9
+
10
+ from sky import admin_policy
11
+ from sky import dag as dag_lib
12
+ from sky import exceptions
13
+ from sky import sky_logging
14
+ from sky import skypilot_config
15
+ from sky import task as task_lib
16
+ from sky.utils import common_utils
17
+ from sky.utils import config_utils
18
+ from sky.utils import ux_utils
19
+
20
+ logger = sky_logging.init_logger(__name__)
21
+
22
+
23
+ def _get_policy_cls(
24
+ policy: Optional[str]) -> Optional[admin_policy.AdminPolicy]:
25
+ """Gets admin-defined policy."""
26
+ if policy is None:
27
+ return None
28
+ try:
29
+ module_path, class_name = policy.rsplit('.', 1)
30
+ module = importlib.import_module(module_path)
31
+ except ImportError as e:
32
+ with ux_utils.print_exception_no_traceback():
33
+ raise ImportError(
34
+ f'Failed to import policy module: {policy}. '
35
+ 'Please check if the module is installed in your Python '
36
+ 'environment.') from e
37
+
38
+ try:
39
+ policy_cls = getattr(module, class_name)
40
+ except AttributeError as e:
41
+ with ux_utils.print_exception_no_traceback():
42
+ raise AttributeError(
43
+ f'Could not find {class_name} class in module {module_path}. '
44
+ 'Please check with your policy admin for details.') from e
45
+
46
+ # Check if the module implements the AdminPolicy interface.
47
+ if not issubclass(policy_cls, admin_policy.AdminPolicy):
48
+ with ux_utils.print_exception_no_traceback():
49
+ raise ValueError(
50
+ f'Policy class {policy!r} does not implement the AdminPolicy '
51
+ 'interface. Please check with your policy admin for details.')
52
+ return policy_cls
53
+
54
+
55
+ def apply(
56
+ entrypoint: Union['dag_lib.Dag', 'task_lib.Task'],
57
+ use_mutated_config_in_current_request: bool = True,
58
+ request_options: Optional[admin_policy.RequestOptions] = None,
59
+ ) -> Tuple['dag_lib.Dag', config_utils.Config]:
60
+ """Applies an admin policy (if registered) to a DAG or a task.
61
+
62
+ It mutates a Dag by applying any registered admin policy and also
63
+ potentially updates (controlled by `use_mutated_config_in_current_request`)
64
+ the global SkyPilot config if there is any changes made by the policy.
65
+
66
+ Args:
67
+ dag: The dag to be mutated by the policy.
68
+ use_mutated_config_in_current_request: Whether to use the mutated
69
+ config in the current request.
70
+ request_options: Additional options user passed for the current request.
71
+
72
+ Returns:
73
+ - The new copy of dag after applying the policy
74
+ - The new copy of skypilot config after applying the policy.
75
+ """
76
+ if isinstance(entrypoint, task_lib.Task):
77
+ dag = dag_lib.Dag()
78
+ dag.add(entrypoint)
79
+ else:
80
+ dag = entrypoint
81
+
82
+ policy = skypilot_config.get_nested(('admin_policy',), None)
83
+ policy_cls = _get_policy_cls(policy)
84
+ if policy_cls is None:
85
+ return dag, skypilot_config.to_dict()
86
+
87
+ logger.info(f'Applying policy: {policy}')
88
+ original_config = skypilot_config.to_dict()
89
+ config = copy.deepcopy(original_config)
90
+ mutated_dag = dag_lib.Dag()
91
+ mutated_dag.name = dag.name
92
+
93
+ mutated_config = None
94
+ for task in dag.tasks:
95
+ user_request = admin_policy.UserRequest(task, config, request_options)
96
+ try:
97
+ mutated_user_request = policy_cls.validate_and_mutate(user_request)
98
+ except Exception as e: # pylint: disable=broad-except
99
+ with ux_utils.print_exception_no_traceback():
100
+ raise exceptions.UserRequestRejectedByPolicy(
101
+ f'{colorama.Fore.RED}User request rejected by policy '
102
+ f'{policy!r}{colorama.Fore.RESET}: '
103
+ f'{common_utils.format_exception(e, use_bracket=True)}'
104
+ ) from e
105
+ if mutated_config is None:
106
+ mutated_config = mutated_user_request.skypilot_config
107
+ else:
108
+ if mutated_config != mutated_user_request.skypilot_config:
109
+ # In the case of a pipeline of tasks, the mutated config
110
+ # generated should remain the same for all tasks for now for
111
+ # simplicity.
112
+ # TODO(zhwu): We should support per-task mutated config or
113
+ # allowing overriding required global config in task YAML.
114
+ with ux_utils.print_exception_no_traceback():
115
+ raise exceptions.UserRequestRejectedByPolicy(
116
+ 'All tasks must have the same SkyPilot config after '
117
+ 'applying the policy. Please check with your policy '
118
+ 'admin for details.')
119
+ mutated_dag.add(mutated_user_request.task)
120
+ assert mutated_config is not None, dag
121
+
122
+ # Update the new_dag's graph with the old dag's graph
123
+ for u, v in dag.graph.edges:
124
+ u_idx = dag.tasks.index(u)
125
+ v_idx = dag.tasks.index(v)
126
+ mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx],
127
+ mutated_dag.tasks[v_idx])
128
+
129
+ if (use_mutated_config_in_current_request and
130
+ original_config != mutated_config):
131
+ with tempfile.NamedTemporaryFile(
132
+ delete=False,
133
+ mode='w',
134
+ prefix='policy-mutated-skypilot-config-',
135
+ suffix='.yaml') as temp_file:
136
+
137
+ common_utils.dump_yaml(temp_file.name, dict(**mutated_config))
138
+ os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name
139
+ logger.debug(f'Updated SkyPilot config: {temp_file.name}')
140
+ # TODO(zhwu): This is not a clean way to update the SkyPilot config,
141
+ # because we are resetting the global context for a single DAG,
142
+ # which is conceptually weird.
143
+ importlib.reload(skypilot_config)
144
+
145
+ logger.debug(f'Mutated user request: {mutated_user_request}')
146
+ mutated_dag.policy_applied = True
147
+ return mutated_dag, mutated_config
@@ -0,0 +1,51 @@
1
+ """Annotations for public APIs."""
2
+
3
+ import functools
4
+ from typing import Callable, Literal
5
+
6
+ # Whether the current process is a SkyPilot API server process.
7
+ is_on_api_server = True
8
+ FUNCTIONS_NEED_RELOAD_CACHE = []
9
+
10
+
11
+ def client_api(func):
12
+ """Mark a function as a client-side API.
13
+
14
+ Code invoked by server-side functions will find annotations.is_on_api_server
15
+ to be True, so they can have some server-side handling.
16
+ """
17
+
18
+ @functools.wraps(func)
19
+ def wrapper(*args, **kwargs):
20
+ global is_on_api_server
21
+ is_on_api_server = False
22
+ return func(*args, **kwargs)
23
+
24
+ return wrapper
25
+
26
+
27
+ def lru_cache(scope: Literal['global', 'request'], *lru_cache_args,
28
+ **lru_cache_kwargs) -> Callable:
29
+ """LRU cache decorator for functions.
30
+
31
+ This decorator allows us to track which functions need to be reloaded for a
32
+ new request using the scope argument.
33
+
34
+ Args:
35
+ scope: Whether the cache is global or request-specific, i.e. needs to be
36
+ reloaded for a new request.
37
+ lru_cache_args: Arguments for functools.lru_cache.
38
+ lru_cache_kwargs: Keyword arguments for functools.lru_cache.
39
+ """
40
+
41
+ def decorator(func: Callable) -> Callable:
42
+ if scope == 'global':
43
+ return functools.lru_cache(*lru_cache_args,
44
+ **lru_cache_kwargs)(func)
45
+ else:
46
+ cached_func = functools.lru_cache(*lru_cache_args,
47
+ **lru_cache_kwargs)(func)
48
+ FUNCTIONS_NEED_RELOAD_CACHE.append(cached_func)
49
+ return cached_func
50
+
51
+ return decorator