skypilot-nightly 1.0.0.dev20241109__py3-none-any.whl → 1.0.0.dev20241110__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,430 @@
1
+ """OCI instance provisioning.
2
+
3
+ History:
4
+ - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Initial implementation
5
+ """
6
+
7
+ import copy
8
+ from datetime import datetime
9
+ import time
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ from sky import exceptions
13
+ from sky import sky_logging
14
+ from sky import status_lib
15
+ from sky.adaptors import oci as oci_adaptor
16
+ from sky.clouds.utils import oci_utils
17
+ from sky.provision import common
18
+ from sky.provision import constants
19
+ from sky.provision.oci import query_utils
20
+ from sky.provision.oci.query_utils import query_helper
21
+ from sky.utils import common_utils
22
+ from sky.utils import ux_utils
23
+
24
+ logger = sky_logging.init_logger(__name__)
25
+
26
+
27
+ @query_utils.debug_enabled(logger)
28
+ @common_utils.retry
29
+ def query_instances(
30
+ cluster_name_on_cloud: str,
31
+ provider_config: Optional[Dict[str, Any]] = None,
32
+ non_terminated_only: bool = True,
33
+ ) -> Dict[str, Optional[status_lib.ClusterStatus]]:
34
+ """Query instances.
35
+
36
+ Returns a dictionary of instance IDs and status.
37
+
38
+ A None status means the instance is marked as "terminated"
39
+ or "terminating".
40
+ """
41
+ assert provider_config is not None, cluster_name_on_cloud
42
+ region = provider_config['region']
43
+
44
+ status_map = oci_utils.oci_config.STATE_MAPPING_OCI_TO_SKY
45
+ statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {}
46
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
47
+
48
+ instances = _get_filtered_nodes(region, filters)
49
+ for node in instances:
50
+ vm_status = node['status']
51
+ sky_status = status_map[vm_status]
52
+ if non_terminated_only and sky_status is None:
53
+ continue
54
+ statuses[node['inst_id']] = sky_status
55
+
56
+ return statuses
57
+
58
+
59
+ @query_utils.debug_enabled(logger)
60
+ def run_instances(region: str, cluster_name_on_cloud: str,
61
+ config: common.ProvisionConfig) -> common.ProvisionRecord:
62
+ """Start instances with bootstrapped configuration."""
63
+ tags = dict(sorted(copy.deepcopy(config.tags).items()))
64
+
65
+ start_time = round(time.time() * 1000)
66
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
67
+
68
+ # Starting stopped nodes if resume_stopped_nodes=True
69
+ resume_instances = []
70
+ if config.resume_stopped_nodes:
71
+ logger.debug('Checking existing stopped nodes.')
72
+
73
+ existing_instances = _get_filtered_nodes(region, filters)
74
+ if len(existing_instances) > config.count:
75
+ raise RuntimeError(
76
+ 'The number of pending/running/stopped/stopping '
77
+ f'instances combined ({len(existing_instances)}) in '
78
+ f'cluster "{cluster_name_on_cloud}" is greater than the '
79
+ f'number requested by the user ({config.count}). '
80
+ 'This is likely a resource leak. '
81
+ 'Use "sky down" to terminate the cluster.')
82
+
83
+ # pylint: disable=line-too-long
84
+ logger.debug(
85
+ f'run_instances: Found {[inst["name"] for inst in existing_instances]} '
86
+ 'existing instances in cluster.')
87
+ existing_instances.sort(key=lambda x: x['name'])
88
+
89
+ stopped_instances = []
90
+ for existing_node in existing_instances:
91
+ if existing_node['status'] == 'STOPPING':
92
+ query_helper.wait_instance_until_status(
93
+ region, existing_node['inst_id'], 'STOPPED')
94
+ stopped_instances.append(existing_node)
95
+ elif existing_node['status'] == 'STOPPED':
96
+ stopped_instances.append(existing_node)
97
+ elif existing_node['status'] in ('PROVISIONING', 'STARTING',
98
+ 'RUNNING'):
99
+ resume_instances.append(existing_node)
100
+
101
+ for stopped_node in stopped_instances:
102
+ stopped_node_id = stopped_node['inst_id']
103
+ instance_action_response = query_helper.start_instance(
104
+ region, stopped_node_id)
105
+
106
+ starting_inst = instance_action_response.data
107
+ resume_instances.append({
108
+ 'inst_id': starting_inst.id,
109
+ 'name': starting_inst.display_name,
110
+ 'ad': starting_inst.availability_domain,
111
+ 'compartment': starting_inst.compartment_id,
112
+ 'status': starting_inst.lifecycle_state,
113
+ 'oci_tags': starting_inst.freeform_tags,
114
+ })
115
+ # end if config.resume_stopped_nodes
116
+
117
+ # Try get head id from the existing instances
118
+ head_instance_id = _get_head_instance_id(resume_instances)
119
+ logger.debug(f'Check existing head node: {head_instance_id}')
120
+
121
+ # Let's create additional new nodes (if neccessary)
122
+ to_start_count = config.count - len(resume_instances)
123
+ created_instances = []
124
+ if to_start_count > 0:
125
+ node_config = config.node_config
126
+ compartment = query_helper.find_compartment(region)
127
+ vcn = query_helper.find_create_vcn_subnet(region)
128
+
129
+ ocpu_count = 0
130
+ vcpu_str = node_config['VCPUs']
131
+ instance_type_str = node_config['InstanceType']
132
+
133
+ if vcpu_str is not None and vcpu_str != 'None':
134
+ if instance_type_str.startswith(
135
+ f'{oci_utils.oci_config.VM_PREFIX}.A'):
136
+ # For ARM cpu, 1*ocpu = 1*vcpu
137
+ ocpu_count = round(float(vcpu_str))
138
+ else:
139
+ # For Intel / AMD cpu, 1*ocpu = 2*vcpu
140
+ ocpu_count = round(float(vcpu_str) / 2)
141
+ ocpu_count = 1 if (ocpu_count > 0 and ocpu_count < 1) else ocpu_count
142
+
143
+ machine_shape_config = None
144
+ if ocpu_count > 0:
145
+ mem = node_config['MemoryInGbs']
146
+ if mem is not None and mem != 'None':
147
+ # pylint: disable=line-too-long
148
+ machine_shape_config = oci_adaptor.oci.core.models.LaunchInstanceShapeConfigDetails(
149
+ ocpus=ocpu_count, memory_in_gbs=mem)
150
+ else:
151
+ # pylint: disable=line-too-long
152
+ machine_shape_config = oci_adaptor.oci.core.models.LaunchInstanceShapeConfigDetails(
153
+ ocpus=ocpu_count)
154
+
155
+ preempitible_config = (
156
+ oci_adaptor.oci.core.models.PreemptibleInstanceConfigDetails(
157
+ preemption_action=oci_adaptor.oci.core.models.
158
+ TerminatePreemptionAction(type='TERMINATE',
159
+ preserve_boot_volume=False))
160
+ if node_config['Preemptible'] else None)
161
+
162
+ batch_id = datetime.now().strftime('%Y%m%d%H%M%S')
163
+
164
+ vm_tags_head = {
165
+ **tags,
166
+ **constants.HEAD_NODE_TAGS,
167
+ constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
168
+ 'sky_spot_flag': str(node_config['Preemptible']).lower(),
169
+ }
170
+ vm_tags_worker = {
171
+ **tags,
172
+ **constants.WORKER_NODE_TAGS,
173
+ constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
174
+ 'sky_spot_flag': str(node_config['Preemptible']).lower(),
175
+ }
176
+
177
+ for seq in range(1, to_start_count + 1):
178
+ if head_instance_id is None:
179
+ vm_tags = vm_tags_head
180
+ node_type = constants.HEAD_NODE_TAGS[
181
+ constants.TAG_RAY_NODE_KIND]
182
+ else:
183
+ vm_tags = vm_tags_worker
184
+ node_type = constants.WORKER_NODE_TAGS[
185
+ constants.TAG_RAY_NODE_KIND]
186
+
187
+ launch_instance_response = query_helper.launch_instance(
188
+ region,
189
+ oci_adaptor.oci.core.models.LaunchInstanceDetails(
190
+ availability_domain=node_config['AvailabilityDomain'],
191
+ compartment_id=compartment,
192
+ shape=instance_type_str,
193
+ display_name=
194
+ f'{cluster_name_on_cloud}_{node_type}_{batch_id}_{seq}',
195
+ freeform_tags=vm_tags,
196
+ metadata={
197
+ 'ssh_authorized_keys': node_config['AuthorizedKey']
198
+ },
199
+ source_details=oci_adaptor.oci.core.models.
200
+ InstanceSourceViaImageDetails(
201
+ source_type='image',
202
+ image_id=node_config['ImageId'],
203
+ boot_volume_size_in_gbs=node_config['BootVolumeSize'],
204
+ boot_volume_vpus_per_gb=int(
205
+ node_config['BootVolumePerf']),
206
+ ),
207
+ create_vnic_details=oci_adaptor.oci.core.models.
208
+ CreateVnicDetails(
209
+ assign_public_ip=True,
210
+ subnet_id=vcn,
211
+ ),
212
+ shape_config=machine_shape_config,
213
+ preemptible_instance_config=preempitible_config,
214
+ ))
215
+
216
+ new_inst = launch_instance_response.data
217
+ if head_instance_id is None:
218
+ head_instance_id = new_inst.id
219
+ logger.debug(f'New head node: {head_instance_id}')
220
+
221
+ created_instances.append({
222
+ 'inst_id': new_inst.id,
223
+ 'name': new_inst.display_name,
224
+ 'ad': new_inst.availability_domain,
225
+ 'compartment': new_inst.compartment_id,
226
+ 'status': new_inst.lifecycle_state,
227
+ 'oci_tags': new_inst.freeform_tags,
228
+ })
229
+ # end for loop
230
+ # end if to_start_count > 0:...
231
+
232
+ for inst in (resume_instances + created_instances):
233
+ logger.debug(f'Provisioning for node {inst["name"]}')
234
+ query_helper.wait_instance_until_status(region, inst['inst_id'],
235
+ 'RUNNING')
236
+ logger.debug(f'Instance {inst["name"]} is RUNNING.')
237
+
238
+ total_time = round(time.time() * 1000) - start_time
239
+ logger.debug('Total time elapsed: {0} milli-seconds.'.format(total_time))
240
+
241
+ assert head_instance_id is not None, head_instance_id
242
+
243
+ return common.ProvisionRecord(
244
+ provider_name='oci',
245
+ region=region,
246
+ zone=None,
247
+ cluster_name=cluster_name_on_cloud,
248
+ head_instance_id=head_instance_id,
249
+ created_instance_ids=[n['inst_id'] for n in created_instances],
250
+ resumed_instance_ids=[n['inst_id'] for n in resume_instances],
251
+ )
252
+
253
+
254
+ @query_utils.debug_enabled(logger)
255
+ def stop_instances(
256
+ cluster_name_on_cloud: str,
257
+ provider_config: Dict[str, Any],
258
+ worker_only: bool = False,
259
+ ) -> None:
260
+ """Stop running instances."""
261
+ # pylint: disable=line-too-long
262
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
263
+
264
+ region = provider_config['region']
265
+ tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
266
+ if worker_only:
267
+ tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker'
268
+
269
+ nodes = _get_filtered_nodes(region, tag_filters)
270
+ for node in nodes:
271
+ query_helper.stop_instance(region, node['inst_id'])
272
+
273
+
274
+ @query_utils.debug_enabled(logger)
275
+ def terminate_instances(
276
+ cluster_name_on_cloud: str,
277
+ provider_config: Dict[str, Any],
278
+ worker_only: bool = False,
279
+ ) -> None:
280
+ """Terminate running or stopped instances."""
281
+ region = provider_config['region']
282
+ tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
283
+ if worker_only:
284
+ tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker'
285
+ query_helper.terminate_instances_by_tags(tag_filters, region)
286
+
287
+
288
+ @query_utils.debug_enabled(logger)
289
+ def open_ports(
290
+ cluster_name_on_cloud: str,
291
+ ports: List[str],
292
+ provider_config: Optional[Dict[str, Any]] = None,
293
+ ) -> None:
294
+ """Open ports for inbound traffic."""
295
+ # OCI ports in security groups are opened while creating the new
296
+ # VCN (skypilot_vcn). If user configure to use existing VCN, it is
297
+ # intended to let user to manage the ports instead of automatically
298
+ # opening ports here.
299
+ del cluster_name_on_cloud, ports, provider_config
300
+
301
+
302
+ @query_utils.debug_enabled(logger)
303
+ def cleanup_ports(
304
+ cluster_name_on_cloud: str,
305
+ ports: List[str],
306
+ provider_config: Optional[Dict[str, Any]] = None,
307
+ ) -> None:
308
+ """Delete any opened ports."""
309
+ del cluster_name_on_cloud, ports, provider_config
310
+ # OCI ports in security groups are opened while creating the new
311
+ # VCN (skypilot_vcn). The VCN will only be created at the first
312
+ # time when it is not existed. We'll not automatically delete the
313
+ # VCN while teardown clusters. it is intended to let user to decide
314
+ # to delete the VCN or not from OCI console, for example.
315
+
316
+
317
+ @query_utils.debug_enabled(logger)
318
+ def wait_instances(region: str, cluster_name_on_cloud: str,
319
+ state: Optional[status_lib.ClusterStatus]) -> None:
320
+ del region, cluster_name_on_cloud, state
321
+ # We already wait for the instances to be running in run_instances.
322
+ # We can not implement the wait logic here because the provisioning
323
+ # instances are not retrieveable by the QL 'query instance resources ...'.
324
+
325
+
326
+ @query_utils.debug_enabled(logger)
327
+ def get_cluster_info(
328
+ region: str,
329
+ cluster_name_on_cloud: str,
330
+ provider_config: Optional[Dict[str, Any]] = None,
331
+ ) -> common.ClusterInfo:
332
+ """Get the metadata of instances in a cluster."""
333
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
334
+ running_instances = _get_filtered_nodes(region, filters)
335
+
336
+ instances = {}
337
+ for running_instance in running_instances:
338
+ inst = _get_inst_obj_with_ip(region, running_instance)
339
+ instances[inst['id']] = [
340
+ common.InstanceInfo(
341
+ instance_id=inst['id'],
342
+ internal_ip=inst['internal_ip'],
343
+ external_ip=inst['external_ip'],
344
+ tags=inst['tags'],
345
+ )
346
+ ]
347
+
348
+ instances = dict(sorted(instances.items(), key=lambda x: x[0]))
349
+ logger.debug(f'Cluster info: {instances}')
350
+
351
+ head_instance_id = _get_head_instance_id(running_instances)
352
+ logger.debug(f'Head instance id is {head_instance_id}')
353
+
354
+ return common.ClusterInfo(
355
+ provider_name='oci',
356
+ head_instance_id=head_instance_id,
357
+ instances=instances,
358
+ provider_config=provider_config,
359
+ )
360
+
361
+
362
+ def _get_filtered_nodes(region: str,
363
+ tag_filters: Dict[str, str]) -> List[Dict[str, Any]]:
364
+ return_nodes = []
365
+
366
+ try:
367
+ insts = query_helper.query_instances_by_tags(tag_filters, region)
368
+ except oci_adaptor.oci.exceptions.ServiceError as e:
369
+ with ux_utils.print_exception_no_traceback():
370
+ raise exceptions.ClusterStatusFetchingError(
371
+ f'Failed to query status for OCI cluster {tag_filters}.'
372
+ 'Details: '
373
+ f'{common_utils.format_exception(e, use_bracket=True)}')
374
+
375
+ for inst in insts:
376
+ inst_id = inst.identifier
377
+ return_nodes.append({
378
+ 'inst_id': inst_id,
379
+ 'name': inst.display_name,
380
+ 'ad': inst.availability_domain,
381
+ 'compartment': inst.compartment_id,
382
+ 'status': inst.lifecycle_state,
383
+ 'oci_tags': inst.freeform_tags,
384
+ })
385
+
386
+ return return_nodes
387
+
388
+
389
+ def _get_inst_obj_with_ip(region: str, inst_info: Dict[str,
390
+ Any]) -> Dict[str, Any]:
391
+ get_vnic_response = query_helper.get_instance_primary_vnic(
392
+ region, inst_info)
393
+ internal_ip = get_vnic_response.private_ip
394
+ external_ip = get_vnic_response.public_ip
395
+ if external_ip is None:
396
+ external_ip = internal_ip
397
+
398
+ return {
399
+ 'id': inst_info['inst_id'],
400
+ 'name': inst_info['name'],
401
+ 'external_ip': external_ip,
402
+ 'internal_ip': internal_ip,
403
+ 'tags': inst_info['oci_tags'],
404
+ 'status': inst_info['status'],
405
+ }
406
+
407
+
408
+ def _get_head_instance_id(instances: List[Dict[str, Any]]) -> Optional[str]:
409
+ head_instance_id = None
410
+ head_node_tags = tuple(constants.HEAD_NODE_TAGS.items())
411
+ for inst in instances:
412
+ is_matched = True
413
+ for k, v in head_node_tags:
414
+ if (k, v) not in inst['oci_tags'].items():
415
+ is_matched = False
416
+ break
417
+ if is_matched:
418
+ if head_instance_id is not None:
419
+ logger.warning(
420
+ 'There are multiple head nodes in the cluster '
421
+ f'(current head instance id: {head_instance_id}, '
422
+ f'newly discovered id: {inst["inst_id"]}. It is likely '
423
+ f'that something goes wrong.')
424
+ # Don't break here so that we can continue to check and
425
+ # warn user about duplicate head instance issue so that
426
+ # user can take further action on the abnormal cluster.
427
+
428
+ head_instance_id = inst['inst_id']
429
+
430
+ return head_instance_id