skypilot-nightly 1.0.0.dev20241109__py3-none-any.whl → 1.0.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/cloud_vm_ray_backend.py +0 -19
  3. sky/clouds/oci.py +11 -21
  4. sky/clouds/service_catalog/oci_catalog.py +1 -1
  5. sky/clouds/utils/oci_utils.py +16 -2
  6. sky/dag.py +19 -15
  7. sky/provision/__init__.py +1 -0
  8. sky/provision/docker_utils.py +1 -1
  9. sky/provision/kubernetes/instance.py +104 -102
  10. sky/provision/oci/__init__.py +15 -0
  11. sky/provision/oci/config.py +51 -0
  12. sky/provision/oci/instance.py +430 -0
  13. sky/{skylet/providers/oci/query_helper.py → provision/oci/query_utils.py} +148 -59
  14. sky/serve/__init__.py +2 -0
  15. sky/serve/load_balancer.py +34 -8
  16. sky/serve/load_balancing_policies.py +23 -1
  17. sky/serve/service.py +4 -1
  18. sky/serve/service_spec.py +19 -0
  19. sky/setup_files/MANIFEST.in +0 -1
  20. sky/skylet/job_lib.py +29 -17
  21. sky/templates/kubernetes-ray.yml.j2 +21 -1
  22. sky/templates/oci-ray.yml.j2 +3 -53
  23. sky/utils/schemas.py +8 -0
  24. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/METADATA +1 -1
  25. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/RECORD +29 -29
  26. sky/skylet/providers/oci/__init__.py +0 -2
  27. sky/skylet/providers/oci/node_provider.py +0 -488
  28. sky/skylet/providers/oci/utils.py +0 -21
  29. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/LICENSE +0 -0
  30. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/WHEEL +0 -0
  31. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/entry_points.txt +0 -0
  32. {skypilot_nightly-1.0.0.dev20241109.dist-info → skypilot_nightly-1.0.0.dev20241111.dist-info}/top_level.txt +0 -0
@@ -1,56 +1,75 @@
1
- """
2
- Helper class for some OCI operations methods which needs to be shared/called
3
- by multiple places.
1
+ """OCI query helper class
4
2
 
5
3
  History:
6
- - Hysun He (hysun.he@oracle.com) @ Apr, 2023: Initial implementation
7
-
4
+ - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Code here mainly
5
+ migrated from the old provisioning API.
6
+ - Hysun He (hysun.he@oracle.com) @ Oct.18, 2024: Enhancement.
7
+ find_compartment: allow search subtree when find a compartment.
8
8
  """
9
-
10
9
  from datetime import datetime
11
- import logging
10
+ import functools
11
+ from logging import Logger
12
12
  import re
13
13
  import time
14
14
  import traceback
15
15
  import typing
16
16
  from typing import Optional
17
17
 
18
+ from sky import sky_logging
18
19
  from sky.adaptors import common as adaptors_common
19
20
  from sky.adaptors import oci as oci_adaptor
20
21
  from sky.clouds.utils import oci_utils
21
- from sky.skylet.providers.oci import utils
22
22
 
23
23
  if typing.TYPE_CHECKING:
24
24
  import pandas as pd
25
25
  else:
26
26
  pd = adaptors_common.LazyImport('pandas')
27
27
 
28
- logger = logging.getLogger(__name__)
28
+ logger = sky_logging.init_logger(__name__)
29
+
30
+
31
+ def debug_enabled(log: Logger):
32
+
33
+ def decorate(f):
34
+
35
+ @functools.wraps(f)
36
+ def wrapper(*args, **kwargs):
37
+ dt_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
38
+ log.debug(f'{dt_str} Enter {f}, {args}, {kwargs}')
39
+ try:
40
+ return f(*args, **kwargs)
41
+ finally:
42
+ log.debug(f'{dt_str} Exit {f}')
43
+
44
+ return wrapper
29
45
 
46
+ return decorate
30
47
 
31
- class oci_query_helper:
32
48
 
49
+ class QueryHelper:
50
+ """Helper class for some OCI operations
51
+ """
33
52
  # Call Cloud API to try getting the satisfied nodes.
34
53
  @classmethod
35
- @utils.debug_enabled(logger=logger)
54
+ @debug_enabled(logger)
36
55
  def query_instances_by_tags(cls, tag_filters, region):
37
56
 
38
- where_clause_tags = ""
57
+ where_clause_tags = ''
39
58
  for tag_key in tag_filters:
40
- if where_clause_tags != "":
41
- where_clause_tags += " && "
59
+ if where_clause_tags != '':
60
+ where_clause_tags += ' && '
42
61
 
43
62
  tag_value = tag_filters[tag_key]
44
- where_clause_tags += (f"(freeformTags.key = '{tag_key}'"
45
- f" && freeformTags.value = '{tag_value}')")
63
+ where_clause_tags += (f'(freeformTags.key = \'{tag_key}\''
64
+ f' && freeformTags.value = \'{tag_value}\')')
46
65
 
47
- qv_str = (f"query instance resources where {where_clause_tags}"
48
- f" && (lifecycleState != 'TERMINATED'"
49
- f" && lifecycleState != 'TERMINATING')")
66
+ qv_str = (f'query instance resources where {where_clause_tags}'
67
+ f' && (lifecycleState != \'TERMINATED\''
68
+ f' && lifecycleState != \'TERMINATING\')')
50
69
 
51
70
  qv = oci_adaptor.oci.resource_search.models.StructuredSearchDetails(
52
71
  query=qv_str,
53
- type="Structured",
72
+ type='Structured',
54
73
  matching_context_type=oci_adaptor.oci.resource_search.models.
55
74
  SearchDetails.MATCHING_CONTEXT_TYPE_NONE,
56
75
  )
@@ -63,44 +82,98 @@ class oci_query_helper:
63
82
 
64
83
  @classmethod
65
84
  def terminate_instances_by_tags(cls, tag_filters, region) -> int:
66
- logger.debug(f"Terminate instance by tags: {tag_filters}")
85
+ logger.debug(f'Terminate instance by tags: {tag_filters}')
67
86
  insts = cls.query_instances_by_tags(tag_filters, region)
68
87
  fail_count = 0
69
88
  for inst in insts:
70
89
  inst_id = inst.identifier
71
- logger.debug(f"Got instance(to be terminated): {inst_id}")
90
+ logger.debug(f'Got instance(to be terminated): {inst_id}')
72
91
 
73
92
  try:
74
93
  oci_adaptor.get_core_client(
75
94
  region,
76
95
  oci_utils.oci_config.get_profile()).terminate_instance(
77
96
  inst_id)
78
- except Exception as e:
97
+ except oci_adaptor.oci.exceptions.ServiceError as e:
79
98
  fail_count += 1
80
- logger.error(f"Terminate instance failed: {str(e)}\n: {inst}")
99
+ logger.error(f'Terminate instance failed: {str(e)}\n: {inst}')
81
100
  traceback.print_exc()
82
101
 
83
102
  if fail_count == 0:
84
- logger.debug(f"Instance teardown result: OK")
103
+ logger.debug('Instance teardown result: OK')
85
104
  else:
86
- logger.warn(f"Instance teardown result: {fail_count} failed!")
105
+ logger.warning(f'Instance teardown result: {fail_count} failed!')
87
106
 
88
107
  return fail_count
89
108
 
90
109
  @classmethod
91
- @utils.debug_enabled(logger=logger)
110
+ @debug_enabled(logger)
111
+ def launch_instance(cls, region, launch_config):
112
+ """ To create a new instance """
113
+ return oci_adaptor.get_core_client(
114
+ region, oci_utils.oci_config.get_profile()).launch_instance(
115
+ launch_instance_details=launch_config)
116
+
117
+ @classmethod
118
+ @debug_enabled(logger)
119
+ def start_instance(cls, region, instance_id):
120
+ """ To start an existing instance """
121
+ return oci_adaptor.get_core_client(
122
+ region, oci_utils.oci_config.get_profile()).instance_action(
123
+ instance_id=instance_id, action='START')
124
+
125
+ @classmethod
126
+ @debug_enabled(logger)
127
+ def stop_instance(cls, region, instance_id):
128
+ """ To stop an instance """
129
+ return oci_adaptor.get_core_client(
130
+ region, oci_utils.oci_config.get_profile()).instance_action(
131
+ instance_id=instance_id, action='STOP')
132
+
133
+ @classmethod
134
+ @debug_enabled(logger)
135
+ def wait_instance_until_status(cls, region, node_id, status):
136
+ """ To wait a instance becoming the specified state """
137
+ compute_client = oci_adaptor.get_core_client(
138
+ region, oci_utils.oci_config.get_profile())
139
+
140
+ resp = compute_client.get_instance(instance_id=node_id)
141
+
142
+ oci_adaptor.oci.wait_until(
143
+ compute_client,
144
+ resp,
145
+ 'lifecycle_state',
146
+ status,
147
+ )
148
+
149
+ @classmethod
150
+ def get_instance_primary_vnic(cls, region, inst_info):
151
+ """ Get the primary vnic infomation of the instance """
152
+ list_vnic_attachments_response = oci_adaptor.get_core_client(
153
+ region, oci_utils.oci_config.get_profile()).list_vnic_attachments(
154
+ availability_domain=inst_info['ad'],
155
+ compartment_id=inst_info['compartment'],
156
+ instance_id=inst_info['inst_id'],
157
+ )
158
+ vnic = list_vnic_attachments_response.data[0]
159
+ return oci_adaptor.get_net_client(
160
+ region, oci_utils.oci_config.get_profile()).get_vnic(
161
+ vnic_id=vnic.vnic_id).data
162
+
163
+ @classmethod
164
+ @debug_enabled(logger)
92
165
  def subscribe_image(cls, compartment_id, listing_id, resource_version,
93
166
  region):
94
- if (pd.isna(listing_id) or listing_id.strip() == "None" or
95
- listing_id.strip() == "nan"):
167
+ if (pd.isna(listing_id) or listing_id.strip() == 'None' or
168
+ listing_id.strip() == 'nan'):
96
169
  return
97
170
 
98
171
  core_client = oci_adaptor.get_core_client(
99
172
  region, oci_utils.oci_config.get_profile())
100
173
  try:
101
- agreements_response = core_client.get_app_catalog_listing_agreements(
174
+ agreements_resp = core_client.get_app_catalog_listing_agreements(
102
175
  listing_id=listing_id, resource_version=resource_version)
103
- agreements = agreements_response.data
176
+ agreements = agreements_resp.data
104
177
 
105
178
  core_client.create_app_catalog_subscription(
106
179
  create_app_catalog_subscription_details=oci_adaptor.oci.core.
@@ -113,24 +186,24 @@ class oci_query_helper:
113
186
  oracle_terms_of_use_link,
114
187
  time_retrieved=datetime.strptime(
115
188
  re.sub(
116
- "\d{3}\+\d{2}\:\d{2}",
117
- "Z",
189
+ r'\d{3}\+\d{2}\:\d{2}',
190
+ 'Z',
118
191
  str(agreements.time_retrieved),
119
192
  0,
120
193
  ),
121
- "%Y-%m-%d %H:%M:%S.%fZ",
194
+ '%Y-%m-%d %H:%M:%S.%fZ',
122
195
  ),
123
196
  signature=agreements.signature,
124
197
  eula_link=agreements.eula_link,
125
198
  ))
126
- except Exception as e:
199
+ except oci_adaptor.oci.exceptions.ServiceError as e:
127
200
  logger.critical(
128
- f"subscribe_image: {listing_id} - {resource_version} ... [Failed]"
129
- f"Error message: {str(e)}")
130
- raise RuntimeError("ERR: Image subscription error!")
201
+ f'[Failed] subscribe_image: {listing_id} - {resource_version}'
202
+ f'Error message: {str(e)}')
203
+ raise RuntimeError('ERR: Image subscription error!') from e
131
204
 
132
205
  @classmethod
133
- @utils.debug_enabled(logger=logger)
206
+ @debug_enabled(logger)
134
207
  def find_compartment(cls, region) -> str:
135
208
  """ If compartment is not configured, we use root compartment """
136
209
  # Try to use the configured one first
@@ -143,12 +216,18 @@ class oci_query_helper:
143
216
  # config file is supported (2023/06/09).
144
217
  root = oci_adaptor.get_oci_config(
145
218
  region, oci_utils.oci_config.get_profile())['tenancy']
219
+
146
220
  list_compartments_response = oci_adaptor.get_identity_client(
147
221
  region, oci_utils.oci_config.get_profile()).list_compartments(
148
222
  compartment_id=root,
149
223
  name=oci_utils.oci_config.COMPARTMENT,
224
+ compartment_id_in_subtree=True,
225
+ access_level='ACCESSIBLE',
150
226
  lifecycle_state='ACTIVE',
227
+ sort_by='TIMECREATED',
228
+ sort_order='DESC',
151
229
  limit=1)
230
+
152
231
  compartments = list_compartments_response.data
153
232
  if len(compartments) > 0:
154
233
  skypilot_compartment = compartments[0].id
@@ -159,7 +238,7 @@ class oci_query_helper:
159
238
  return skypilot_compartment
160
239
 
161
240
  @classmethod
162
- @utils.debug_enabled(logger=logger)
241
+ @debug_enabled(logger)
163
242
  def find_create_vcn_subnet(cls, region) -> Optional[str]:
164
243
  """ If sub is not configured, we find/create VCN skypilot_vcn """
165
244
  subnet = oci_utils.oci_config.get_vcn_subnet(region)
@@ -174,7 +253,7 @@ class oci_query_helper:
174
253
  list_vcns_response = net_client.list_vcns(
175
254
  compartment_id=skypilot_compartment,
176
255
  display_name=oci_utils.oci_config.VCN_NAME,
177
- lifecycle_state="AVAILABLE")
256
+ lifecycle_state='AVAILABLE')
178
257
  vcns = list_vcns_response.data
179
258
  if len(vcns) > 0:
180
259
  # Found the VCN.
@@ -184,7 +263,7 @@ class oci_query_helper:
184
263
  limit=1,
185
264
  vcn_id=skypilot_vcn,
186
265
  display_name=oci_utils.oci_config.VCN_SUBNET_NAME,
187
- lifecycle_state="AVAILABLE")
266
+ lifecycle_state='AVAILABLE')
188
267
  logger.debug(f'Got VCN subnet \n{list_subnets_response.data}')
189
268
  if len(list_subnets_response.data) < 1:
190
269
  logger.error(
@@ -201,10 +280,17 @@ class oci_query_helper:
201
280
  return cls.create_vcn_subnet(net_client, skypilot_compartment)
202
281
 
203
282
  @classmethod
204
- @utils.debug_enabled(logger=logger)
283
+ @debug_enabled(logger)
205
284
  def create_vcn_subnet(cls, net_client,
206
285
  skypilot_compartment) -> Optional[str]:
286
+
287
+ skypilot_vcn = None # VCN for the resources
288
+ subnet = None # Subnet for the VMs
289
+ ig = None # Internet gateway
290
+ sg = None # Service gateway
291
+
207
292
  try:
293
+ # pylint: disable=line-too-long
208
294
  create_vcn_response = net_client.create_vcn(
209
295
  create_vcn_details=oci_adaptor.oci.core.models.CreateVcnDetails(
210
296
  compartment_id=skypilot_compartment,
@@ -274,38 +360,38 @@ class oci_query_helper:
274
360
  update_security_list_details=oci_adaptor.oci.core.models.
275
361
  UpdateSecurityListDetails(ingress_security_rules=[
276
362
  oci_adaptor.oci.core.models.IngressSecurityRule(
277
- protocol="6",
363
+ protocol='6',
278
364
  source=oci_utils.oci_config.VCN_CIDR_INTERNET,
279
365
  is_stateless=False,
280
- source_type="CIDR_BLOCK",
366
+ source_type='CIDR_BLOCK',
281
367
  tcp_options=oci_adaptor.oci.core.models.TcpOptions(
282
368
  destination_port_range=oci_adaptor.oci.core.models.
283
369
  PortRange(max=22, min=22),
284
370
  source_port_range=oci_adaptor.oci.core.models.
285
371
  PortRange(max=65535, min=1)),
286
- description="Allow SSH port."),
372
+ description='Allow SSH port.'),
287
373
  oci_adaptor.oci.core.models.IngressSecurityRule(
288
- protocol="all",
374
+ protocol='all',
289
375
  source=oci_utils.oci_config.VCN_SUBNET_CIDR,
290
376
  is_stateless=False,
291
- source_type="CIDR_BLOCK",
292
- description="Allow all traffic from/to same subnet."),
377
+ source_type='CIDR_BLOCK',
378
+ description='Allow all traffic from/to same subnet.'),
293
379
  oci_adaptor.oci.core.models.IngressSecurityRule(
294
- protocol="1",
380
+ protocol='1',
295
381
  source=oci_utils.oci_config.VCN_CIDR_INTERNET,
296
382
  is_stateless=False,
297
- source_type="CIDR_BLOCK",
383
+ source_type='CIDR_BLOCK',
298
384
  icmp_options=oci_adaptor.oci.core.models.IcmpOptions(
299
385
  type=3, code=4),
300
- description="ICMP traffic."),
386
+ description='ICMP traffic.'),
301
387
  oci_adaptor.oci.core.models.IngressSecurityRule(
302
- protocol="1",
388
+ protocol='1',
303
389
  source=oci_utils.oci_config.VCN_CIDR,
304
390
  is_stateless=False,
305
- source_type="CIDR_BLOCK",
391
+ source_type='CIDR_BLOCK',
306
392
  icmp_options=oci_adaptor.oci.core.models.IcmpOptions(
307
393
  type=3),
308
- description="ICMP traffic (VCN)."),
394
+ description='ICMP traffic (VCN).'),
309
395
  ]))
310
396
  logger.debug(
311
397
  f'Updated security_list: \n{update_security_list_response.data}'
@@ -325,7 +411,7 @@ class oci_query_helper:
325
411
  ]))
326
412
  logger.debug(f'Route table: \n{update_route_table_response.data}')
327
413
 
328
- except oci_adaptor.service_exception() as e:
414
+ except oci_adaptor.oci.exceptions.ServiceError as e:
329
415
  logger.error(f'Create VCN Error: Create new VCN '
330
416
  f'{oci_utils.oci_config.VCN_NAME} failed: {str(e)}')
331
417
  # In case of partial success while creating vcn
@@ -335,7 +421,7 @@ class oci_query_helper:
335
421
  return subnet
336
422
 
337
423
  @classmethod
338
- @utils.debug_enabled(logger=logger)
424
+ @debug_enabled(logger)
339
425
  def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet,
340
426
  internet_gateway, service_gateway):
341
427
  if skypilot_vcn is None:
@@ -369,7 +455,7 @@ class oci_query_helper:
369
455
  f'Deleted vcn {skypilot_vcn}-{delete_vcn_response.data}'
370
456
  )
371
457
  break
372
- except oci_adaptor.service_exception() as e:
458
+ except oci_adaptor.oci.exceptions.ServiceError as e:
373
459
  logger.info(f'Waiting del SG/IG/Subnet finish: {str(e)}')
374
460
  retry_count = retry_count + 1
375
461
  if retry_count == oci_utils.oci_config.MAX_RETRY_COUNT:
@@ -378,6 +464,9 @@ class oci_query_helper:
378
464
  time.sleep(
379
465
  oci_utils.oci_config.RETRY_INTERVAL_BASE_SECONDS)
380
466
 
381
- except oci_adaptor.service_exception() as e:
467
+ except oci_adaptor.oci.exceptions.ServiceError as e:
382
468
  logger.error(
383
469
  f'Delete VCN {oci_utils.oci_config.VCN_NAME} Error: {str(e)}')
470
+
471
+
472
+ query_helper = QueryHelper()
sky/serve/__init__.py CHANGED
@@ -11,6 +11,7 @@ from sky.serve.core import tail_logs
11
11
  from sky.serve.core import terminate_replica
12
12
  from sky.serve.core import up
13
13
  from sky.serve.core import update
14
+ from sky.serve.load_balancing_policies import LB_POLICIES
14
15
  from sky.serve.serve_state import ReplicaStatus
15
16
  from sky.serve.serve_state import ServiceStatus
16
17
  from sky.serve.serve_utils import DEFAULT_UPDATE_MODE
@@ -35,6 +36,7 @@ __all__ = [
35
36
  'get_endpoint',
36
37
  'INITIAL_VERSION',
37
38
  'LB_CONTROLLER_SYNC_INTERVAL_SECONDS',
39
+ 'LB_POLICIES',
38
40
  'ReplicaStatus',
39
41
  'ServiceComponent',
40
42
  'ServiceStatus',
@@ -2,7 +2,7 @@
2
2
  import asyncio
3
3
  import logging
4
4
  import threading
5
- from typing import Dict, Union
5
+ from typing import Dict, Optional, Union
6
6
 
7
7
  import aiohttp
8
8
  import fastapi
@@ -27,18 +27,24 @@ class SkyServeLoadBalancer:
27
27
  policy.
28
28
  """
29
29
 
30
- def __init__(self, controller_url: str, load_balancer_port: int) -> None:
30
+ def __init__(self,
31
+ controller_url: str,
32
+ load_balancer_port: int,
33
+ load_balancing_policy_name: Optional[str] = None) -> None:
31
34
  """Initialize the load balancer.
32
35
 
33
36
  Args:
34
37
  controller_url: The URL of the controller.
35
38
  load_balancer_port: The port where the load balancer listens to.
39
+ load_balancing_policy_name: The name of the load balancing policy
40
+ to use. Defaults to None.
36
41
  """
37
42
  self._app = fastapi.FastAPI()
38
43
  self._controller_url: str = controller_url
39
44
  self._load_balancer_port: int = load_balancer_port
40
- self._load_balancing_policy: lb_policies.LoadBalancingPolicy = (
41
- lb_policies.RoundRobinPolicy())
45
+ # Use the registry to create the load balancing policy
46
+ self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
47
+ load_balancing_policy_name)
42
48
  self._request_aggregator: serve_utils.RequestsAggregator = (
43
49
  serve_utils.RequestTimestamp())
44
50
  # TODO(tian): httpx.Client has a resource limit of 100 max connections
@@ -223,9 +229,21 @@ class SkyServeLoadBalancer:
223
229
  uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)
224
230
 
225
231
 
226
- def run_load_balancer(controller_addr: str, load_balancer_port: int):
227
- load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
228
- load_balancer_port=load_balancer_port)
232
+ def run_load_balancer(controller_addr: str,
233
+ load_balancer_port: int,
234
+ load_balancing_policy_name: Optional[str] = None) -> None:
235
+ """ Run the load balancer.
236
+
237
+ Args:
238
+ controller_addr: The address of the controller.
239
+ load_balancer_port: The port where the load balancer listens to.
240
+ policy_name: The name of the load balancing policy to use. Defaults to
241
+ None.
242
+ """
243
+ load_balancer = SkyServeLoadBalancer(
244
+ controller_url=controller_addr,
245
+ load_balancer_port=load_balancer_port,
246
+ load_balancing_policy_name=load_balancing_policy_name)
229
247
  load_balancer.run()
230
248
 
231
249
 
@@ -241,5 +259,13 @@ if __name__ == '__main__':
241
259
  required=True,
242
260
  default=8890,
243
261
  help='The port where the load balancer listens to.')
262
+ available_policies = list(lb_policies.LB_POLICIES.keys())
263
+ parser.add_argument(
264
+ '--load-balancing-policy',
265
+ choices=available_policies,
266
+ default='round_robin',
267
+ help=f'The load balancing policy to use. Available policies: '
268
+ f'{", ".join(available_policies)}.')
244
269
  args = parser.parse_args()
245
- run_load_balancer(args.controller_addr, args.load_balancer_port)
270
+ run_load_balancer(args.controller_addr, args.load_balancer_port,
271
+ args.load_balancing_policy)
@@ -10,6 +10,10 @@ if typing.TYPE_CHECKING:
10
10
 
11
11
  logger = sky_logging.init_logger(__name__)
12
12
 
13
+ # Define a registry for load balancing policies
14
+ LB_POLICIES = {}
15
+ DEFAULT_LB_POLICY = None
16
+
13
17
 
14
18
  def _request_repr(request: 'fastapi.Request') -> str:
15
19
  return ('<Request '
@@ -25,6 +29,24 @@ class LoadBalancingPolicy:
25
29
  def __init__(self) -> None:
26
30
  self.ready_replicas: List[str] = []
27
31
 
32
+ def __init_subclass__(cls, name: str, default: bool = False):
33
+ LB_POLICIES[name] = cls
34
+ if default:
35
+ global DEFAULT_LB_POLICY
36
+ assert DEFAULT_LB_POLICY is None, (
37
+ 'Only one policy can be default.')
38
+ DEFAULT_LB_POLICY = name
39
+
40
+ @classmethod
41
+ def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
42
+ """Create a load balancing policy from a name."""
43
+ if policy_name is None:
44
+ policy_name = DEFAULT_LB_POLICY
45
+
46
+ if policy_name not in LB_POLICIES:
47
+ raise ValueError(f'Unknown load balancing policy: {policy_name}')
48
+ return LB_POLICIES[policy_name]()
49
+
28
50
  def set_ready_replicas(self, ready_replicas: List[str]) -> None:
29
51
  raise NotImplementedError
30
52
 
@@ -44,7 +66,7 @@ class LoadBalancingPolicy:
44
66
  raise NotImplementedError
45
67
 
46
68
 
47
- class RoundRobinPolicy(LoadBalancingPolicy):
69
+ class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True):
48
70
  """Round-robin load balancing policy."""
49
71
 
50
72
  def __init__(self) -> None:
sky/serve/service.py CHANGED
@@ -219,6 +219,9 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
219
219
  load_balancer_port = common_utils.find_free_port(
220
220
  constants.LOAD_BALANCER_PORT_START)
221
221
 
222
+ # Extract the load balancing policy from the service spec
223
+ policy_name = service_spec.load_balancing_policy
224
+
222
225
  # Start the load balancer.
223
226
  # TODO(tian): Probably we could enable multiple ports specified in
224
227
  # service spec and we could start multiple load balancers.
@@ -227,7 +230,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
227
230
  target=ux_utils.RedirectOutputForProcess(
228
231
  load_balancer.run_load_balancer,
229
232
  load_balancer_log_file).run,
230
- args=(controller_addr, load_balancer_port))
233
+ args=(controller_addr, load_balancer_port, policy_name))
231
234
  load_balancer_process.start()
232
235
  serve_state.set_service_load_balancer_port(service_name,
233
236
  load_balancer_port)
sky/serve/service_spec.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional
6
6
 
7
7
  import yaml
8
8
 
9
+ from sky import serve
9
10
  from sky.serve import constants
10
11
  from sky.utils import common_utils
11
12
  from sky.utils import schemas
@@ -29,6 +30,7 @@ class SkyServiceSpec:
29
30
  base_ondemand_fallback_replicas: Optional[int] = None,
30
31
  upscale_delay_seconds: Optional[int] = None,
31
32
  downscale_delay_seconds: Optional[int] = None,
33
+ load_balancing_policy: Optional[str] = None,
32
34
  ) -> None:
33
35
  if max_replicas is not None and max_replicas < min_replicas:
34
36
  with ux_utils.print_exception_no_traceback():
@@ -55,6 +57,13 @@ class SkyServiceSpec:
55
57
  raise ValueError('readiness_path must start with a slash (/). '
56
58
  f'Got: {readiness_path}')
57
59
 
60
+ # Add the check for unknown load balancing policies
61
+ if (load_balancing_policy is not None and
62
+ load_balancing_policy not in serve.LB_POLICIES):
63
+ with ux_utils.print_exception_no_traceback():
64
+ raise ValueError(
65
+ f'Unknown load balancing policy: {load_balancing_policy}. '
66
+ f'Available policies: {list(serve.LB_POLICIES.keys())}')
58
67
  self._readiness_path: str = readiness_path
59
68
  self._initial_delay_seconds: int = initial_delay_seconds
60
69
  self._readiness_timeout_seconds: int = readiness_timeout_seconds
@@ -69,6 +78,7 @@ class SkyServiceSpec:
69
78
  int] = base_ondemand_fallback_replicas
70
79
  self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds
71
80
  self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds
81
+ self._load_balancing_policy: Optional[str] = load_balancing_policy
72
82
 
73
83
  self._use_ondemand_fallback: bool = (
74
84
  self.dynamic_ondemand_fallback is not None and
@@ -150,6 +160,8 @@ class SkyServiceSpec:
150
160
  service_config['dynamic_ondemand_fallback'] = policy_section.get(
151
161
  'dynamic_ondemand_fallback', None)
152
162
 
163
+ service_config['load_balancing_policy'] = config.get(
164
+ 'load_balancing_policy', None)
153
165
  return SkyServiceSpec(**service_config)
154
166
 
155
167
  @staticmethod
@@ -205,6 +217,8 @@ class SkyServiceSpec:
205
217
  self.upscale_delay_seconds)
206
218
  add_if_not_none('replica_policy', 'downscale_delay_seconds',
207
219
  self.downscale_delay_seconds)
220
+ add_if_not_none('load_balancing_policy', None,
221
+ self._load_balancing_policy)
208
222
  return config
209
223
 
210
224
  def probe_str(self):
@@ -256,6 +270,7 @@ class SkyServiceSpec:
256
270
  Readiness probe timeout seconds: {self.readiness_timeout_seconds}
257
271
  Replica autoscaling policy: {self.autoscaling_policy_str()}
258
272
  Spot Policy: {self.spot_policy_str()}
273
+ Load Balancing Policy: {self.load_balancing_policy}
259
274
  """)
260
275
 
261
276
  @property
@@ -310,3 +325,7 @@ class SkyServiceSpec:
310
325
  @property
311
326
  def use_ondemand_fallback(self) -> bool:
312
327
  return self._use_ondemand_fallback
328
+
329
+ @property
330
+ def load_balancing_policy(self) -> Optional[str]:
331
+ return self._load_balancing_policy
@@ -6,7 +6,6 @@ include sky/setup_files/*
6
6
  include sky/skylet/*.sh
7
7
  include sky/skylet/LICENSE
8
8
  include sky/skylet/providers/ibm/*
9
- include sky/skylet/providers/oci/*
10
9
  include sky/skylet/providers/scp/*
11
10
  include sky/skylet/providers/*.py
12
11
  include sky/skylet/ray_patches/*.patch