skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -13,18 +13,15 @@ from sky import sky_logging
13
13
  from sky.adaptors import gcp
14
14
  from sky.clouds import gcp as gcp_cloud
15
15
  from sky.provision import common
16
+ from sky.provision import constants as provision_constants
16
17
  from sky.provision.gcp import constants
18
+ from sky.provision.gcp import mig_utils
17
19
  from sky.utils import common_utils
18
20
  from sky.utils import ux_utils
19
21
 
20
- # Tag uniquely identifying all nodes of a cluster
21
- TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name'
22
- TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
23
22
  # Tag for the name of the node
24
23
  INSTANCE_NAME_MAX_LEN = 64
25
24
  INSTANCE_NAME_UUID_LEN = 8
26
- TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
27
- TAG_RAY_NODE_KIND = 'ray-node-type'
28
25
 
29
26
  TPU_NODE_CREATION_FAILURE = 'Failed to provision TPU node.'
30
27
 
@@ -41,7 +38,7 @@ _FIREWALL_RESOURCE_NOT_FOUND_PATTERN = re.compile(
41
38
  r'The resource \'projects/.*/global/firewalls/.*\' was not found')
42
39
 
43
40
 
44
- def _retry_on_http_exception(
41
+ def _retry_on_gcp_http_exception(
45
42
  regex: Optional[str] = None,
46
43
  max_retries: int = GCP_MAX_RETRIES,
47
44
  retry_interval_s: int = GCP_RETRY_INTERVAL_SECONDS,
@@ -52,17 +49,18 @@ def _retry_on_http_exception(
52
49
 
53
50
  @functools.wraps(func)
54
51
  def wrapper(*args, **kwargs):
55
- exception_type = gcp.http_error_exception()
56
52
 
57
53
  def try_catch_exc():
58
54
  try:
59
55
  value = func(*args, **kwargs)
60
56
  return value
61
57
  except Exception as e: # pylint: disable=broad-except
62
- if not isinstance(e, exception_type) or (
63
- regex and not re.search(regex, str(e))):
64
- raise
65
- return e
58
+ if (isinstance(e, gcp.http_error_exception()) and
59
+ (regex is None or re.search(regex, str(e)))):
60
+ logger.error(
61
+ f'Retrying for gcp.http_error_exception: {e}')
62
+ return e
63
+ raise
66
64
 
67
65
  for _ in range(max_retries):
68
66
  ret = try_catch_exc()
@@ -100,19 +98,20 @@ def _generate_node_name(cluster_name: str, node_suffix: str,
100
98
  return node_name
101
99
 
102
100
 
103
- def _log_errors(errors: List[Dict[str, str]], e: Any,
104
- zone: Optional[str]) -> None:
105
- """Format errors into a string."""
101
+ def _format_and_log_message_from_errors(errors: List[Dict[str, str]], e: Any,
102
+ zone: Optional[str]) -> str:
103
+ """Format errors into a string and log it to the console."""
106
104
  if errors:
107
105
  plural = 's' if len(errors) > 1 else ''
108
106
  codes = ', '.join(repr(e.get('code', 'N/A')) for e in errors)
109
107
  messages = '; '.join(
110
108
  repr(e.get('message', 'N/A').strip('.')) for e in errors)
111
109
  zone_str = f' in {zone}' if zone else ''
112
- logger.warning(f'Got return code{plural} {codes}'
113
- f'{zone_str}: {messages}')
110
+ msg = f'Got return code{plural} {codes}{zone_str}: {messages}'
114
111
  else:
115
- logger.warning(f'create_instances: Failed with reason: {e}')
112
+ msg = f'create_instances: Failed with reason: {e}'
113
+ logger.warning(msg)
114
+ return msg
116
115
 
117
116
 
118
117
  def selflink_to_name(selflink: str) -> str:
@@ -133,6 +132,8 @@ def instance_to_handler(instance: str):
133
132
  return GCPComputeInstance
134
133
  elif instance_type == 'tpu':
135
134
  return GCPTPUVMInstance
135
+ elif instance.startswith(constants.MIG_NAME_PREFIX):
136
+ return GCPManagedInstanceGroup
136
137
  else:
137
138
  raise ValueError(f'Unknown instance type: {instance_type}')
138
139
 
@@ -176,8 +177,11 @@ class GCPInstance:
176
177
  raise NotImplementedError
177
178
 
178
179
  @classmethod
179
- def wait_for_operation(cls, operation: dict, project_id: str,
180
- zone: Optional[str]) -> None:
180
+ def wait_for_operation(cls,
181
+ operation: dict,
182
+ project_id: str,
183
+ region: Optional[str] = None,
184
+ zone: Optional[str] = None) -> None:
181
185
  raise NotImplementedError
182
186
 
183
187
  @classmethod
@@ -239,6 +243,7 @@ class GCPInstance:
239
243
  node_config: dict,
240
244
  labels: dict,
241
245
  count: int,
246
+ total_count: int,
242
247
  include_head_node: bool,
243
248
  ) -> Tuple[Optional[List], List[str]]:
244
249
  """Creates multiple instances and returns result.
@@ -247,6 +252,21 @@ class GCPInstance:
247
252
  """
248
253
  raise NotImplementedError
249
254
 
255
+ @classmethod
256
+ def start_instances(cls, cluster_name: str, project_id: str, zone: str,
257
+ instances: List[str], labels: Dict[str,
258
+ str]) -> List[str]:
259
+ """Start multiple instances.
260
+
261
+ Returns:
262
+ List of instance names that are started.
263
+ """
264
+ del cluster_name # Unused
265
+ for instance_id in instances:
266
+ cls.start_instance(instance_id, project_id, zone)
267
+ cls.set_labels(project_id, zone, instance_id, labels)
268
+ return instances
269
+
250
270
  @classmethod
251
271
  def start_instance(cls, node_id: str, project_id: str, zone: str) -> None:
252
272
  """Start a stopped instance."""
@@ -264,15 +284,9 @@ class GCPInstance:
264
284
  target_instance_id: str,
265
285
  is_head: bool = True) -> str:
266
286
  if is_head:
267
- node_tag = {
268
- TAG_SKYPILOT_HEAD_NODE: '1',
269
- TAG_RAY_NODE_KIND: 'head',
270
- }
287
+ node_tag = provision_constants.HEAD_NODE_TAGS
271
288
  else:
272
- node_tag = {
273
- TAG_SKYPILOT_HEAD_NODE: '0',
274
- TAG_RAY_NODE_KIND: 'worker',
275
- }
289
+ node_tag = provision_constants.WORKER_NODE_TAGS
276
290
  cls.set_labels(project_id=project_id,
277
291
  availability_zone=availability_zone,
278
292
  node_id=target_instance_id,
@@ -400,18 +414,25 @@ class GCPComputeInstance(GCPInstance):
400
414
  return instances
401
415
 
402
416
  @classmethod
403
- def wait_for_operation(cls, operation: dict, project_id: str,
404
- zone: Optional[str]) -> None:
417
+ def wait_for_operation(cls,
418
+ operation: dict,
419
+ project_id: str,
420
+ region: Optional[str] = None,
421
+ zone: Optional[str] = None,
422
+ timeout: int = GCP_TIMEOUT) -> None:
405
423
  if zone is not None:
406
424
  kwargs = {'zone': zone}
407
425
  operation_caller = cls.load_resource().zoneOperations()
426
+ elif region is not None:
427
+ kwargs = {'region': region}
428
+ operation_caller = cls.load_resource().regionOperations()
408
429
  else:
409
430
  kwargs = {}
410
431
  operation_caller = cls.load_resource().globalOperations()
411
432
  logger.debug(
412
433
  f'Waiting GCP operation {operation["name"]} to be ready ...')
413
434
 
414
- @_retry_on_http_exception(
435
+ @_retry_on_gcp_http_exception(
415
436
  f'Failed to wait for operation {operation["name"]}')
416
437
  def call_operation(fn, timeout: int):
417
438
  request = fn(
@@ -423,13 +444,13 @@ class GCPComputeInstance(GCPInstance):
423
444
  return request.execute(num_retries=GCP_MAX_RETRIES)
424
445
 
425
446
  wait_start = time.time()
426
- while time.time() - wait_start < GCP_TIMEOUT:
447
+ while time.time() - wait_start < timeout:
427
448
  # Retry the wait() call until it succeeds or times out.
428
449
  # This is because the wait() call is only best effort, and does not
429
450
  # guarantee that the operation is done when it returns.
430
451
  # Reference: https://cloud.google.com/workflows/docs/reference/googleapis/compute/v1/zoneOperations/wait # pylint: disable=line-too-long
431
- timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1)
432
- result = call_operation(operation_caller.wait, timeout)
452
+ remaining_timeout = max(timeout - (time.time() - wait_start), 1)
453
+ result = call_operation(operation_caller.wait, remaining_timeout)
433
454
  if result['status'] == 'DONE':
434
455
  # NOTE: Error example:
435
456
  # {
@@ -441,8 +462,10 @@ class GCPComputeInstance(GCPInstance):
441
462
  logger.debug(
442
463
  'wait_operations: Failed to create instances. Reason: '
443
464
  f'{errors}')
444
- _log_errors(errors, result, zone)
465
+ msg = _format_and_log_message_from_errors(
466
+ errors, result, zone)
445
467
  error = common.ProvisionerError('Operation failed')
468
+ setattr(error, 'detailed_reason', msg)
446
469
  error.errors = errors
447
470
  raise error
448
471
  return
@@ -451,9 +474,10 @@ class GCPComputeInstance(GCPInstance):
451
474
  else:
452
475
  logger.warning('wait_for_operation: Timeout waiting for creation '
453
476
  'operation, cancelling the operation ...')
454
- timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1)
477
+ remaining_timeout = max(timeout - (time.time() - wait_start), 1)
455
478
  try:
456
- result = call_operation(operation_caller.delete, timeout)
479
+ result = call_operation(operation_caller.delete,
480
+ remaining_timeout)
457
481
  except gcp.http_error_exception() as e:
458
482
  logger.debug('wait_for_operation: failed to cancel operation '
459
483
  f'due to error: {e}')
@@ -462,8 +486,10 @@ class GCPComputeInstance(GCPInstance):
462
486
  'message': f'Timeout waiting for operation {operation["name"]}',
463
487
  'domain': 'wait_for_operation'
464
488
  }]
465
- _log_errors(errors, None, zone)
489
+ msg = _format_and_log_message_from_errors(errors, None, zone)
466
490
  error = common.ProvisionerError('Operation timed out')
491
+ # Used for usage collection only, to include in the usage message.
492
+ setattr(error, 'detailed_reason', msg)
467
493
  error.errors = errors
468
494
  raise error
469
495
 
@@ -588,6 +614,11 @@ class GCPComputeInstance(GCPInstance):
588
614
  return operation
589
615
 
590
616
  @classmethod
617
+ # When there is a cloud function running in parallel to set labels for
618
+ # newly created instances, it may fail with the following error:
619
+ # "Labels fingerprint either invalid or resource labels have changed"
620
+ # We should retry until the labels are set successfully.
621
+ @_retry_on_gcp_http_exception('Labels fingerprint either invalid')
591
622
  def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
592
623
  labels: dict) -> None:
593
624
  node = cls.load_resource().instances().get(
@@ -606,7 +637,7 @@ class GCPComputeInstance(GCPInstance):
606
637
  body=body,
607
638
  ).execute(num_retries=GCP_CREATE_MAX_RETRIES))
608
639
 
609
- cls.wait_for_operation(operation, project_id, availability_zone)
640
+ cls.wait_for_operation(operation, project_id, zone=availability_zone)
610
641
 
611
642
  @classmethod
612
643
  def create_instances(
@@ -617,6 +648,7 @@ class GCPComputeInstance(GCPInstance):
617
648
  node_config: dict,
618
649
  labels: dict,
619
650
  count: int,
651
+ total_count: int,
620
652
  include_head_node: bool,
621
653
  ) -> Tuple[Optional[List], List[str]]:
622
654
  # NOTE: The syntax for bulkInsert() is different from insert().
@@ -643,8 +675,8 @@ class GCPComputeInstance(GCPInstance):
643
675
  config.update({
644
676
  'labels': dict(
645
677
  labels, **{
646
- TAG_RAY_CLUSTER_NAME: cluster_name,
647
- TAG_SKYPILOT_CLUSTER_NAME: cluster_name
678
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name,
679
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name
648
680
  }),
649
681
  })
650
682
 
@@ -739,6 +771,19 @@ class GCPComputeInstance(GCPInstance):
739
771
  logger.debug('"insert" operation requested ...')
740
772
  return operations
741
773
 
774
+ @classmethod
775
+ def _convert_selflinks_in_config(cls, config: dict) -> None:
776
+ """Convert selflinks to names in the config."""
777
+ for disk in config.get('disks', []):
778
+ disk_type = disk.get('initializeParams', {}).get('diskType')
779
+ if disk_type is not None:
780
+ disk['initializeParams']['diskType'] = selflink_to_name(
781
+ disk_type)
782
+ config['machineType'] = selflink_to_name(config['machineType'])
783
+ for accelerator in config.get('guestAccelerators', []):
784
+ accelerator['acceleratorType'] = selflink_to_name(
785
+ accelerator['acceleratorType'])
786
+
742
787
  @classmethod
743
788
  def _bulk_insert(cls, names: List[str], project_id: str, zone: str,
744
789
  config: dict) -> List[dict]:
@@ -752,15 +797,7 @@ class GCPComputeInstance(GCPInstance):
752
797
  k: v for d in config['scheduling'] for k, v in d.items()
753
798
  }
754
799
 
755
- for disk in config.get('disks', []):
756
- disk_type = disk.get('initializeParams', {}).get('diskType')
757
- if disk_type is not None:
758
- disk['initializeParams']['diskType'] = selflink_to_name(
759
- disk_type)
760
- config['machineType'] = selflink_to_name(config['machineType'])
761
- for accelerator in config.get('guestAccelerators', []):
762
- accelerator['acceleratorType'] = selflink_to_name(
763
- accelerator['acceleratorType'])
800
+ cls._convert_selflinks_in_config(config)
764
801
 
765
802
  body = {
766
803
  'count': len(names),
@@ -819,7 +856,7 @@ class GCPComputeInstance(GCPInstance):
819
856
  })
820
857
  logger.debug(
821
858
  f'create_instances: googleapiclient.errors.HttpError: {e}')
822
- _log_errors(errors, e, zone)
859
+ _format_and_log_message_from_errors(errors, e, zone)
823
860
  return errors
824
861
 
825
862
  # Allow Google Compute Engine instance templates.
@@ -849,13 +886,13 @@ class GCPComputeInstance(GCPInstance):
849
886
  if errors:
850
887
  logger.debug('create_instances: Failed to create instances. '
851
888
  f'Reason: {errors}')
852
- _log_errors(errors, operations, zone)
889
+ _format_and_log_message_from_errors(errors, operations, zone)
853
890
  return errors
854
891
 
855
892
  logger.debug('Waiting GCP instances to be ready ...')
856
893
  try:
857
894
  for operation in operations:
858
- cls.wait_for_operation(operation, project_id, zone)
895
+ cls.wait_for_operation(operation, project_id, zone=zone)
859
896
  except common.ProvisionerError as e:
860
897
  return e.errors
861
898
  except gcp.http_error_exception() as e:
@@ -876,7 +913,7 @@ class GCPComputeInstance(GCPInstance):
876
913
  instance=node_id,
877
914
  ).execute())
878
915
 
879
- cls.wait_for_operation(operation, project_id, zone)
916
+ cls.wait_for_operation(operation, project_id, zone=zone)
880
917
 
881
918
  @classmethod
882
919
  def get_instance_info(cls, project_id: str, availability_zone: str,
@@ -935,7 +972,220 @@ class GCPComputeInstance(GCPInstance):
935
972
  logger.warning(f'googleapiclient.errors.HttpError: {e.reason}')
936
973
  return
937
974
 
938
- cls.wait_for_operation(operation, project_id, availability_zone)
975
+ cls.wait_for_operation(operation, project_id, zone=availability_zone)
976
+
977
+
978
+ class GCPManagedInstanceGroup(GCPComputeInstance):
979
+ """Handler for GCP Managed Instance Group."""
980
+
981
+ @classmethod
982
+ def create_instances(
983
+ cls,
984
+ cluster_name: str,
985
+ project_id: str,
986
+ zone: str,
987
+ node_config: dict,
988
+ labels: dict,
989
+ count: int,
990
+ total_count: int,
991
+ include_head_node: bool,
992
+ ) -> Tuple[Optional[List], List[str]]:
993
+ logger.debug(f'Creating cluster with MIG: {cluster_name!r}')
994
+ config = copy.deepcopy(node_config)
995
+ labels = dict(config.get('labels', {}), **labels)
996
+
997
+ config.update({
998
+ 'labels': dict(
999
+ labels,
1000
+ **{
1001
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name,
1002
+ # Assume all nodes are workers, we can update the head node
1003
+ # once the instances are created.
1004
+ **provision_constants.WORKER_NODE_TAGS,
1005
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name,
1006
+ }),
1007
+ })
1008
+ cls._convert_selflinks_in_config(config)
1009
+
1010
+ # Convert label values to string and lowercase per MIG API requirement.
1011
+ region = zone.rpartition('-')[0]
1012
+ instance_template_name = mig_utils.get_instance_template_name(
1013
+ cluster_name)
1014
+ managed_instance_group_name = mig_utils.get_managed_instance_group_name(
1015
+ cluster_name)
1016
+
1017
+ instance_template_exists = mig_utils.check_instance_template_exits(
1018
+ project_id, region, instance_template_name)
1019
+ mig_exists = mig_utils.check_managed_instance_group_exists(
1020
+ project_id, zone, managed_instance_group_name)
1021
+
1022
+ label_filters = {
1023
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name,
1024
+ }
1025
+ potential_head_instances = []
1026
+ if mig_exists:
1027
+ instances = cls.filter(
1028
+ project_id,
1029
+ zone,
1030
+ label_filters={
1031
+ provision_constants.TAG_RAY_NODE_KIND: 'head',
1032
+ **label_filters,
1033
+ },
1034
+ status_filters=cls.NEED_TO_TERMINATE_STATES)
1035
+ potential_head_instances = list(instances.keys())
1036
+
1037
+ config['labels'] = {
1038
+ k: str(v).lower() for k, v in config['labels'].items()
1039
+ }
1040
+ if instance_template_exists:
1041
+ if mig_exists:
1042
+ logger.debug(
1043
+ f'Instance template {instance_template_name} already '
1044
+ 'exists. Skip creating it.')
1045
+ else:
1046
+ logger.debug(
1047
+ f'Instance template {instance_template_name!r} '
1048
+ 'exists and no instance group is using it. This is a '
1049
+ 'leftover of a previous autodown. Delete it and recreate '
1050
+ 'it.')
1051
+ # TODO(zhwu): this is a bit hacky as we cannot delete instance
1052
+ # template during an autodown, we can only defer the deletion
1053
+ # to the next launch of a cluster with the same name. We should
1054
+ # find a better way to handle this.
1055
+ cls._delete_instance_template(project_id, zone,
1056
+ instance_template_name)
1057
+ instance_template_exists = False
1058
+
1059
+ if not instance_template_exists:
1060
+ operation = mig_utils.create_region_instance_template(
1061
+ cluster_name, project_id, region, instance_template_name,
1062
+ config)
1063
+ cls.wait_for_operation(operation, project_id, region=region)
1064
+ # create managed instance group
1065
+ instance_template_url = (f'projects/{project_id}/regions/{region}/'
1066
+ f'instanceTemplates/{instance_template_name}')
1067
+ if not mig_exists:
1068
+ # Create a new MIG with size 0 and resize it later for triggering
1069
+ # DWS, according to the doc: https://cloud.google.com/compute/docs/instance-groups/create-mig-with-gpu-vms # pylint: disable=line-too-long
1070
+ operation = mig_utils.create_managed_instance_group(
1071
+ project_id,
1072
+ zone,
1073
+ managed_instance_group_name,
1074
+ instance_template_url,
1075
+ size=0)
1076
+ cls.wait_for_operation(operation, project_id, zone=zone)
1077
+
1078
+ managed_instance_group_config = config[
1079
+ constants.MANAGED_INSTANCE_GROUP_CONFIG]
1080
+ if count > 0:
1081
+ # Use resize to trigger DWS for creating VMs.
1082
+ operation = mig_utils.resize_managed_instance_group(
1083
+ project_id,
1084
+ zone,
1085
+ managed_instance_group_name,
1086
+ count,
1087
+ run_duration=managed_instance_group_config['run_duration'])
1088
+ cls.wait_for_operation(operation, project_id, zone=zone)
1089
+
1090
+ # This will block the provisioning until the nodes are ready, which
1091
+ # makes the failover not effective. We rely on the request timeout set
1092
+ # by user to trigger failover.
1093
+ mig_utils.wait_for_managed_group_to_be_stable(
1094
+ project_id,
1095
+ zone,
1096
+ managed_instance_group_name,
1097
+ timeout=managed_instance_group_config.get(
1098
+ 'provision_timeout',
1099
+ constants.DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT))
1100
+
1101
+ pending_running_instance_names = cls._add_labels_and_find_head(
1102
+ cluster_name, project_id, zone, labels, potential_head_instances)
1103
+ assert len(pending_running_instance_names) == total_count, (
1104
+ pending_running_instance_names, total_count)
1105
+ cls.create_node_tag(
1106
+ project_id,
1107
+ zone,
1108
+ pending_running_instance_names[0],
1109
+ is_head=True,
1110
+ )
1111
+ return None, pending_running_instance_names
1112
+
1113
+ @classmethod
1114
+ def _delete_instance_template(cls, project_id: str, zone: str,
1115
+ instance_template_name: str) -> None:
1116
+ logger.debug(f'Deleting instance template {instance_template_name}...')
1117
+ region = zone.rpartition('-')[0]
1118
+ try:
1119
+ operation = cls.load_resource().regionInstanceTemplates().delete(
1120
+ project=project_id,
1121
+ region=region,
1122
+ instanceTemplate=instance_template_name).execute()
1123
+ cls.wait_for_operation(operation, project_id, region=region)
1124
+ except gcp.http_error_exception() as e:
1125
+ if re.search(mig_utils.IT_RESOURCE_NOT_FOUND_PATTERN,
1126
+ str(e)) is None:
1127
+ raise
1128
+ logger.warning(
1129
+ f'Instance template {instance_template_name!r} does not exist. '
1130
+ 'Skip deletion.')
1131
+
1132
+ @classmethod
1133
+ def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None:
1134
+ mig_name = mig_utils.get_managed_instance_group_name(cluster_name)
1135
+ # Get all resize request of the MIG and cancel them.
1136
+ mig_utils.cancel_all_resize_request_for_mig(project_id, zone, mig_name)
1137
+ logger.debug(f'Deleting MIG {mig_name!r} ...')
1138
+ try:
1139
+ operation = cls.load_resource().instanceGroupManagers().delete(
1140
+ project=project_id, zone=zone,
1141
+ instanceGroupManager=mig_name).execute()
1142
+ cls.wait_for_operation(operation, project_id, zone=zone)
1143
+ except gcp.http_error_exception() as e:
1144
+ if re.search(mig_utils.MIG_RESOURCE_NOT_FOUND_PATTERN,
1145
+ str(e)) is None:
1146
+ raise
1147
+ logger.warning(f'MIG {mig_name!r} does not exist. Skip '
1148
+ 'deletion.')
1149
+
1150
+ # In the autostop case, the following deletion of instance template
1151
+ # will not be executed as the instance that runs the deletion will be
1152
+ # terminated with the managed instance group. It is ok to leave the
1153
+ # instance template there as when a user creates a new cluster with the
1154
+ # same name, the instance template will be updated in our
1155
+ # create_instances method.
1156
+ cls._delete_instance_template(
1157
+ project_id, zone,
1158
+ mig_utils.get_instance_template_name(cluster_name))
1159
+
1160
+ @classmethod
1161
+ def _add_labels_and_find_head(
1162
+ cls, cluster_name: str, project_id: str, zone: str,
1163
+ labels: Dict[str, str],
1164
+ potential_head_instances: List[str]) -> List[str]:
1165
+ pending_running_instances = cls.filter(
1166
+ project_id,
1167
+ zone,
1168
+ {provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name},
1169
+ # Find all provisioning and running instances.
1170
+ status_filters=cls.NEED_TO_STOP_STATES)
1171
+ for running_instance_name in pending_running_instances.keys():
1172
+ if running_instance_name in potential_head_instances:
1173
+ head_instance_name = running_instance_name
1174
+ break
1175
+ else:
1176
+ head_instance_name = list(pending_running_instances.keys())[0]
1177
+ # We need to update the node's label if mig already exists, as the
1178
+ # config is not updated during the resize operation.
1179
+ for instance_name in pending_running_instances.keys():
1180
+ cls.set_labels(project_id=project_id,
1181
+ availability_zone=zone,
1182
+ node_id=instance_name,
1183
+ labels=labels)
1184
+
1185
+ pending_running_instance_names = list(pending_running_instances.keys())
1186
+ pending_running_instance_names.remove(head_instance_name)
1187
+ # Label for head node type will be set by caller
1188
+ return [head_instance_name] + pending_running_instance_names
939
1189
 
940
1190
 
941
1191
  class GCPTPUVMInstance(GCPInstance):
@@ -959,12 +1209,15 @@ class GCPTPUVMInstance(GCPInstance):
959
1209
  discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest')
960
1210
 
961
1211
  @classmethod
962
- def wait_for_operation(cls, operation: dict, project_id: str,
963
- zone: Optional[str]) -> None:
1212
+ def wait_for_operation(cls,
1213
+ operation: dict,
1214
+ project_id: str,
1215
+ region: Optional[str] = None,
1216
+ zone: Optional[str] = None) -> None:
964
1217
  """Poll for TPU operation until finished."""
965
- del project_id, zone # unused
1218
+ del project_id, region, zone # unused
966
1219
 
967
- @_retry_on_http_exception(
1220
+ @_retry_on_gcp_http_exception(
968
1221
  f'Failed to wait for operation {operation["name"]}')
969
1222
  def call_operation(fn, timeout: int):
970
1223
  request = fn(name=operation['name'])
@@ -1132,7 +1385,7 @@ class GCPTPUVMInstance(GCPInstance):
1132
1385
  f'Failed to get VPC name for instance {instance}') from e
1133
1386
 
1134
1387
  @classmethod
1135
- @_retry_on_http_exception('unable to queue the operation')
1388
+ @_retry_on_gcp_http_exception('unable to queue the operation')
1136
1389
  def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
1137
1390
  labels: dict) -> None:
1138
1391
  while True:
@@ -1176,6 +1429,7 @@ class GCPTPUVMInstance(GCPInstance):
1176
1429
  node_config: dict,
1177
1430
  labels: dict,
1178
1431
  count: int,
1432
+ total_count: int,
1179
1433
  include_head_node: bool,
1180
1434
  ) -> Tuple[Optional[List], List[str]]:
1181
1435
  config = copy.deepcopy(node_config)
@@ -1198,8 +1452,8 @@ class GCPTPUVMInstance(GCPInstance):
1198
1452
  config.update({
1199
1453
  'labels': dict(
1200
1454
  labels, **{
1201
- TAG_RAY_CLUSTER_NAME: cluster_name,
1202
- TAG_SKYPILOT_CLUSTER_NAME: cluster_name
1455
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name,
1456
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name
1203
1457
  }),
1204
1458
  })
1205
1459
 
@@ -1225,11 +1479,10 @@ class GCPTPUVMInstance(GCPInstance):
1225
1479
  for i, name in enumerate(names):
1226
1480
  node_config = config.copy()
1227
1481
  if i == 0:
1228
- node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '1'
1229
- node_config['labels'][TAG_RAY_NODE_KIND] = 'head'
1482
+ node_config['labels'].update(provision_constants.HEAD_NODE_TAGS)
1230
1483
  else:
1231
- node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '0'
1232
- node_config['labels'][TAG_RAY_NODE_KIND] = 'worker'
1484
+ node_config['labels'].update(
1485
+ provision_constants.WORKER_NODE_TAGS)
1233
1486
  try:
1234
1487
  logger.debug('Launching GCP TPU VM ...')
1235
1488
  request = (
@@ -1257,7 +1510,7 @@ class GCPTPUVMInstance(GCPInstance):
1257
1510
  'domain': 'create_instances',
1258
1511
  'message': error_details,
1259
1512
  })
1260
- _log_errors(errors, e, zone)
1513
+ _format_and_log_message_from_errors(errors, e, zone)
1261
1514
  return errors, names
1262
1515
  for detail in error_details:
1263
1516
  # To be consistent with error messages returned by operation
@@ -1276,7 +1529,7 @@ class GCPTPUVMInstance(GCPInstance):
1276
1529
  'domain': violation.get('subject'),
1277
1530
  'message': violation.get('description'),
1278
1531
  })
1279
- _log_errors(errors, e, zone)
1532
+ _format_and_log_message_from_errors(errors, e, zone)
1280
1533
  return errors, names
1281
1534
  errors = []
1282
1535
  for operation in operations:
@@ -1294,7 +1547,7 @@ class GCPTPUVMInstance(GCPInstance):
1294
1547
  if errors:
1295
1548
  logger.debug('create_instances: Failed to create instances. '
1296
1549
  f'Reason: {errors}')
1297
- _log_errors(errors, operations, zone)
1550
+ _format_and_log_message_from_errors(errors, operations, zone)
1298
1551
  return errors, names
1299
1552
 
1300
1553
  logger.debug('Waiting GCP instances to be ready ...')
@@ -1336,7 +1589,7 @@ class GCPTPUVMInstance(GCPInstance):
1336
1589
  'message': 'Timeout waiting for creation operation',
1337
1590
  'domain': 'create_instances'
1338
1591
  }]
1339
- _log_errors(errors, None, zone)
1592
+ _format_and_log_message_from_errors(errors, None, zone)
1340
1593
  return errors, names
1341
1594
 
1342
1595
  # NOTE: Error example:
@@ -1353,7 +1606,7 @@ class GCPTPUVMInstance(GCPInstance):
1353
1606
  logger.debug(
1354
1607
  'create_instances: Failed to create instances. Reason: '
1355
1608
  f'{errors}')
1356
- _log_errors(errors, results, zone)
1609
+ _format_and_log_message_from_errors(errors, results, zone)
1357
1610
  return errors, names
1358
1611
  assert all(success), (
1359
1612
  'Failed to create instances, but there is no error. '
@@ -1406,10 +1659,11 @@ class GCPNodeType(enum.Enum):
1406
1659
  """Enum for GCP node types (compute & tpu)"""
1407
1660
 
1408
1661
  COMPUTE = 'compute'
1662
+ MIG = 'mig'
1409
1663
  TPU = 'tpu'
1410
1664
 
1411
1665
 
1412
- def get_node_type(node: dict) -> GCPNodeType:
1666
+ def get_node_type(config: Dict[str, Any]) -> GCPNodeType:
1413
1667
  """Returns node type based on the keys in ``node``.
1414
1668
 
1415
1669
  This is a very simple check. If we have a ``machineType`` key,
@@ -1419,17 +1673,22 @@ def get_node_type(node: dict) -> GCPNodeType:
1419
1673
 
1420
1674
  This works for both node configs and API returned nodes.
1421
1675
  """
1422
-
1423
- if 'machineType' not in node and 'acceleratorType' not in node:
1676
+ if ('machineType' not in config and 'acceleratorType' not in config):
1424
1677
  raise ValueError(
1425
1678
  'Invalid node. For a Compute instance, "machineType" is '
1426
1679
  'required. '
1427
1680
  'For a TPU instance, "acceleratorType" and no "machineType" '
1428
1681
  'is required. '
1429
- f'Got {list(node)}')
1682
+ f'Got {list(config)}')
1430
1683
 
1431
- if 'machineType' not in node and 'acceleratorType' in node:
1684
+ if 'machineType' not in config and 'acceleratorType' in config:
1432
1685
  return GCPNodeType.TPU
1686
+
1687
+ if (config.get(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) is not None
1688
+ and config.get('guestAccelerators', None) is not None):
1689
+ # DWS in MIG only works for machine with GPUs.
1690
+ return GCPNodeType.MIG
1691
+
1433
1692
  return GCPNodeType.COMPUTE
1434
1693
 
1435
1694
 
@@ -1475,7 +1734,7 @@ def create_tpu_node(project_id: str, zone: str, tpu_node_config: Dict[str, str],
1475
1734
  'https://console.cloud.google.com/iam-admin/quotas '
1476
1735
  'for more information.'
1477
1736
  }]
1478
- _log_errors(provisioner_err.errors, e, zone)
1737
+ _format_and_log_message_from_errors(provisioner_err.errors, e, zone)
1479
1738
  raise provisioner_err from e
1480
1739
 
1481
1740
  if 'PERMISSION_DENIED' in stderr:
@@ -1484,7 +1743,7 @@ def create_tpu_node(project_id: str, zone: str, tpu_node_config: Dict[str, str],
1484
1743
  'domain': 'tpu',
1485
1744
  'message': 'TPUs are not available in this zone.'
1486
1745
  }]
1487
- _log_errors(provisioner_err.errors, e, zone)
1746
+ _format_and_log_message_from_errors(provisioner_err.errors, e, zone)
1488
1747
  raise provisioner_err from e
1489
1748
 
1490
1749
  if 'no more capacity in the zone' in stderr:
@@ -1493,7 +1752,7 @@ def create_tpu_node(project_id: str, zone: str, tpu_node_config: Dict[str, str],
1493
1752
  'domain': 'tpu',
1494
1753
  'message': 'No more capacity in this zone.'
1495
1754
  }]
1496
- _log_errors(provisioner_err.errors, e, zone)
1755
+ _format_and_log_message_from_errors(provisioner_err.errors, e, zone)
1497
1756
  raise provisioner_err from e
1498
1757
 
1499
1758
  if 'CloudTpu received an invalid AcceleratorType' in stderr:
@@ -1506,7 +1765,7 @@ def create_tpu_node(project_id: str, zone: str, tpu_node_config: Dict[str, str],
1506
1765
  'message': (f'TPU type {tpu_type} is not available in this '
1507
1766
  f'zone {zone}.')
1508
1767
  }]
1509
- _log_errors(provisioner_err.errors, e, zone)
1768
+ _format_and_log_message_from_errors(provisioner_err.errors, e, zone)
1510
1769
  raise provisioner_err from e
1511
1770
 
1512
1771
  # TODO(zhwu): Add more error code handling, if needed.
@@ -1515,7 +1774,7 @@ def create_tpu_node(project_id: str, zone: str, tpu_node_config: Dict[str, str],
1515
1774
  'domain': 'tpu',
1516
1775
  'message': stderr
1517
1776
  }]
1518
- _log_errors(provisioner_err.errors, e, zone)
1777
+ _format_and_log_message_from_errors(provisioner_err.errors, e, zone)
1519
1778
  raise provisioner_err from e
1520
1779
 
1521
1780