skypilot-nightly 1.0.0.dev20241109__py3-none-any.whl → 1.0.0.dev20241111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/cloud_vm_ray_backend.py +0 -19
- sky/clouds/oci.py +11 -21
- sky/clouds/service_catalog/oci_catalog.py +1 -1
- sky/clouds/utils/oci_utils.py +16 -2
- sky/dag.py +19 -15
- sky/provision/__init__.py +1 -0
- sky/provision/docker_utils.py +1 -1
- sky/provision/kubernetes/instance.py +104 -102
- sky/provision/oci/__init__.py +15 -0
- sky/provision/oci/config.py +51 -0
- sky/provision/oci/instance.py +430 -0
- sky/{skylet/providers/oci/query_helper.py → provision/oci/query_utils.py} +148 -59
- sky/serve/__init__.py +2 -0
- sky/serve/load_balancer.py +34 -8
- sky/serve/load_balancing_policies.py +23 -1
- sky/serve/service.py +4 -1
- sky/serve/service_spec.py +19 -0
- sky/setup_files/MANIFEST.in +0 -1
- sky/skylet/job_lib.py +29 -17
- sky/templates/kubernetes-ray.yml.j2 +21 -1
- sky/templates/oci-ray.yml.j2 +3 -53
- sky/utils/schemas.py +8 -0
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/RECORD +29 -29
- sky/skylet/providers/oci/__init__.py +0 -2
- sky/skylet/providers/oci/node_provider.py +0 -488
- sky/skylet/providers/oci/utils.py +0 -21
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/top_level.txt +0 -0
@@ -1,56 +1,75 @@
|
|
1
|
-
"""
|
2
|
-
Helper class for some OCI operations methods which needs to be shared/called
|
3
|
-
by multiple places.
|
1
|
+
"""OCI query helper class
|
4
2
|
|
5
3
|
History:
|
6
|
-
- Hysun He (hysun.he@oracle.com) @
|
7
|
-
|
4
|
+
- Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Code here mainly
|
5
|
+
migrated from the old provisioning API.
|
6
|
+
- Hysun He (hysun.he@oracle.com) @ Oct.18, 2024: Enhancement.
|
7
|
+
find_compartment: allow search subtree when find a compartment.
|
8
8
|
"""
|
9
|
-
|
10
9
|
from datetime import datetime
|
11
|
-
import
|
10
|
+
import functools
|
11
|
+
from logging import Logger
|
12
12
|
import re
|
13
13
|
import time
|
14
14
|
import traceback
|
15
15
|
import typing
|
16
16
|
from typing import Optional
|
17
17
|
|
18
|
+
from sky import sky_logging
|
18
19
|
from sky.adaptors import common as adaptors_common
|
19
20
|
from sky.adaptors import oci as oci_adaptor
|
20
21
|
from sky.clouds.utils import oci_utils
|
21
|
-
from sky.skylet.providers.oci import utils
|
22
22
|
|
23
23
|
if typing.TYPE_CHECKING:
|
24
24
|
import pandas as pd
|
25
25
|
else:
|
26
26
|
pd = adaptors_common.LazyImport('pandas')
|
27
27
|
|
28
|
-
logger =
|
28
|
+
logger = sky_logging.init_logger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def debug_enabled(log: Logger):
|
32
|
+
|
33
|
+
def decorate(f):
|
34
|
+
|
35
|
+
@functools.wraps(f)
|
36
|
+
def wrapper(*args, **kwargs):
|
37
|
+
dt_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
38
|
+
log.debug(f'{dt_str} Enter {f}, {args}, {kwargs}')
|
39
|
+
try:
|
40
|
+
return f(*args, **kwargs)
|
41
|
+
finally:
|
42
|
+
log.debug(f'{dt_str} Exit {f}')
|
43
|
+
|
44
|
+
return wrapper
|
29
45
|
|
46
|
+
return decorate
|
30
47
|
|
31
|
-
class oci_query_helper:
|
32
48
|
|
49
|
+
class QueryHelper:
|
50
|
+
"""Helper class for some OCI operations
|
51
|
+
"""
|
33
52
|
# Call Cloud API to try getting the satisfied nodes.
|
34
53
|
@classmethod
|
35
|
-
@
|
54
|
+
@debug_enabled(logger)
|
36
55
|
def query_instances_by_tags(cls, tag_filters, region):
|
37
56
|
|
38
|
-
where_clause_tags =
|
57
|
+
where_clause_tags = ''
|
39
58
|
for tag_key in tag_filters:
|
40
|
-
if where_clause_tags !=
|
41
|
-
where_clause_tags +=
|
59
|
+
if where_clause_tags != '':
|
60
|
+
where_clause_tags += ' && '
|
42
61
|
|
43
62
|
tag_value = tag_filters[tag_key]
|
44
|
-
where_clause_tags += (f
|
45
|
-
f
|
63
|
+
where_clause_tags += (f'(freeformTags.key = \'{tag_key}\''
|
64
|
+
f' && freeformTags.value = \'{tag_value}\')')
|
46
65
|
|
47
|
-
qv_str = (f
|
48
|
-
f
|
49
|
-
f
|
66
|
+
qv_str = (f'query instance resources where {where_clause_tags}'
|
67
|
+
f' && (lifecycleState != \'TERMINATED\''
|
68
|
+
f' && lifecycleState != \'TERMINATING\')')
|
50
69
|
|
51
70
|
qv = oci_adaptor.oci.resource_search.models.StructuredSearchDetails(
|
52
71
|
query=qv_str,
|
53
|
-
type=
|
72
|
+
type='Structured',
|
54
73
|
matching_context_type=oci_adaptor.oci.resource_search.models.
|
55
74
|
SearchDetails.MATCHING_CONTEXT_TYPE_NONE,
|
56
75
|
)
|
@@ -63,44 +82,98 @@ class oci_query_helper:
|
|
63
82
|
|
64
83
|
@classmethod
|
65
84
|
def terminate_instances_by_tags(cls, tag_filters, region) -> int:
|
66
|
-
logger.debug(f
|
85
|
+
logger.debug(f'Terminate instance by tags: {tag_filters}')
|
67
86
|
insts = cls.query_instances_by_tags(tag_filters, region)
|
68
87
|
fail_count = 0
|
69
88
|
for inst in insts:
|
70
89
|
inst_id = inst.identifier
|
71
|
-
logger.debug(f
|
90
|
+
logger.debug(f'Got instance(to be terminated): {inst_id}')
|
72
91
|
|
73
92
|
try:
|
74
93
|
oci_adaptor.get_core_client(
|
75
94
|
region,
|
76
95
|
oci_utils.oci_config.get_profile()).terminate_instance(
|
77
96
|
inst_id)
|
78
|
-
except
|
97
|
+
except oci_adaptor.oci.exceptions.ServiceError as e:
|
79
98
|
fail_count += 1
|
80
|
-
logger.error(f
|
99
|
+
logger.error(f'Terminate instance failed: {str(e)}\n: {inst}')
|
81
100
|
traceback.print_exc()
|
82
101
|
|
83
102
|
if fail_count == 0:
|
84
|
-
logger.debug(
|
103
|
+
logger.debug('Instance teardown result: OK')
|
85
104
|
else:
|
86
|
-
logger.
|
105
|
+
logger.warning(f'Instance teardown result: {fail_count} failed!')
|
87
106
|
|
88
107
|
return fail_count
|
89
108
|
|
90
109
|
@classmethod
|
91
|
-
@
|
110
|
+
@debug_enabled(logger)
|
111
|
+
def launch_instance(cls, region, launch_config):
|
112
|
+
""" To create a new instance """
|
113
|
+
return oci_adaptor.get_core_client(
|
114
|
+
region, oci_utils.oci_config.get_profile()).launch_instance(
|
115
|
+
launch_instance_details=launch_config)
|
116
|
+
|
117
|
+
@classmethod
|
118
|
+
@debug_enabled(logger)
|
119
|
+
def start_instance(cls, region, instance_id):
|
120
|
+
""" To start an existing instance """
|
121
|
+
return oci_adaptor.get_core_client(
|
122
|
+
region, oci_utils.oci_config.get_profile()).instance_action(
|
123
|
+
instance_id=instance_id, action='START')
|
124
|
+
|
125
|
+
@classmethod
|
126
|
+
@debug_enabled(logger)
|
127
|
+
def stop_instance(cls, region, instance_id):
|
128
|
+
""" To stop an instance """
|
129
|
+
return oci_adaptor.get_core_client(
|
130
|
+
region, oci_utils.oci_config.get_profile()).instance_action(
|
131
|
+
instance_id=instance_id, action='STOP')
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
@debug_enabled(logger)
|
135
|
+
def wait_instance_until_status(cls, region, node_id, status):
|
136
|
+
""" To wait a instance becoming the specified state """
|
137
|
+
compute_client = oci_adaptor.get_core_client(
|
138
|
+
region, oci_utils.oci_config.get_profile())
|
139
|
+
|
140
|
+
resp = compute_client.get_instance(instance_id=node_id)
|
141
|
+
|
142
|
+
oci_adaptor.oci.wait_until(
|
143
|
+
compute_client,
|
144
|
+
resp,
|
145
|
+
'lifecycle_state',
|
146
|
+
status,
|
147
|
+
)
|
148
|
+
|
149
|
+
@classmethod
|
150
|
+
def get_instance_primary_vnic(cls, region, inst_info):
|
151
|
+
""" Get the primary vnic infomation of the instance """
|
152
|
+
list_vnic_attachments_response = oci_adaptor.get_core_client(
|
153
|
+
region, oci_utils.oci_config.get_profile()).list_vnic_attachments(
|
154
|
+
availability_domain=inst_info['ad'],
|
155
|
+
compartment_id=inst_info['compartment'],
|
156
|
+
instance_id=inst_info['inst_id'],
|
157
|
+
)
|
158
|
+
vnic = list_vnic_attachments_response.data[0]
|
159
|
+
return oci_adaptor.get_net_client(
|
160
|
+
region, oci_utils.oci_config.get_profile()).get_vnic(
|
161
|
+
vnic_id=vnic.vnic_id).data
|
162
|
+
|
163
|
+
@classmethod
|
164
|
+
@debug_enabled(logger)
|
92
165
|
def subscribe_image(cls, compartment_id, listing_id, resource_version,
|
93
166
|
region):
|
94
|
-
if (pd.isna(listing_id) or listing_id.strip() ==
|
95
|
-
listing_id.strip() ==
|
167
|
+
if (pd.isna(listing_id) or listing_id.strip() == 'None' or
|
168
|
+
listing_id.strip() == 'nan'):
|
96
169
|
return
|
97
170
|
|
98
171
|
core_client = oci_adaptor.get_core_client(
|
99
172
|
region, oci_utils.oci_config.get_profile())
|
100
173
|
try:
|
101
|
-
|
174
|
+
agreements_resp = core_client.get_app_catalog_listing_agreements(
|
102
175
|
listing_id=listing_id, resource_version=resource_version)
|
103
|
-
agreements =
|
176
|
+
agreements = agreements_resp.data
|
104
177
|
|
105
178
|
core_client.create_app_catalog_subscription(
|
106
179
|
create_app_catalog_subscription_details=oci_adaptor.oci.core.
|
@@ -113,24 +186,24 @@ class oci_query_helper:
|
|
113
186
|
oracle_terms_of_use_link,
|
114
187
|
time_retrieved=datetime.strptime(
|
115
188
|
re.sub(
|
116
|
-
|
117
|
-
|
189
|
+
r'\d{3}\+\d{2}\:\d{2}',
|
190
|
+
'Z',
|
118
191
|
str(agreements.time_retrieved),
|
119
192
|
0,
|
120
193
|
),
|
121
|
-
|
194
|
+
'%Y-%m-%d %H:%M:%S.%fZ',
|
122
195
|
),
|
123
196
|
signature=agreements.signature,
|
124
197
|
eula_link=agreements.eula_link,
|
125
198
|
))
|
126
|
-
except
|
199
|
+
except oci_adaptor.oci.exceptions.ServiceError as e:
|
127
200
|
logger.critical(
|
128
|
-
f
|
129
|
-
f
|
130
|
-
raise RuntimeError(
|
201
|
+
f'[Failed] subscribe_image: {listing_id} - {resource_version}'
|
202
|
+
f'Error message: {str(e)}')
|
203
|
+
raise RuntimeError('ERR: Image subscription error!') from e
|
131
204
|
|
132
205
|
@classmethod
|
133
|
-
@
|
206
|
+
@debug_enabled(logger)
|
134
207
|
def find_compartment(cls, region) -> str:
|
135
208
|
""" If compartment is not configured, we use root compartment """
|
136
209
|
# Try to use the configured one first
|
@@ -143,12 +216,18 @@ class oci_query_helper:
|
|
143
216
|
# config file is supported (2023/06/09).
|
144
217
|
root = oci_adaptor.get_oci_config(
|
145
218
|
region, oci_utils.oci_config.get_profile())['tenancy']
|
219
|
+
|
146
220
|
list_compartments_response = oci_adaptor.get_identity_client(
|
147
221
|
region, oci_utils.oci_config.get_profile()).list_compartments(
|
148
222
|
compartment_id=root,
|
149
223
|
name=oci_utils.oci_config.COMPARTMENT,
|
224
|
+
compartment_id_in_subtree=True,
|
225
|
+
access_level='ACCESSIBLE',
|
150
226
|
lifecycle_state='ACTIVE',
|
227
|
+
sort_by='TIMECREATED',
|
228
|
+
sort_order='DESC',
|
151
229
|
limit=1)
|
230
|
+
|
152
231
|
compartments = list_compartments_response.data
|
153
232
|
if len(compartments) > 0:
|
154
233
|
skypilot_compartment = compartments[0].id
|
@@ -159,7 +238,7 @@ class oci_query_helper:
|
|
159
238
|
return skypilot_compartment
|
160
239
|
|
161
240
|
@classmethod
|
162
|
-
@
|
241
|
+
@debug_enabled(logger)
|
163
242
|
def find_create_vcn_subnet(cls, region) -> Optional[str]:
|
164
243
|
""" If sub is not configured, we find/create VCN skypilot_vcn """
|
165
244
|
subnet = oci_utils.oci_config.get_vcn_subnet(region)
|
@@ -174,7 +253,7 @@ class oci_query_helper:
|
|
174
253
|
list_vcns_response = net_client.list_vcns(
|
175
254
|
compartment_id=skypilot_compartment,
|
176
255
|
display_name=oci_utils.oci_config.VCN_NAME,
|
177
|
-
lifecycle_state=
|
256
|
+
lifecycle_state='AVAILABLE')
|
178
257
|
vcns = list_vcns_response.data
|
179
258
|
if len(vcns) > 0:
|
180
259
|
# Found the VCN.
|
@@ -184,7 +263,7 @@ class oci_query_helper:
|
|
184
263
|
limit=1,
|
185
264
|
vcn_id=skypilot_vcn,
|
186
265
|
display_name=oci_utils.oci_config.VCN_SUBNET_NAME,
|
187
|
-
lifecycle_state=
|
266
|
+
lifecycle_state='AVAILABLE')
|
188
267
|
logger.debug(f'Got VCN subnet \n{list_subnets_response.data}')
|
189
268
|
if len(list_subnets_response.data) < 1:
|
190
269
|
logger.error(
|
@@ -201,10 +280,17 @@ class oci_query_helper:
|
|
201
280
|
return cls.create_vcn_subnet(net_client, skypilot_compartment)
|
202
281
|
|
203
282
|
@classmethod
|
204
|
-
@
|
283
|
+
@debug_enabled(logger)
|
205
284
|
def create_vcn_subnet(cls, net_client,
|
206
285
|
skypilot_compartment) -> Optional[str]:
|
286
|
+
|
287
|
+
skypilot_vcn = None # VCN for the resources
|
288
|
+
subnet = None # Subnet for the VMs
|
289
|
+
ig = None # Internet gateway
|
290
|
+
sg = None # Service gateway
|
291
|
+
|
207
292
|
try:
|
293
|
+
# pylint: disable=line-too-long
|
208
294
|
create_vcn_response = net_client.create_vcn(
|
209
295
|
create_vcn_details=oci_adaptor.oci.core.models.CreateVcnDetails(
|
210
296
|
compartment_id=skypilot_compartment,
|
@@ -274,38 +360,38 @@ class oci_query_helper:
|
|
274
360
|
update_security_list_details=oci_adaptor.oci.core.models.
|
275
361
|
UpdateSecurityListDetails(ingress_security_rules=[
|
276
362
|
oci_adaptor.oci.core.models.IngressSecurityRule(
|
277
|
-
protocol=
|
363
|
+
protocol='6',
|
278
364
|
source=oci_utils.oci_config.VCN_CIDR_INTERNET,
|
279
365
|
is_stateless=False,
|
280
|
-
source_type=
|
366
|
+
source_type='CIDR_BLOCK',
|
281
367
|
tcp_options=oci_adaptor.oci.core.models.TcpOptions(
|
282
368
|
destination_port_range=oci_adaptor.oci.core.models.
|
283
369
|
PortRange(max=22, min=22),
|
284
370
|
source_port_range=oci_adaptor.oci.core.models.
|
285
371
|
PortRange(max=65535, min=1)),
|
286
|
-
description=
|
372
|
+
description='Allow SSH port.'),
|
287
373
|
oci_adaptor.oci.core.models.IngressSecurityRule(
|
288
|
-
protocol=
|
374
|
+
protocol='all',
|
289
375
|
source=oci_utils.oci_config.VCN_SUBNET_CIDR,
|
290
376
|
is_stateless=False,
|
291
|
-
source_type=
|
292
|
-
description=
|
377
|
+
source_type='CIDR_BLOCK',
|
378
|
+
description='Allow all traffic from/to same subnet.'),
|
293
379
|
oci_adaptor.oci.core.models.IngressSecurityRule(
|
294
|
-
protocol=
|
380
|
+
protocol='1',
|
295
381
|
source=oci_utils.oci_config.VCN_CIDR_INTERNET,
|
296
382
|
is_stateless=False,
|
297
|
-
source_type=
|
383
|
+
source_type='CIDR_BLOCK',
|
298
384
|
icmp_options=oci_adaptor.oci.core.models.IcmpOptions(
|
299
385
|
type=3, code=4),
|
300
|
-
description=
|
386
|
+
description='ICMP traffic.'),
|
301
387
|
oci_adaptor.oci.core.models.IngressSecurityRule(
|
302
|
-
protocol=
|
388
|
+
protocol='1',
|
303
389
|
source=oci_utils.oci_config.VCN_CIDR,
|
304
390
|
is_stateless=False,
|
305
|
-
source_type=
|
391
|
+
source_type='CIDR_BLOCK',
|
306
392
|
icmp_options=oci_adaptor.oci.core.models.IcmpOptions(
|
307
393
|
type=3),
|
308
|
-
description=
|
394
|
+
description='ICMP traffic (VCN).'),
|
309
395
|
]))
|
310
396
|
logger.debug(
|
311
397
|
f'Updated security_list: \n{update_security_list_response.data}'
|
@@ -325,7 +411,7 @@ class oci_query_helper:
|
|
325
411
|
]))
|
326
412
|
logger.debug(f'Route table: \n{update_route_table_response.data}')
|
327
413
|
|
328
|
-
except oci_adaptor.
|
414
|
+
except oci_adaptor.oci.exceptions.ServiceError as e:
|
329
415
|
logger.error(f'Create VCN Error: Create new VCN '
|
330
416
|
f'{oci_utils.oci_config.VCN_NAME} failed: {str(e)}')
|
331
417
|
# In case of partial success while creating vcn
|
@@ -335,7 +421,7 @@ class oci_query_helper:
|
|
335
421
|
return subnet
|
336
422
|
|
337
423
|
@classmethod
|
338
|
-
@
|
424
|
+
@debug_enabled(logger)
|
339
425
|
def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet,
|
340
426
|
internet_gateway, service_gateway):
|
341
427
|
if skypilot_vcn is None:
|
@@ -369,7 +455,7 @@ class oci_query_helper:
|
|
369
455
|
f'Deleted vcn {skypilot_vcn}-{delete_vcn_response.data}'
|
370
456
|
)
|
371
457
|
break
|
372
|
-
except oci_adaptor.
|
458
|
+
except oci_adaptor.oci.exceptions.ServiceError as e:
|
373
459
|
logger.info(f'Waiting del SG/IG/Subnet finish: {str(e)}')
|
374
460
|
retry_count = retry_count + 1
|
375
461
|
if retry_count == oci_utils.oci_config.MAX_RETRY_COUNT:
|
@@ -378,6 +464,9 @@ class oci_query_helper:
|
|
378
464
|
time.sleep(
|
379
465
|
oci_utils.oci_config.RETRY_INTERVAL_BASE_SECONDS)
|
380
466
|
|
381
|
-
except oci_adaptor.
|
467
|
+
except oci_adaptor.oci.exceptions.ServiceError as e:
|
382
468
|
logger.error(
|
383
469
|
f'Delete VCN {oci_utils.oci_config.VCN_NAME} Error: {str(e)}')
|
470
|
+
|
471
|
+
|
472
|
+
query_helper = QueryHelper()
|
sky/serve/__init__.py
CHANGED
@@ -11,6 +11,7 @@ from sky.serve.core import tail_logs
|
|
11
11
|
from sky.serve.core import terminate_replica
|
12
12
|
from sky.serve.core import up
|
13
13
|
from sky.serve.core import update
|
14
|
+
from sky.serve.load_balancing_policies import LB_POLICIES
|
14
15
|
from sky.serve.serve_state import ReplicaStatus
|
15
16
|
from sky.serve.serve_state import ServiceStatus
|
16
17
|
from sky.serve.serve_utils import DEFAULT_UPDATE_MODE
|
@@ -35,6 +36,7 @@ __all__ = [
|
|
35
36
|
'get_endpoint',
|
36
37
|
'INITIAL_VERSION',
|
37
38
|
'LB_CONTROLLER_SYNC_INTERVAL_SECONDS',
|
39
|
+
'LB_POLICIES',
|
38
40
|
'ReplicaStatus',
|
39
41
|
'ServiceComponent',
|
40
42
|
'ServiceStatus',
|
sky/serve/load_balancer.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
import asyncio
|
3
3
|
import logging
|
4
4
|
import threading
|
5
|
-
from typing import Dict, Union
|
5
|
+
from typing import Dict, Optional, Union
|
6
6
|
|
7
7
|
import aiohttp
|
8
8
|
import fastapi
|
@@ -27,18 +27,24 @@ class SkyServeLoadBalancer:
|
|
27
27
|
policy.
|
28
28
|
"""
|
29
29
|
|
30
|
-
def __init__(self,
|
30
|
+
def __init__(self,
|
31
|
+
controller_url: str,
|
32
|
+
load_balancer_port: int,
|
33
|
+
load_balancing_policy_name: Optional[str] = None) -> None:
|
31
34
|
"""Initialize the load balancer.
|
32
35
|
|
33
36
|
Args:
|
34
37
|
controller_url: The URL of the controller.
|
35
38
|
load_balancer_port: The port where the load balancer listens to.
|
39
|
+
load_balancing_policy_name: The name of the load balancing policy
|
40
|
+
to use. Defaults to None.
|
36
41
|
"""
|
37
42
|
self._app = fastapi.FastAPI()
|
38
43
|
self._controller_url: str = controller_url
|
39
44
|
self._load_balancer_port: int = load_balancer_port
|
40
|
-
|
41
|
-
|
45
|
+
# Use the registry to create the load balancing policy
|
46
|
+
self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
|
47
|
+
load_balancing_policy_name)
|
42
48
|
self._request_aggregator: serve_utils.RequestsAggregator = (
|
43
49
|
serve_utils.RequestTimestamp())
|
44
50
|
# TODO(tian): httpx.Client has a resource limit of 100 max connections
|
@@ -223,9 +229,21 @@ class SkyServeLoadBalancer:
|
|
223
229
|
uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)
|
224
230
|
|
225
231
|
|
226
|
-
def run_load_balancer(controller_addr: str,
|
227
|
-
|
228
|
-
|
232
|
+
def run_load_balancer(controller_addr: str,
|
233
|
+
load_balancer_port: int,
|
234
|
+
load_balancing_policy_name: Optional[str] = None) -> None:
|
235
|
+
""" Run the load balancer.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
controller_addr: The address of the controller.
|
239
|
+
load_balancer_port: The port where the load balancer listens to.
|
240
|
+
policy_name: The name of the load balancing policy to use. Defaults to
|
241
|
+
None.
|
242
|
+
"""
|
243
|
+
load_balancer = SkyServeLoadBalancer(
|
244
|
+
controller_url=controller_addr,
|
245
|
+
load_balancer_port=load_balancer_port,
|
246
|
+
load_balancing_policy_name=load_balancing_policy_name)
|
229
247
|
load_balancer.run()
|
230
248
|
|
231
249
|
|
@@ -241,5 +259,13 @@ if __name__ == '__main__':
|
|
241
259
|
required=True,
|
242
260
|
default=8890,
|
243
261
|
help='The port where the load balancer listens to.')
|
262
|
+
available_policies = list(lb_policies.LB_POLICIES.keys())
|
263
|
+
parser.add_argument(
|
264
|
+
'--load-balancing-policy',
|
265
|
+
choices=available_policies,
|
266
|
+
default='round_robin',
|
267
|
+
help=f'The load balancing policy to use. Available policies: '
|
268
|
+
f'{", ".join(available_policies)}.')
|
244
269
|
args = parser.parse_args()
|
245
|
-
run_load_balancer(args.controller_addr, args.load_balancer_port
|
270
|
+
run_load_balancer(args.controller_addr, args.load_balancer_port,
|
271
|
+
args.load_balancing_policy)
|
@@ -10,6 +10,10 @@ if typing.TYPE_CHECKING:
|
|
10
10
|
|
11
11
|
logger = sky_logging.init_logger(__name__)
|
12
12
|
|
13
|
+
# Define a registry for load balancing policies
|
14
|
+
LB_POLICIES = {}
|
15
|
+
DEFAULT_LB_POLICY = None
|
16
|
+
|
13
17
|
|
14
18
|
def _request_repr(request: 'fastapi.Request') -> str:
|
15
19
|
return ('<Request '
|
@@ -25,6 +29,24 @@ class LoadBalancingPolicy:
|
|
25
29
|
def __init__(self) -> None:
|
26
30
|
self.ready_replicas: List[str] = []
|
27
31
|
|
32
|
+
def __init_subclass__(cls, name: str, default: bool = False):
|
33
|
+
LB_POLICIES[name] = cls
|
34
|
+
if default:
|
35
|
+
global DEFAULT_LB_POLICY
|
36
|
+
assert DEFAULT_LB_POLICY is None, (
|
37
|
+
'Only one policy can be default.')
|
38
|
+
DEFAULT_LB_POLICY = name
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
|
42
|
+
"""Create a load balancing policy from a name."""
|
43
|
+
if policy_name is None:
|
44
|
+
policy_name = DEFAULT_LB_POLICY
|
45
|
+
|
46
|
+
if policy_name not in LB_POLICIES:
|
47
|
+
raise ValueError(f'Unknown load balancing policy: {policy_name}')
|
48
|
+
return LB_POLICIES[policy_name]()
|
49
|
+
|
28
50
|
def set_ready_replicas(self, ready_replicas: List[str]) -> None:
|
29
51
|
raise NotImplementedError
|
30
52
|
|
@@ -44,7 +66,7 @@ class LoadBalancingPolicy:
|
|
44
66
|
raise NotImplementedError
|
45
67
|
|
46
68
|
|
47
|
-
class RoundRobinPolicy(LoadBalancingPolicy):
|
69
|
+
class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True):
|
48
70
|
"""Round-robin load balancing policy."""
|
49
71
|
|
50
72
|
def __init__(self) -> None:
|
sky/serve/service.py
CHANGED
@@ -219,6 +219,9 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
|
|
219
219
|
load_balancer_port = common_utils.find_free_port(
|
220
220
|
constants.LOAD_BALANCER_PORT_START)
|
221
221
|
|
222
|
+
# Extract the load balancing policy from the service spec
|
223
|
+
policy_name = service_spec.load_balancing_policy
|
224
|
+
|
222
225
|
# Start the load balancer.
|
223
226
|
# TODO(tian): Probably we could enable multiple ports specified in
|
224
227
|
# service spec and we could start multiple load balancers.
|
@@ -227,7 +230,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
|
|
227
230
|
target=ux_utils.RedirectOutputForProcess(
|
228
231
|
load_balancer.run_load_balancer,
|
229
232
|
load_balancer_log_file).run,
|
230
|
-
args=(controller_addr, load_balancer_port))
|
233
|
+
args=(controller_addr, load_balancer_port, policy_name))
|
231
234
|
load_balancer_process.start()
|
232
235
|
serve_state.set_service_load_balancer_port(service_name,
|
233
236
|
load_balancer_port)
|
sky/serve/service_spec.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional
|
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
9
|
+
from sky import serve
|
9
10
|
from sky.serve import constants
|
10
11
|
from sky.utils import common_utils
|
11
12
|
from sky.utils import schemas
|
@@ -29,6 +30,7 @@ class SkyServiceSpec:
|
|
29
30
|
base_ondemand_fallback_replicas: Optional[int] = None,
|
30
31
|
upscale_delay_seconds: Optional[int] = None,
|
31
32
|
downscale_delay_seconds: Optional[int] = None,
|
33
|
+
load_balancing_policy: Optional[str] = None,
|
32
34
|
) -> None:
|
33
35
|
if max_replicas is not None and max_replicas < min_replicas:
|
34
36
|
with ux_utils.print_exception_no_traceback():
|
@@ -55,6 +57,13 @@ class SkyServiceSpec:
|
|
55
57
|
raise ValueError('readiness_path must start with a slash (/). '
|
56
58
|
f'Got: {readiness_path}')
|
57
59
|
|
60
|
+
# Add the check for unknown load balancing policies
|
61
|
+
if (load_balancing_policy is not None and
|
62
|
+
load_balancing_policy not in serve.LB_POLICIES):
|
63
|
+
with ux_utils.print_exception_no_traceback():
|
64
|
+
raise ValueError(
|
65
|
+
f'Unknown load balancing policy: {load_balancing_policy}. '
|
66
|
+
f'Available policies: {list(serve.LB_POLICIES.keys())}')
|
58
67
|
self._readiness_path: str = readiness_path
|
59
68
|
self._initial_delay_seconds: int = initial_delay_seconds
|
60
69
|
self._readiness_timeout_seconds: int = readiness_timeout_seconds
|
@@ -69,6 +78,7 @@ class SkyServiceSpec:
|
|
69
78
|
int] = base_ondemand_fallback_replicas
|
70
79
|
self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds
|
71
80
|
self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds
|
81
|
+
self._load_balancing_policy: Optional[str] = load_balancing_policy
|
72
82
|
|
73
83
|
self._use_ondemand_fallback: bool = (
|
74
84
|
self.dynamic_ondemand_fallback is not None and
|
@@ -150,6 +160,8 @@ class SkyServiceSpec:
|
|
150
160
|
service_config['dynamic_ondemand_fallback'] = policy_section.get(
|
151
161
|
'dynamic_ondemand_fallback', None)
|
152
162
|
|
163
|
+
service_config['load_balancing_policy'] = config.get(
|
164
|
+
'load_balancing_policy', None)
|
153
165
|
return SkyServiceSpec(**service_config)
|
154
166
|
|
155
167
|
@staticmethod
|
@@ -205,6 +217,8 @@ class SkyServiceSpec:
|
|
205
217
|
self.upscale_delay_seconds)
|
206
218
|
add_if_not_none('replica_policy', 'downscale_delay_seconds',
|
207
219
|
self.downscale_delay_seconds)
|
220
|
+
add_if_not_none('load_balancing_policy', None,
|
221
|
+
self._load_balancing_policy)
|
208
222
|
return config
|
209
223
|
|
210
224
|
def probe_str(self):
|
@@ -256,6 +270,7 @@ class SkyServiceSpec:
|
|
256
270
|
Readiness probe timeout seconds: {self.readiness_timeout_seconds}
|
257
271
|
Replica autoscaling policy: {self.autoscaling_policy_str()}
|
258
272
|
Spot Policy: {self.spot_policy_str()}
|
273
|
+
Load Balancing Policy: {self.load_balancing_policy}
|
259
274
|
""")
|
260
275
|
|
261
276
|
@property
|
@@ -310,3 +325,7 @@ class SkyServiceSpec:
|
|
310
325
|
@property
|
311
326
|
def use_ondemand_fallback(self) -> bool:
|
312
327
|
return self._use_ondemand_fallback
|
328
|
+
|
329
|
+
@property
|
330
|
+
def load_balancing_policy(self) -> Optional[str]:
|
331
|
+
return self._load_balancing_policy
|
sky/setup_files/MANIFEST.in
CHANGED
@@ -6,7 +6,6 @@ include sky/setup_files/*
|
|
6
6
|
include sky/skylet/*.sh
|
7
7
|
include sky/skylet/LICENSE
|
8
8
|
include sky/skylet/providers/ibm/*
|
9
|
-
include sky/skylet/providers/oci/*
|
10
9
|
include sky/skylet/providers/scp/*
|
11
10
|
include sky/skylet/providers/*.py
|
12
11
|
include sky/skylet/ray_patches/*.patch
|