skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ """Managed Instance Group Utils"""
2
+ import re
3
+ import subprocess
4
+ from typing import Any, Dict
5
+
6
+ from sky import sky_logging
7
+ from sky.adaptors import gcp
8
+ from sky.provision.gcp import constants
9
+
10
+ logger = sky_logging.init_logger(__name__)
11
+
12
+ MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile(
13
+ r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not '
14
+ r'found')
15
+
16
+ IT_RESOURCE_NOT_FOUND_PATTERN = re.compile(
17
+ r'The resource \'projects/.*/regions/.*/instanceTemplates/.*\' was not '
18
+ 'found')
19
+
20
+
21
+ def get_instance_template_name(cluster_name: str) -> str:
22
+ return f'{constants.INSTANCE_TEMPLATE_NAME_PREFIX}{cluster_name}'
23
+
24
+
25
+ def get_managed_instance_group_name(cluster_name: str) -> str:
26
+ return f'{constants.MIG_NAME_PREFIX}{cluster_name}'
27
+
28
+
29
+ def check_instance_template_exits(project_id: str, region: str,
30
+ template_name: str) -> bool:
31
+ compute = gcp.build('compute',
32
+ 'v1',
33
+ credentials=None,
34
+ cache_discovery=False)
35
+ try:
36
+ compute.regionInstanceTemplates().get(
37
+ project=project_id, region=region,
38
+ instanceTemplate=template_name).execute()
39
+ except gcp.http_error_exception() as e:
40
+ if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None:
41
+ # Instance template does not exist.
42
+ return False
43
+ raise
44
+ return True
45
+
46
+
47
+ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str,
48
+ region: str, template_name: str,
49
+ node_config: Dict[str, Any]) -> dict:
50
+ """Create a regional instance template."""
51
+ logger.debug(f'Creating regional instance template {template_name!r}.')
52
+ compute = gcp.build('compute',
53
+ 'v1',
54
+ credentials=None,
55
+ cache_discovery=False)
56
+ config = node_config.copy()
57
+ config.pop(constants.MANAGED_INSTANCE_GROUP_CONFIG, None)
58
+
59
+ # We have to ignore user defined scheduling for DWS.
60
+ # TODO: Add a warning log for this behvaiour.
61
+ scheduling = config.get('scheduling', {})
62
+ assert scheduling.get('provisioningModel') != 'SPOT', (
63
+ 'DWS does not support spot VMs.')
64
+
65
+ reservations_affinity = config.pop('reservation_affinity', None)
66
+ if reservations_affinity is not None:
67
+ logger.warning(
68
+ f'Ignoring reservations_affinity {reservations_affinity} '
69
+ 'for DWS.')
70
+
71
+ # Create the regional instance template request
72
+ operation = compute.regionInstanceTemplates().insert(
73
+ project=project_id,
74
+ region=region,
75
+ body={
76
+ 'name': template_name,
77
+ 'properties': dict(
78
+ description=(
79
+ 'SkyPilot instance template for '
80
+ f'{cluster_name_on_cloud!r} to support DWS requests.'),
81
+ reservationAffinity=dict(
82
+ consumeReservationType='NO_RESERVATION'),
83
+ **config,
84
+ )
85
+ }).execute()
86
+ return operation
87
+
88
+
89
+ def create_managed_instance_group(project_id: str, zone: str, group_name: str,
90
+ instance_template_url: str,
91
+ size: int) -> dict:
92
+ logger.debug(f'Creating managed instance group {group_name!r}.')
93
+ compute = gcp.build('compute',
94
+ 'v1',
95
+ credentials=None,
96
+ cache_discovery=False)
97
+ operation = compute.instanceGroupManagers().insert(
98
+ project=project_id,
99
+ zone=zone,
100
+ body={
101
+ 'name': group_name,
102
+ 'instanceTemplate': instance_template_url,
103
+ 'target_size': size,
104
+ 'instanceLifecyclePolicy': {
105
+ 'defaultActionOnFailure': 'DO_NOTHING',
106
+ },
107
+ 'updatePolicy': {
108
+ 'type': 'OPPORTUNISTIC',
109
+ },
110
+ }).execute()
111
+ return operation
112
+
113
+
114
+ def resize_managed_instance_group(project_id: str, zone: str, group_name: str,
115
+ resize_by: int, run_duration: int) -> dict:
116
+ logger.debug(f'Resizing managed instance group {group_name!r} by '
117
+ f'{resize_by} with run duration {run_duration}.')
118
+ compute = gcp.build('compute',
119
+ 'beta',
120
+ credentials=None,
121
+ cache_discovery=False)
122
+ operation = compute.instanceGroupManagerResizeRequests().insert(
123
+ project=project_id,
124
+ zone=zone,
125
+ instanceGroupManager=group_name,
126
+ body={
127
+ 'name': group_name,
128
+ 'resizeBy': resize_by,
129
+ 'requestedRunDuration': {
130
+ 'seconds': run_duration,
131
+ }
132
+ }).execute()
133
+ return operation
134
+
135
+
136
+ def cancel_all_resize_request_for_mig(project_id: str, zone: str,
137
+ group_name: str) -> None:
138
+ logger.debug(f'Cancelling all resize requests for MIG {group_name!r}.')
139
+ try:
140
+ compute = gcp.build('compute',
141
+ 'beta',
142
+ credentials=None,
143
+ cache_discovery=False)
144
+ operation = compute.instanceGroupManagerResizeRequests().list(
145
+ project=project_id,
146
+ zone=zone,
147
+ instanceGroupManager=group_name,
148
+ filter='state eq ACCEPTED').execute()
149
+ for request in operation.get('items', []):
150
+ try:
151
+ compute.instanceGroupManagerResizeRequests().cancel(
152
+ project=project_id,
153
+ zone=zone,
154
+ instanceGroupManager=group_name,
155
+ resizeRequest=request['name']).execute()
156
+ except gcp.http_error_exception() as e:
157
+ logger.warning('Failed to cancel resize request '
158
+ f'{request["id"]!r}: {e}')
159
+ except gcp.http_error_exception() as e:
160
+ if re.search(MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None:
161
+ raise
162
+ logger.warning(f'MIG {group_name!r} does not exist. Skip '
163
+ 'resize request cancellation.')
164
+ logger.debug(f'Error: {e}')
165
+
166
+
167
+ def check_managed_instance_group_exists(project_id: str, zone: str,
168
+ group_name: str) -> bool:
169
+ compute = gcp.build('compute',
170
+ 'v1',
171
+ credentials=None,
172
+ cache_discovery=False)
173
+ try:
174
+ compute.instanceGroupManagers().get(
175
+ project=project_id, zone=zone,
176
+ instanceGroupManager=group_name).execute()
177
+ except gcp.http_error_exception() as e:
178
+ if MIG_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None:
179
+ return False
180
+ raise
181
+ return True
182
+
183
+
184
+ def wait_for_managed_group_to_be_stable(project_id: str, zone: str,
185
+ group_name: str, timeout: int) -> None:
186
+ """Wait until the managed instance group is stable."""
187
+ logger.debug(f'Waiting for MIG {group_name} to be stable with timeout '
188
+ f'{timeout}.')
189
+ try:
190
+ cmd = ('gcloud compute instance-groups managed wait-until '
191
+ f'{group_name} '
192
+ '--stable '
193
+ f'--zone={zone} '
194
+ f'--project={project_id} '
195
+ f'--timeout={timeout}')
196
+ logger.info(
197
+ f'Waiting for MIG {group_name} to be stable with command:\n{cmd}')
198
+ proc = subprocess.run(
199
+ f'yes | {cmd}',
200
+ stdout=subprocess.PIPE,
201
+ stderr=subprocess.PIPE,
202
+ shell=True,
203
+ check=True,
204
+ )
205
+ stdout = proc.stdout.decode('ascii')
206
+ logger.info(stdout)
207
+ except subprocess.CalledProcessError as e:
208
+ stderr = e.stderr.decode('ascii')
209
+ logger.info(stderr)
210
+ raise
@@ -4,10 +4,10 @@ import functools
4
4
  import hashlib
5
5
  import json
6
6
  import os
7
- import resource
8
7
  import time
9
- from typing import Any, Dict, List, Optional, Tuple
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
10
9
 
10
+ from sky import exceptions
11
11
  from sky import provision
12
12
  from sky import sky_logging
13
13
  from sky.provision import common
@@ -15,15 +15,17 @@ from sky.provision import docker_utils
15
15
  from sky.provision import logging as provision_logging
16
16
  from sky.provision import metadata_utils
17
17
  from sky.skylet import constants
18
+ from sky.usage import constants as usage_constants
19
+ from sky.usage import usage_lib
18
20
  from sky.utils import accelerator_registry
19
21
  from sky.utils import command_runner
20
22
  from sky.utils import common_utils
23
+ from sky.utils import env_options
21
24
  from sky.utils import subprocess_utils
25
+ from sky.utils import timeline
22
26
  from sky.utils import ux_utils
23
27
 
24
28
  logger = sky_logging.init_logger(__name__)
25
- _START_TITLE = '\n' + '-' * 20 + 'Start: {} ' + '-' * 20
26
- _END_TITLE = '-' * 20 + 'End: {} ' + '-' * 20 + '\n'
27
29
 
28
30
  _MAX_RETRY = 6
29
31
 
@@ -44,8 +46,8 @@ _RAY_PORT_COMMAND = (
44
46
  f'RAY_PORT=$({constants.SKY_PYTHON_CMD} -c '
45
47
  '"from sky.skylet import job_lib; print(job_lib.get_ray_port())" '
46
48
  '2> /dev/null || echo 6379);'
47
- f'{constants.SKY_PYTHON_CMD} -c "from sky.utils import common_utils; '
48
- 'print(common_utils.encode_payload({\'ray_port\': $RAY_PORT}))"')
49
+ f'{constants.SKY_PYTHON_CMD} -c "from sky.utils import message_utils; '
50
+ 'print(message_utils.encode_payload({\'ray_port\': $RAY_PORT}))"')
49
51
 
50
52
  # Command that calls `ray status` with SkyPilot's Ray port set.
51
53
  RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND = (
@@ -68,42 +70,58 @@ MAYBE_SKYLET_RESTART_CMD = (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV}; '
68
70
  'sky.skylet.attempt_skylet;')
69
71
 
70
72
 
71
- def _auto_retry(func):
73
+ def _set_usage_run_id_cmd() -> str:
74
+ """Gets the command to set the usage run id.
75
+
76
+ The command saves the current usage run id to the file, so that the skylet
77
+ can use it to report the heartbeat.
78
+
79
+ We use a function instead of a constant so that the usage run id is the
80
+ latest one when the function is called.
81
+ """
82
+ return (
83
+ f'cat {usage_constants.USAGE_RUN_ID_FILE} || '
84
+ # The run id is retrieved locally for the current run, so that the
85
+ # remote cluster will be set with the same run id as the initial
86
+ # launch operation.
87
+ f'echo "{usage_lib.messages.usage.run_id}" > '
88
+ f'{usage_constants.USAGE_RUN_ID_FILE}')
89
+
90
+
91
+ def _set_skypilot_env_var_cmd() -> str:
92
+ """Sets the skypilot environment variables on the remote machine."""
93
+ env_vars = env_options.Options.all_options()
94
+ return '; '.join([f'export {k}={v}' for k, v in env_vars.items()])
95
+
96
+
97
+ def _auto_retry(should_retry: Callable[[Exception], bool] = lambda _: True):
72
98
  """Decorator that retries the function if it fails.
73
99
 
74
100
  This decorator is mostly for SSH disconnection issues, which might happen
75
101
  during the setup of instances.
76
102
  """
77
103
 
78
- @functools.wraps(func)
79
- def retry(*args, **kwargs):
80
- backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=5)
81
- for retry_cnt in range(_MAX_RETRY):
82
- try:
83
- return func(*args, **kwargs)
84
- except Exception as e: # pylint: disable=broad-except
85
- if retry_cnt >= _MAX_RETRY - 1:
86
- raise e
87
- sleep = backoff.current_backoff()
88
- logger.info(
89
- f'{func.__name__}: Retrying in {sleep:.1f} seconds, '
90
- f'due to {e}')
91
- time.sleep(sleep)
92
-
93
- return retry
94
-
104
+ def decorator(func):
95
105
 
96
- def _log_start_end(func):
106
+ @functools.wraps(func)
107
+ def retry(*args, **kwargs):
108
+ backoff = common_utils.Backoff(initial_backoff=1,
109
+ max_backoff_factor=5)
110
+ for retry_cnt in range(_MAX_RETRY):
111
+ try:
112
+ return func(*args, **kwargs)
113
+ except Exception as e: # pylint: disable=broad-except
114
+ if not should_retry(e) or retry_cnt >= _MAX_RETRY - 1:
115
+ raise
116
+ sleep = backoff.current_backoff()
117
+ logger.info(
118
+ f'{func.__name__}: Retrying in {sleep:.1f} seconds, '
119
+ f'due to {e}')
120
+ time.sleep(sleep)
97
121
 
98
- @functools.wraps(func)
99
- def wrapper(*args, **kwargs):
100
- logger.info(_START_TITLE.format(func.__name__))
101
- try:
102
- return func(*args, **kwargs)
103
- finally:
104
- logger.info(_END_TITLE.format(func.__name__))
122
+ return retry
105
123
 
106
- return wrapper
124
+ return decorator
107
125
 
108
126
 
109
127
  def _hint_worker_log_path(cluster_name: str, cluster_info: common.ClusterInfo,
@@ -124,7 +142,8 @@ def _parallel_ssh_with_cache(func,
124
142
  if max_workers is None:
125
143
  # Not using the default value of `max_workers` in ThreadPoolExecutor,
126
144
  # as 32 is too large for some machines.
127
- max_workers = subprocess_utils.get_parallel_threads()
145
+ max_workers = subprocess_utils.get_parallel_threads(
146
+ cluster_info.provider_name)
128
147
  with futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
129
148
  results = []
130
149
  runners = provision.get_command_runners(cluster_info.provider_name,
@@ -147,7 +166,7 @@ def _parallel_ssh_with_cache(func,
147
166
  return [future.result() for future in results]
148
167
 
149
168
 
150
- @_log_start_end
169
+ @common.log_function_start_end
151
170
  def initialize_docker(cluster_name: str, docker_config: Dict[str, Any],
152
171
  cluster_info: common.ClusterInfo,
153
172
  ssh_credentials: Dict[str, Any]) -> Optional[str]:
@@ -156,7 +175,8 @@ def initialize_docker(cluster_name: str, docker_config: Dict[str, Any],
156
175
  return None
157
176
  _hint_worker_log_path(cluster_name, cluster_info, 'initialize_docker')
158
177
 
159
- @_auto_retry
178
+ @_auto_retry(should_retry=lambda e: isinstance(e, exceptions.CommandError)
179
+ and e.returncode == 255)
160
180
  def _initialize_docker(runner: command_runner.CommandRunner, log_path: str):
161
181
  docker_user = docker_utils.DockerInitializer(docker_config, runner,
162
182
  log_path).initialize()
@@ -177,7 +197,8 @@ def initialize_docker(cluster_name: str, docker_config: Dict[str, Any],
177
197
  return docker_users[0]
178
198
 
179
199
 
180
- @_log_start_end
200
+ @common.log_function_start_end
201
+ @timeline.event
181
202
  def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str],
182
203
  cluster_info: common.ClusterInfo,
183
204
  ssh_credentials: Dict[str, Any]) -> None:
@@ -193,7 +214,7 @@ def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str],
193
214
  hasher.update(d)
194
215
  digest = hasher.hexdigest()
195
216
 
196
- @_auto_retry
217
+ @_auto_retry()
197
218
  def _setup_node(runner: command_runner.CommandRunner, log_path: str):
198
219
  for cmd in setup_commands:
199
220
  returncode, stdout, stderr = runner.run(
@@ -253,49 +274,102 @@ def _ray_gpu_options(custom_resource: str) -> str:
253
274
  return f' --num-gpus={acc_count}'
254
275
 
255
276
 
256
- @_log_start_end
257
- @_auto_retry
258
- def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
259
- cluster_info: common.ClusterInfo,
260
- ssh_credentials: Dict[str, Any]) -> None:
261
- """Start Ray on the head node."""
262
- runners = provision.get_command_runners(cluster_info.provider_name,
263
- cluster_info, **ssh_credentials)
264
- head_runner = runners[0]
265
- assert cluster_info.head_instance_id is not None, (cluster_name,
266
- cluster_info)
267
-
268
- # Log the head node's output to the provision.log
269
- log_path_abs = str(provision_logging.get_log_path())
277
+ def ray_head_start_command(custom_resource: Optional[str],
278
+ custom_ray_options: Optional[Dict[str, Any]]) -> str:
279
+ """Returns the command to start Ray on the head node."""
270
280
  ray_options = (
271
281
  # --disable-usage-stats in `ray start` saves 10 seconds of idle wait.
272
282
  f'--disable-usage-stats '
273
283
  f'--port={constants.SKY_REMOTE_RAY_PORT} '
274
284
  f'--dashboard-port={constants.SKY_REMOTE_RAY_DASHBOARD_PORT} '
285
+ f'--min-worker-port 11002 '
275
286
  f'--object-manager-port=8076 '
276
287
  f'--temp-dir={constants.SKY_REMOTE_RAY_TEMPDIR}')
277
288
  if custom_resource:
278
289
  ray_options += f' --resources=\'{custom_resource}\''
279
290
  ray_options += _ray_gpu_options(custom_resource)
291
+ if custom_ray_options:
292
+ if 'use_external_ip' in custom_ray_options:
293
+ custom_ray_options.pop('use_external_ip')
294
+ for key, value in custom_ray_options.items():
295
+ ray_options += f' --{key}={value}'
280
296
 
281
- if cluster_info.custom_ray_options:
282
- if 'use_external_ip' in cluster_info.custom_ray_options:
283
- cluster_info.custom_ray_options.pop('use_external_ip')
284
- for key, value in cluster_info.custom_ray_options.items():
297
+ cmd = (
298
+ f'{constants.SKY_RAY_CMD} stop; '
299
+ 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
300
+ # worker_maximum_startup_concurrency controls the maximum number of
301
+ # workers that can be started concurrently. However, it also controls
302
+ # this warning message:
303
+ # https://github.com/ray-project/ray/blob/d5d03e6e24ae3cfafb87637ade795fb1480636e6/src/ray/raylet/worker_pool.cc#L1535-L1545
304
+ # maximum_startup_concurrency defaults to the number of CPUs given by
305
+ # multiprocessing.cpu_count() or manually specified to ray. (See
306
+ # https://github.com/ray-project/ray/blob/fab26e1813779eb568acba01281c6dd963c13635/python/ray/_private/services.py#L1622-L1624.)
307
+ # The warning will show when the number of workers is >4x the
308
+ # maximum_startup_concurrency, so typically 4x CPU count. However, the
309
+ # job controller uses 0.25cpu reservations, and each job can use two
310
+ # workers (one for the submitted job and one for remote actors),
311
+ # resulting in a worker count of 8x CPUs or more. Increase the
312
+ # worker_maximum_startup_concurrency to 3x CPUs so that we will only see
313
+ # the warning when the worker count is >12x CPUs.
314
+ 'RAY_worker_maximum_startup_concurrency=$(( 3 * $(nproc --all) )) '
315
+ f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' +
316
+ _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND)
317
+ return cmd
318
+
319
+
320
+ def ray_worker_start_command(custom_resource: Optional[str],
321
+ custom_ray_options: Optional[Dict[str, Any]],
322
+ no_restart: bool) -> str:
323
+ """Returns the command to start Ray on the worker node."""
324
+ # We need to use the ray port in the env variable, because the head node
325
+ # determines the port to be used for the worker node.
326
+ ray_options = ('--address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT} '
327
+ '--object-manager-port=8076')
328
+
329
+ if custom_resource:
330
+ ray_options += f' --resources=\'{custom_resource}\''
331
+ ray_options += _ray_gpu_options(custom_resource)
332
+
333
+ if custom_ray_options:
334
+ for key, value in custom_ray_options.items():
285
335
  ray_options += f' --{key}={value}'
286
336
 
287
- # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY to avoid using credentials
288
- # from environment variables set by user. SkyPilot's ray cluster should use
289
- # the `~/.aws/` credentials, as that is the one used to create the cluster,
290
- # and the autoscaler module started by the `ray start` command should use
291
- # the same credentials. Otherwise, `ray status` will fail to fetch the
292
- # available nodes.
293
- # Reference: https://github.com/skypilot-org/skypilot/issues/2441
294
- cmd = (f'{constants.SKY_RAY_CMD} stop; '
295
- 'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; '
296
- 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
297
- f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' +
298
- _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND)
337
+ cmd = (
338
+ 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
339
+ f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || '
340
+ 'exit 1;' + _RAY_PRLIMIT)
341
+ if no_restart:
342
+ # We do not use ray status to check whether ray is running, because
343
+ # on worker node, if the user started their own ray cluster, ray status
344
+ # will return 0, i.e., we don't know skypilot's ray cluster is running.
345
+ # Instead, we check whether the raylet process is running on gcs address
346
+ # that is connected to the head with the correct port.
347
+ cmd = (
348
+ f'ps aux | grep "ray/raylet/raylet" | '
349
+ 'grep "gcs-address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT}" '
350
+ f'|| {{ {cmd} }}')
351
+ else:
352
+ cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd
353
+ return cmd
354
+
355
+
356
+ @common.log_function_start_end
357
+ @_auto_retry()
358
+ @timeline.event
359
+ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
360
+ cluster_info: common.ClusterInfo,
361
+ ssh_credentials: Dict[str, Any]) -> None:
362
+ """Start Ray on the head node."""
363
+ runners = provision.get_command_runners(cluster_info.provider_name,
364
+ cluster_info, **ssh_credentials)
365
+ head_runner = runners[0]
366
+ assert cluster_info.head_instance_id is not None, (cluster_name,
367
+ cluster_info)
368
+
369
+ # Log the head node's output to the provision.log
370
+ log_path_abs = str(provision_logging.get_log_path())
371
+ cmd = ray_head_start_command(custom_resource,
372
+ cluster_info.custom_ray_options)
299
373
  logger.info(f'Running command on head node: {cmd}')
300
374
  # TODO(zhwu): add the output to log files.
301
375
  returncode, stdout, stderr = head_runner.run(
@@ -313,8 +387,9 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
313
387
  f'===== stderr ====={stderr}')
314
388
 
315
389
 
316
- @_log_start_end
317
- @_auto_retry
390
+ @common.log_function_start_end
391
+ @_auto_retry()
392
+ @timeline.event
318
393
  def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
319
394
  custom_resource: Optional[str], ray_port: int,
320
395
  cluster_info: common.ClusterInfo,
@@ -349,43 +424,17 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
349
424
  head_ip = (head_instance.internal_ip
350
425
  if not use_external_ip else head_instance.external_ip)
351
426
 
352
- ray_options = (f'--address={head_ip}:{constants.SKY_REMOTE_RAY_PORT} '
353
- f'--object-manager-port=8076')
354
-
355
- if custom_resource:
356
- ray_options += f' --resources=\'{custom_resource}\''
357
- ray_options += _ray_gpu_options(custom_resource)
358
-
359
- if cluster_info.custom_ray_options:
360
- for key, value in cluster_info.custom_ray_options.items():
361
- ray_options += f' --{key}={value}'
427
+ ray_cmd = ray_worker_start_command(custom_resource,
428
+ cluster_info.custom_ray_options,
429
+ no_restart)
362
430
 
363
- # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY, see the comment in
364
- # `start_ray_on_head_node`.
365
- cmd = (
366
- f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; '
367
- 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
368
- f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || '
369
- 'exit 1;' + _RAY_PRLIMIT)
370
- if no_restart:
371
- # We do not use ray status to check whether ray is running, because
372
- # on worker node, if the user started their own ray cluster, ray status
373
- # will return 0, i.e., we don't know skypilot's ray cluster is running.
374
- # Instead, we check whether the raylet process is running on gcs address
375
- # that is connected to the head with the correct port.
376
- cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | '
377
- f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || '
378
- f'{{ {cmd} }}')
379
- else:
380
- cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd
431
+ cmd = (f'export SKYPILOT_RAY_HEAD_IP="{head_ip}"; '
432
+ f'export SKYPILOT_RAY_PORT={ray_port}; ' + ray_cmd)
381
433
 
382
434
  logger.info(f'Running command on worker nodes: {cmd}')
383
435
 
384
436
  def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner,
385
437
  str]):
386
- # for cmd in config_from_yaml['worker_start_ray_commands']:
387
- # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0])
388
- # runner.run(cmd)
389
438
  runner, instance_id = runner_and_id
390
439
  log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id)
391
440
  log_path_abs = str(log_dir / ('ray_cluster' + '.log'))
@@ -398,8 +447,10 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
398
447
  # by ray will have the correct PATH.
399
448
  source_bashrc=True)
400
449
 
450
+ num_threads = subprocess_utils.get_parallel_threads(
451
+ cluster_info.provider_name)
401
452
  results = subprocess_utils.run_in_parallel(
402
- _setup_ray_worker, list(zip(worker_runners, cache_ids)))
453
+ _setup_ray_worker, list(zip(worker_runners, cache_ids)), num_threads)
403
454
  for returncode, stdout, stderr in results:
404
455
  if returncode:
405
456
  with ux_utils.print_exception_no_traceback():
@@ -410,8 +461,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
410
461
  f'===== stderr ====={stderr}')
411
462
 
412
463
 
413
- @_log_start_end
414
- @_auto_retry
464
+ @common.log_function_start_end
465
+ @_auto_retry()
466
+ @timeline.event
415
467
  def start_skylet_on_head_node(cluster_name: str,
416
468
  cluster_info: common.ClusterInfo,
417
469
  ssh_credentials: Dict[str, Any]) -> None:
@@ -425,11 +477,17 @@ def start_skylet_on_head_node(cluster_name: str,
425
477
  logger.info(f'Running command on head node: {MAYBE_SKYLET_RESTART_CMD}')
426
478
  # We need to source bashrc for skylet to make sure the autostop event can
427
479
  # access the path to the cloud CLIs.
428
- returncode, stdout, stderr = head_runner.run(MAYBE_SKYLET_RESTART_CMD,
429
- stream_logs=False,
430
- require_outputs=True,
431
- log_path=log_path_abs,
432
- source_bashrc=True)
480
+ set_usage_run_id_cmd = _set_usage_run_id_cmd()
481
+ # Set the skypilot environment variables, including the usage type, debug
482
+ # info, and other options.
483
+ set_skypilot_env_var_cmd = _set_skypilot_env_var_cmd()
484
+ returncode, stdout, stderr = head_runner.run(
485
+ f'{set_usage_run_id_cmd}; {set_skypilot_env_var_cmd}; '
486
+ f'{MAYBE_SKYLET_RESTART_CMD}',
487
+ stream_logs=False,
488
+ require_outputs=True,
489
+ log_path=log_path_abs,
490
+ source_bashrc=True)
433
491
  if returncode:
434
492
  raise RuntimeError('Failed to start skylet on the head node '
435
493
  f'(exit code {returncode}). Error: '
@@ -437,7 +495,7 @@ def start_skylet_on_head_node(cluster_name: str,
437
495
  f'===== stderr ====={stderr}')
438
496
 
439
497
 
440
- @_auto_retry
498
+ @_auto_retry()
441
499
  def _internal_file_mounts(file_mounts: Dict,
442
500
  runner: command_runner.CommandRunner,
443
501
  log_path: str) -> None:
@@ -473,28 +531,8 @@ def _internal_file_mounts(file_mounts: Dict,
473
531
  )
474
532
 
475
533
 
476
- def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int:
477
- fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
478
-
479
- fd_per_rsync = 5
480
- for src in common_file_mounts.values():
481
- if os.path.isdir(src):
482
- # Assume that each file/folder under src takes 5 file descriptors
483
- # on average.
484
- fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5)
485
-
486
- # Reserve some file descriptors for the system and other processes
487
- fd_reserve = 100
488
-
489
- max_workers = (fd_limit - fd_reserve) // fd_per_rsync
490
- # At least 1 worker, and avoid too many workers overloading the system.
491
- max_workers = min(max(max_workers, 1),
492
- subprocess_utils.get_parallel_threads())
493
- logger.debug(f'Using {max_workers} workers for file mounts.')
494
- return max_workers
495
-
496
-
497
- @_log_start_end
534
+ @common.log_function_start_end
535
+ @timeline.event
498
536
  def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str],
499
537
  cluster_info: common.ClusterInfo,
500
538
  ssh_credentials: Dict[str, str]) -> None:
@@ -515,4 +553,5 @@ def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str],
515
553
  digest=None,
516
554
  cluster_info=cluster_info,
517
555
  ssh_credentials=ssh_credentials,
518
- max_workers=_max_workers_for_file_mounts(common_file_mounts))
556
+ max_workers=subprocess_utils.get_max_workers_for_file_mounts(
557
+ common_file_mounts, cluster_info.provider_name))
@@ -2,6 +2,7 @@
2
2
 
3
3
  from sky.provision.kubernetes.config import bootstrap_instances
4
4
  from sky.provision.kubernetes.instance import get_cluster_info
5
+ from sky.provision.kubernetes.instance import get_command_runners
5
6
  from sky.provision.kubernetes.instance import query_instances
6
7
  from sky.provision.kubernetes.instance import run_instances
7
8
  from sky.provision.kubernetes.instance import stop_instances