skypilot-nightly 1.0.0.dev20250101__py3-none-any.whl → 1.0.0.dev20250102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,287 @@
1
+ """DigitalOcean instance provisioning."""
2
+
3
+ import time
4
+ from typing import Any, Dict, List, Optional
5
+ import uuid
6
+
7
+ from sky import sky_logging
8
+ from sky import status_lib
9
+ from sky.provision import common
10
+ from sky.provision.do import constants
11
+ from sky.provision.do import utils
12
+
13
+ # The maximum number of times to poll for the status of an operation
14
+ MAX_POLLS = 60 // constants.POLL_INTERVAL
15
+ # Stopping instances can take several minutes, so we increase the timeout
16
+ MAX_POLLS_FOR_UP_OR_STOP = MAX_POLLS * 8
17
+
18
+ logger = sky_logging.init_logger(__name__)
19
+
20
+
21
+ def _get_head_instance(
22
+ instances: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
23
+ for instance_name, instance_meta in instances.items():
24
+ if instance_name.endswith('-head'):
25
+ return instance_meta
26
+ return None
27
+
28
+
29
+ def run_instances(region: str, cluster_name_on_cloud: str,
30
+ config: common.ProvisionConfig) -> common.ProvisionRecord:
31
+ """Runs instances for the given cluster."""
32
+
33
+ pending_status = ['new']
34
+ newly_started_instances = utils.filter_instances(cluster_name_on_cloud,
35
+ pending_status + ['off'])
36
+ while True:
37
+ instances = utils.filter_instances(cluster_name_on_cloud,
38
+ pending_status)
39
+ if not instances:
40
+ break
41
+ instance_statuses = [
42
+ instance['status'] for instance in instances.values()
43
+ ]
44
+ logger.info(f'Waiting for {len(instances)} instances to be ready: '
45
+ f'{instance_statuses}')
46
+ time.sleep(constants.POLL_INTERVAL)
47
+
48
+ exist_instances = utils.filter_instances(cluster_name_on_cloud,
49
+ status_filters=pending_status +
50
+ ['active', 'off'])
51
+ if len(exist_instances) > config.count:
52
+ raise RuntimeError(
53
+ f'Cluster {cluster_name_on_cloud} already has '
54
+ f'{len(exist_instances)} nodes, but {config.count} are required.')
55
+
56
+ stopped_instances = utils.filter_instances(cluster_name_on_cloud,
57
+ status_filters=['off'])
58
+ for instance in stopped_instances.values():
59
+ utils.start_instance(instance)
60
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
61
+ instances = utils.filter_instances(cluster_name_on_cloud, ['off'])
62
+ if len(instances) == 0:
63
+ break
64
+ num_stopped_instances = len(stopped_instances)
65
+ num_restarted_instances = num_stopped_instances - len(instances)
66
+ logger.info(
67
+ f'Waiting for {num_restarted_instances}/{num_stopped_instances} '
68
+ 'stopped instances to be restarted.')
69
+ time.sleep(constants.POLL_INTERVAL)
70
+ else:
71
+ msg = ('run_instances: Failed to restart all'
72
+ 'instances possibly due to to capacity issue.')
73
+ logger.warning(msg)
74
+ raise RuntimeError(msg)
75
+
76
+ exist_instances = utils.filter_instances(cluster_name_on_cloud,
77
+ status_filters=['active'])
78
+ head_instance = _get_head_instance(exist_instances)
79
+ to_start_count = config.count - len(exist_instances)
80
+ if to_start_count < 0:
81
+ raise RuntimeError(
82
+ f'Cluster {cluster_name_on_cloud} already has '
83
+ f'{len(exist_instances)} nodes, but {config.count} are required.')
84
+ if to_start_count == 0:
85
+ if head_instance is None:
86
+ head_instance = list(exist_instances.values())[0]
87
+ utils.rename_instance(
88
+ head_instance,
89
+ f'{cluster_name_on_cloud}-{uuid.uuid4().hex[:4]}-head')
90
+ assert head_instance is not None, ('`head_instance` should not be None')
91
+ logger.info(f'Cluster {cluster_name_on_cloud} already has '
92
+ f'{len(exist_instances)} nodes, no need to start more.')
93
+ return common.ProvisionRecord(
94
+ provider_name='do',
95
+ cluster_name=cluster_name_on_cloud,
96
+ region=region,
97
+ zone=None,
98
+ head_instance_id=head_instance['name'],
99
+ resumed_instance_ids=list(newly_started_instances.keys()),
100
+ created_instance_ids=[],
101
+ )
102
+
103
+ created_instances: List[Dict[str, Any]] = []
104
+ for _ in range(to_start_count):
105
+ instance_type = 'head' if head_instance is None else 'worker'
106
+ instance = utils.create_instance(
107
+ region=region,
108
+ cluster_name_on_cloud=cluster_name_on_cloud,
109
+ instance_type=instance_type,
110
+ config=config)
111
+ logger.info(f'Launched instance {instance["name"]}.')
112
+ created_instances.append(instance)
113
+ if head_instance is None:
114
+ head_instance = instance
115
+
116
+ # Wait for instances to be ready.
117
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
118
+ instances = utils.filter_instances(cluster_name_on_cloud,
119
+ status_filters=['active'])
120
+ logger.info('Waiting for instances to be ready: '
121
+ f'({len(instances)}/{config.count}).')
122
+ if len(instances) == config.count:
123
+ break
124
+
125
+ time.sleep(constants.POLL_INTERVAL)
126
+ else:
127
+ # Failed to launch config.count of instances after max retries
128
+ msg = 'run_instances: Failed to create the instances'
129
+ logger.warning(msg)
130
+ raise RuntimeError(msg)
131
+ assert head_instance is not None, 'head_instance should not be None'
132
+ return common.ProvisionRecord(
133
+ provider_name='do',
134
+ cluster_name=cluster_name_on_cloud,
135
+ region=region,
136
+ zone=None,
137
+ head_instance_id=head_instance['name'],
138
+ resumed_instance_ids=list(stopped_instances.keys()),
139
+ created_instance_ids=[
140
+ instance['name'] for instance in created_instances
141
+ ],
142
+ )
143
+
144
+
145
+ def wait_instances(region: str, cluster_name_on_cloud: str,
146
+ state: Optional[status_lib.ClusterStatus]) -> None:
147
+ del region, cluster_name_on_cloud, state # unused
148
+ # We already wait on ready state in `run_instances` no need
149
+
150
+
151
+ def stop_instances(
152
+ cluster_name_on_cloud: str,
153
+ provider_config: Optional[Dict[str, Any]] = None,
154
+ worker_only: bool = False,
155
+ ) -> None:
156
+ del provider_config # unused
157
+ all_instances = utils.filter_instances(cluster_name_on_cloud,
158
+ status_filters=None)
159
+ num_instances = len(all_instances)
160
+
161
+ # Request a stop on all instances
162
+ for instance_name, instance_meta in all_instances.items():
163
+ if worker_only and instance_name.endswith('-head'):
164
+ num_instances -= 1
165
+ continue
166
+ utils.stop_instance(instance_meta)
167
+
168
+ # Wait for instances to stop
169
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
170
+ all_instances = utils.filter_instances(cluster_name_on_cloud, ['off'])
171
+ if len(all_instances) >= num_instances:
172
+ break
173
+ time.sleep(constants.POLL_INTERVAL)
174
+ else:
175
+ raise RuntimeError(f'Maximum number of polls: '
176
+ f'{MAX_POLLS_FOR_UP_OR_STOP} reached. '
177
+ f'Instance {all_instances} is still not in '
178
+ 'STOPPED status.')
179
+
180
+
181
+ def terminate_instances(
182
+ cluster_name_on_cloud: str,
183
+ provider_config: Optional[Dict[str, Any]] = None,
184
+ worker_only: bool = False,
185
+ ) -> None:
186
+ """See sky/provision/__init__.py"""
187
+ del provider_config # unused
188
+ instances = utils.filter_instances(cluster_name_on_cloud,
189
+ status_filters=None)
190
+ for instance_name, instance_meta in instances.items():
191
+ logger.debug(f'Terminating instance {instance_name}')
192
+ if worker_only and instance_name.endswith('-head'):
193
+ continue
194
+ utils.down_instance(instance_meta)
195
+
196
+ for _ in range(MAX_POLLS_FOR_UP_OR_STOP):
197
+ instances = utils.filter_instances(cluster_name_on_cloud,
198
+ status_filters=None)
199
+ if len(instances) == 0 or len(instances) <= 1 and worker_only:
200
+ break
201
+ time.sleep(constants.POLL_INTERVAL)
202
+ else:
203
+ msg = ('Failed to delete all instances')
204
+ logger.warning(msg)
205
+ raise RuntimeError(msg)
206
+
207
+
208
+ def get_cluster_info(
209
+ region: str,
210
+ cluster_name_on_cloud: str,
211
+ provider_config: Optional[Dict[str, Any]] = None,
212
+ ) -> common.ClusterInfo:
213
+ del region # unused
214
+ running_instances = utils.filter_instances(cluster_name_on_cloud,
215
+ ['active'])
216
+ instances: Dict[str, List[common.InstanceInfo]] = {}
217
+ head_instance: Optional[str] = None
218
+ for instance_name, instance_meta in running_instances.items():
219
+ if instance_name.endswith('-head'):
220
+ head_instance = instance_name
221
+ for net in instance_meta['networks']['v4']:
222
+ if net['type'] == 'public':
223
+ instance_ip = net['ip_address']
224
+ break
225
+ instances[instance_name] = [
226
+ common.InstanceInfo(
227
+ instance_id=instance_meta['name'],
228
+ internal_ip=instance_ip,
229
+ external_ip=instance_ip,
230
+ ssh_port=22,
231
+ tags={},
232
+ )
233
+ ]
234
+
235
+ assert head_instance is not None, 'no head instance found'
236
+ return common.ClusterInfo(
237
+ instances=instances,
238
+ head_instance_id=head_instance,
239
+ provider_name='do',
240
+ provider_config=provider_config,
241
+ )
242
+
243
+
244
+ def query_instances(
245
+ cluster_name_on_cloud: str,
246
+ provider_config: Optional[Dict[str, Any]] = None,
247
+ non_terminated_only: bool = True,
248
+ ) -> Dict[str, Optional[status_lib.ClusterStatus]]:
249
+ """See sky/provision/__init__.py"""
250
+ # terminated instances are not retrieved by the
251
+ # API making `non_terminated_only` argument moot.
252
+ del non_terminated_only
253
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
254
+ instances = utils.filter_instances(cluster_name_on_cloud,
255
+ status_filters=None)
256
+
257
+ status_map = {
258
+ 'new': status_lib.ClusterStatus.INIT,
259
+ 'archive': status_lib.ClusterStatus.INIT,
260
+ 'active': status_lib.ClusterStatus.UP,
261
+ 'off': status_lib.ClusterStatus.STOPPED,
262
+ }
263
+ statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {}
264
+ for instance_meta in instances.values():
265
+ status = status_map[instance_meta['status']]
266
+ statuses[instance_meta['name']] = status
267
+ return statuses
268
+
269
+
270
+ def open_ports(
271
+ cluster_name_on_cloud: str,
272
+ ports: List[str],
273
+ provider_config: Optional[Dict[str, Any]] = None,
274
+ ) -> None:
275
+ """See sky/provision/__init__.py"""
276
+ logger.debug(
277
+ f'Skip opening ports {ports} for DigitalOcean instances, as all '
278
+ 'ports are open by default.')
279
+ del cluster_name_on_cloud, provider_config, ports
280
+
281
+
282
+ def cleanup_ports(
283
+ cluster_name_on_cloud: str,
284
+ ports: List[str],
285
+ provider_config: Optional[Dict[str, Any]] = None,
286
+ ) -> None:
287
+ del cluster_name_on_cloud, provider_config, ports
@@ -0,0 +1,306 @@
1
+ """DigitalOcean API client wrapper for SkyPilot.
2
+
3
+ Example usage of `pydo` client library was mostly taken from here:
4
+ https://github.com/digitalocean/pydo/blob/main/examples/poc_droplets_volumes_sshkeys.py
5
+ """
6
+
7
+ import copy
8
+ import os
9
+ import typing
10
+ from typing import Any, Dict, List, Optional
11
+ import urllib
12
+ import uuid
13
+
14
+ from sky import sky_logging
15
+ from sky.adaptors import do
16
+ from sky.provision import common
17
+ from sky.provision import constants as provision_constants
18
+ from sky.provision.do import constants
19
+ from sky.utils import common_utils
20
+
21
+ if typing.TYPE_CHECKING:
22
+ from sky import resources
23
+ from sky import status_lib
24
+
25
+ logger = sky_logging.init_logger(__name__)
26
+
27
+ POSSIBLE_CREDENTIALS_PATHS = [
28
+ os.path.expanduser(
29
+ '~/Library/Application Support/doctl/config.yaml'), # OS X
30
+ os.path.expanduser(
31
+ os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config/'),
32
+ 'doctl/config.yaml')), # Linux
33
+ ]
34
+ INITIAL_BACKOFF_SECONDS = 10
35
+ MAX_BACKOFF_FACTOR = 10
36
+ MAX_ATTEMPTS = 6
37
+ SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}'
38
+
39
+ CREDENTIALS_PATH = '~/.config/doctl/config.yaml'
40
+ _client = None
41
+ _ssh_key_id = None
42
+
43
+
44
+ class DigitalOceanError(Exception):
45
+ pass
46
+
47
+
48
+ def _init_client():
49
+ global _client, CREDENTIALS_PATH
50
+ assert _client is None
51
+ CREDENTIALS_PATH = None
52
+ credentials_found = 0
53
+ for path in POSSIBLE_CREDENTIALS_PATHS:
54
+ if os.path.exists(path):
55
+ CREDENTIALS_PATH = path
56
+ credentials_found += 1
57
+ logger.debug(f'Digital Ocean credential path found at {path}')
58
+ if not credentials_found > 1:
59
+ logger.debug('more than 1 credential file found')
60
+ if CREDENTIALS_PATH is None:
61
+ raise DigitalOceanError(
62
+ 'no credentials file found from '
63
+ f'the following paths {POSSIBLE_CREDENTIALS_PATHS}')
64
+
65
+ # attempt default context
66
+ credentials = common_utils.read_yaml(CREDENTIALS_PATH)
67
+ default_token = credentials.get('access-token', None)
68
+ if default_token is not None:
69
+ try:
70
+ test_client = do.pydo.Client(token=default_token)
71
+ test_client.droplets.list()
72
+ logger.debug('trying `default` context')
73
+ _client = test_client
74
+ return _client
75
+ except do.exceptions().HttpResponseError:
76
+ pass
77
+
78
+ auth_contexts = credentials.get('auth-contexts', None)
79
+ if auth_contexts is not None:
80
+ for context, api_token in auth_contexts.items():
81
+ try:
82
+ test_client = do.pydo.Client(token=api_token)
83
+ test_client.droplets.list()
84
+ logger.debug(f'using {context} context')
85
+ _client = test_client
86
+ break
87
+ except do.exceptions().HttpResponseError:
88
+ continue
89
+ else:
90
+ raise DigitalOceanError(
91
+ 'no valid api tokens found try '
92
+ 'setting a new API token with `doctl auth init`')
93
+ return _client
94
+
95
+
96
+ def client():
97
+ global _client
98
+ if _client is None:
99
+ _client = _init_client()
100
+ return _client
101
+
102
+
103
+ def ssh_key_id(public_key: str):
104
+ global _ssh_key_id
105
+ if _ssh_key_id is None:
106
+ page = 1
107
+ paginated = True
108
+ while paginated:
109
+ try:
110
+ resp = client().ssh_keys.list(per_page=50, page=page)
111
+ for ssh_key in resp['ssh_keys']:
112
+ if ssh_key['public_key'] == public_key:
113
+ _ssh_key_id = ssh_key
114
+ return _ssh_key_id
115
+ except do.exceptions().HttpResponseError as err:
116
+ raise DigitalOceanError(
117
+ f'Error: {err.status_code} {err.reason}: '
118
+ f'{err.error.message}') from err
119
+
120
+ pages = resp['links']
121
+ if 'pages' in pages and 'next' in pages['pages']:
122
+ pages = pages['pages']
123
+ parsed_url = urllib.parse.urlparse(pages['next'])
124
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
125
+ else:
126
+ paginated = False
127
+
128
+ request = {
129
+ 'public_key': public_key,
130
+ 'name': SSH_KEY_NAME_ON_DO,
131
+ }
132
+ _ssh_key_id = client().ssh_keys.create(body=request)['ssh_key']
133
+ return _ssh_key_id
134
+
135
+
136
+ def _create_volume(request: Dict[str, Any]) -> Dict[str, Any]:
137
+ try:
138
+ resp = client().volumes.create(body=request)
139
+ volume = resp['volume']
140
+ except do.exceptions().HttpResponseError as err:
141
+ raise DigitalOceanError(
142
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
143
+ ) from err
144
+ else:
145
+ return volume
146
+
147
+
148
+ def _create_droplet(request: Dict[str, Any]) -> Dict[str, Any]:
149
+ try:
150
+ resp = client().droplets.create(body=request)
151
+ droplet_id = resp['droplet']['id']
152
+
153
+ get_resp = client().droplets.get(droplet_id)
154
+ droplet = get_resp['droplet']
155
+ except do.exceptions().HttpResponseError as err:
156
+ raise DigitalOceanError(
157
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
158
+ ) from err
159
+ return droplet
160
+
161
+
162
+ def create_instance(region: str, cluster_name_on_cloud: str, instance_type: str,
163
+ config: common.ProvisionConfig) -> Dict[str, Any]:
164
+ """Creates a instance and mounts the requested block storage
165
+
166
+ Args:
167
+ region (str): instance region
168
+ instance_name (str): name of instance
169
+ config (common.ProvisionConfig): provisioner configuration
170
+
171
+ Returns:
172
+ Dict[str, Any]: instance metadata
173
+ """
174
+ # sort tags by key to support deterministic unit test stubbing
175
+ tags = dict(sorted(copy.deepcopy(config.tags).items()))
176
+ tags = {
177
+ 'Name': cluster_name_on_cloud,
178
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
179
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
180
+ **tags
181
+ }
182
+ tags = [f'{key}:{value}' for key, value in tags.items()]
183
+ default_image = constants.GPU_IMAGES.get(
184
+ config.node_config['InstanceType'],
185
+ 'gpu-h100x1-base',
186
+ )
187
+ image_id = config.node_config['ImageId']
188
+ image_id = image_id if image_id is not None else default_image
189
+ instance_name = (f'{cluster_name_on_cloud}-'
190
+ f'{uuid.uuid4().hex[:4]}-{instance_type}')
191
+ instance_request = {
192
+ 'name': instance_name,
193
+ 'region': region,
194
+ 'size': config.node_config['InstanceType'],
195
+ 'image': image_id,
196
+ 'ssh_keys': [
197
+ ssh_key_id(
198
+ config.authentication_config['ssh_public_key'])['fingerprint']
199
+ ],
200
+ 'tags': tags,
201
+ }
202
+ instance = _create_droplet(instance_request)
203
+
204
+ volume_request = {
205
+ 'size_gigabytes': config.node_config['DiskSize'],
206
+ 'name': instance_name,
207
+ 'region': region,
208
+ 'filesystem_type': 'ext4',
209
+ 'tags': tags
210
+ }
211
+ volume = _create_volume(volume_request)
212
+
213
+ attach_request = {'type': 'attach', 'droplet_id': instance['id']}
214
+ try:
215
+ client().volume_actions.post_by_id(volume['id'], attach_request)
216
+ except do.exceptions().HttpResponseError as err:
217
+ raise DigitalOceanError(
218
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
219
+ ) from err
220
+ logger.debug(f'{instance_name} created')
221
+ return instance
222
+
223
+
224
+ def start_instance(instance: Dict[str, Any]):
225
+ try:
226
+ client().droplet_actions.post(droplet_id=instance['id'],
227
+ body={'type': 'power_on'})
228
+ except do.exceptions().HttpResponseError as err:
229
+ raise DigitalOceanError(
230
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
231
+ ) from err
232
+
233
+
234
+ def stop_instance(instance: Dict[str, Any]):
235
+ try:
236
+ client().droplet_actions.post(
237
+ droplet_id=instance['id'],
238
+ body={'type': 'shutdown'},
239
+ )
240
+ except do.exceptions().HttpResponseError as err:
241
+ raise DigitalOceanError(
242
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
243
+ ) from err
244
+
245
+
246
+ def down_instance(instance: Dict[str, Any]):
247
+ # We use dangerous destroy to atomically delete
248
+ # block storage and instance for autodown
249
+ try:
250
+ client().droplets.destroy_with_associated_resources_dangerous(
251
+ droplet_id=instance['id'], x_dangerous=True)
252
+ except do.exceptions().HttpResponseError as err:
253
+ if 'a destroy is already in progress' in err.error.message:
254
+ return
255
+ raise DigitalOceanError(
256
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
257
+ ) from err
258
+
259
+
260
+ def rename_instance(instance: Dict[str, Any], new_name: str):
261
+ try:
262
+ client().droplet_actions.rename(droplet=instance['id'],
263
+ body={
264
+ 'type': 'rename',
265
+ 'name': new_name
266
+ })
267
+ except do.exceptions().HttpResponseError as err:
268
+ raise DigitalOceanError(
269
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
270
+ ) from err
271
+
272
+
273
+ def filter_instances(
274
+ cluster_name_on_cloud: str,
275
+ status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
276
+ """Returns Dict mapping instance name
277
+ to instance metadata filtered by status
278
+ """
279
+
280
+ filtered_instances: Dict[str, Any] = {}
281
+ page = 1
282
+ paginated = True
283
+ while paginated:
284
+ try:
285
+ resp = client().droplets.list(
286
+ tag_name=f'{provision_constants.TAG_SKYPILOT_CLUSTER_NAME}:'
287
+ f'{cluster_name_on_cloud}',
288
+ per_page=50,
289
+ page=page)
290
+ for instance in resp['droplets']:
291
+ if status_filters is None or instance[
292
+ 'status'] in status_filters:
293
+ filtered_instances[instance['name']] = instance
294
+ except do.exceptions().HttpResponseError as err:
295
+ raise DigitalOceanError(
296
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
297
+ ) from err
298
+
299
+ pages = resp['links']
300
+ if 'pages' in pages and 'next' in pages['pages']:
301
+ pages = pages['pages']
302
+ parsed_url = urllib.parse.urlparse(pages['next'])
303
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
304
+ else:
305
+ paginated = False
306
+ return filtered_instances
@@ -338,14 +338,20 @@ class DockerInitializer:
338
338
  no_exist = 'NoExist'
339
339
  # SkyPilot: Add the current user to the docker group first (if needed),
340
340
  # before checking if docker is installed to avoid permission issues.
341
- cleaned_output = self._run(
342
- 'id -nG $USER | grep -qw docker || '
343
- 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
344
- f'command -v {self.docker_cmd} || echo {no_exist!r}')
345
- if no_exist in cleaned_output or 'docker' not in cleaned_output:
346
- logger.error(
347
- f'{self.docker_cmd.capitalize()} not installed. Please use an '
348
- f'image with {self.docker_cmd.capitalize()} installed.')
341
+ docker_cmd = ('id -nG $USER | grep -qw docker || '
342
+ 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
343
+ f'command -v {self.docker_cmd} || echo {no_exist!r}')
344
+ cleaned_output = self._run(docker_cmd)
345
+ timeout = 60 * 10 # 10 minute timeout
346
+ start = time.time()
347
+ while no_exist in cleaned_output or 'docker' not in cleaned_output:
348
+ if time.time() - start > timeout:
349
+ logger.error(
350
+ f'{self.docker_cmd.capitalize()} not installed. Please use '
351
+ f'an image with {self.docker_cmd.capitalize()} installed.')
352
+ return
353
+ time.sleep(5)
354
+ cleaned_output = self._run(docker_cmd)
349
355
 
350
356
  def _check_container_status(self):
351
357
  if self.initialized:
@@ -415,7 +415,6 @@ def _post_provision_setup(
415
415
  f'{json.dumps(dataclasses.asdict(provision_record), indent=2)}\n'
416
416
  'Cluster info:\n'
417
417
  f'{json.dumps(dataclasses.asdict(cluster_info), indent=2)}')
418
-
419
418
  head_instance = cluster_info.get_head_instance()
420
419
  if head_instance is None:
421
420
  e = RuntimeError(f'Provision failed for cluster {cluster_name!r}. '
@@ -127,6 +127,7 @@ extras_require: Dict[str, List[str]] = {
127
127
  'fluidstack': [], # No dependencies needed for fluidstack
128
128
  'cudo': ['cudo-compute>=0.1.10'],
129
129
  'paperspace': [], # No dependencies needed for paperspace
130
+ 'do': ['pydo>=0.3.0', 'azure-core>=1.24.0', 'azure-common'],
130
131
  'vsphere': [
131
132
  'pyvmomi==8.0.1.0.2',
132
133
  # vsphere-automation-sdk is also required, but it does not have