skypilot-nightly 1.0.0.dev20250824__py3-none-any.whl → 1.0.0.dev20250825__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. sky/__init__.py +2 -2
  2. sky/catalog/data_fetchers/fetch_lambda_cloud.py +1 -0
  3. sky/dashboard/out/404.html +1 -1
  4. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  5. sky/dashboard/out/clusters/[cluster].html +1 -1
  6. sky/dashboard/out/clusters.html +1 -1
  7. sky/dashboard/out/config.html +1 -1
  8. sky/dashboard/out/index.html +1 -1
  9. sky/dashboard/out/infra/[context].html +1 -1
  10. sky/dashboard/out/infra.html +1 -1
  11. sky/dashboard/out/jobs/[job].html +1 -1
  12. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  13. sky/dashboard/out/jobs.html +1 -1
  14. sky/dashboard/out/users.html +1 -1
  15. sky/dashboard/out/volumes.html +1 -1
  16. sky/dashboard/out/workspace/new.html +1 -1
  17. sky/dashboard/out/workspaces/[name].html +1 -1
  18. sky/dashboard/out/workspaces.html +1 -1
  19. sky/serve/autoscalers.py +347 -5
  20. sky/serve/controller.py +36 -6
  21. sky/serve/load_balancer.py +49 -16
  22. sky/serve/load_balancing_policies.py +115 -1
  23. sky/serve/service.py +2 -1
  24. sky/serve/service_spec.py +26 -4
  25. sky/utils/schemas.py +18 -2
  26. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/METADATA +1 -1
  27. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/RECORD +33 -33
  28. /sky/dashboard/out/_next/static/{P-Au_yHqNENhnPF3shEpK → n7XGGtvnHqbVUS8eayoGG}/_buildManifest.js +0 -0
  29. /sky/dashboard/out/_next/static/{P-Au_yHqNENhnPF3shEpK → n7XGGtvnHqbVUS8eayoGG}/_ssgManifest.js +0 -0
  30. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/WHEEL +0 -0
  31. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/entry_points.txt +0 -0
  32. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/licenses/LICENSE +0 -0
  33. {skypilot_nightly-1.0.0.dev20250824.dist-info → skypilot_nightly-1.0.0.dev20250825.dist-info}/top_level.txt +0 -0
sky/serve/autoscalers.py CHANGED
@@ -6,7 +6,7 @@ import enum
6
6
  import math
7
7
  import time
8
8
  import typing
9
- from typing import Any, Dict, Iterable, List, Optional, Union
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
10
10
 
11
11
  from sky import sky_logging
12
12
  from sky.serve import constants
@@ -213,6 +213,10 @@ class Autoscaler:
213
213
  # TODO(MaoZiming): use NAME to get the class.
214
214
  if spec.use_ondemand_fallback:
215
215
  return FallbackRequestRateAutoscaler(service_name, spec)
216
+ elif isinstance(spec.target_qps_per_replica, dict):
217
+ # Use instance-aware autoscaler
218
+ # when target_qps_per_replica is a dict
219
+ return InstanceAwareRequestRateAutoscaler(service_name, spec)
216
220
  else:
217
221
  return RequestRateAutoscaler(service_name, spec)
218
222
 
@@ -464,20 +468,28 @@ class RequestRateAutoscaler(_AutoscalerWithHysteresis):
464
468
  request_timestamps: All request timestamps within the window.
465
469
  """
466
470
  super().__init__(service_name, spec)
467
- self.target_qps_per_replica: Optional[
468
- float] = spec.target_qps_per_replica
471
+ self.target_qps_per_replica: Optional[Union[float, Dict[
472
+ str, float]]] = spec.target_qps_per_replica
469
473
  self.qps_window_size: int = constants.AUTOSCALER_QPS_WINDOW_SIZE_SECONDS
470
474
  self.request_timestamps: List[float] = []
471
475
 
472
476
  def _calculate_target_num_replicas(self) -> int:
473
477
  if self.target_qps_per_replica is None:
474
478
  return self.min_replicas
479
+
480
+ # RequestRateAutoscaler should only handle float values
481
+ if isinstance(self.target_qps_per_replica, dict):
482
+ raise ValueError('RequestRateAutoscaler does not support dict '
483
+ 'target_qps_per_replica. Should use '
484
+ 'InstanceAwareRequestRateAutoscaler instead.')
485
+
475
486
  num_requests_per_second = len(
476
487
  self.request_timestamps) / self.qps_window_size
477
- target_num_replicas = math.ceil(num_requests_per_second /
478
- self.target_qps_per_replica)
488
+ target_num_replicas = \
489
+ math.ceil(num_requests_per_second / self.target_qps_per_replica)
479
490
  logger.info(f'Requests per second: {num_requests_per_second}. '
480
491
  f'Target number of replicas: {target_num_replicas}.')
492
+
481
493
  return self._clip_target_num_replicas(target_num_replicas)
482
494
 
483
495
  def update_version(self, version: int, spec: 'service_spec.SkyServiceSpec',
@@ -510,6 +522,7 @@ class RequestRateAutoscaler(_AutoscalerWithHysteresis):
510
522
  ) -> List[AutoscalerDecision]:
511
523
  """Generate Autoscaling decisions based on request rate."""
512
524
 
525
+ # Use standard hysteresis-based logic (non-instance-aware)
513
526
  self._set_target_num_replicas_with_hysteresis()
514
527
 
515
528
  latest_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = []
@@ -538,6 +551,7 @@ class RequestRateAutoscaler(_AutoscalerWithHysteresis):
538
551
  if len(latest_nonterminal_replicas) > target_num_replicas:
539
552
  num_replicas_to_scale_down = (len(latest_nonterminal_replicas) -
540
553
  target_num_replicas)
554
+ # Use standard downscaling logic
541
555
  replicas_to_scale_down = (
542
556
  _select_nonterminal_replicas_to_scale_down(
543
557
  num_replicas_to_scale_down, latest_nonterminal_replicas))
@@ -562,6 +576,334 @@ class RequestRateAutoscaler(_AutoscalerWithHysteresis):
562
576
  logger.info(f'Remaining dynamic states: {dynamic_states}')
563
577
 
564
578
 
579
+ class InstanceAwareRequestRateAutoscaler(RequestRateAutoscaler):
580
+ """Instance-aware RequestRateAutoscaler:
581
+ Autoscale based on each replica's GPU-specific QPS.
582
+
583
+ This autoscaler considers different QPS targets for different GPU types
584
+ when target_qps_per_replica is provided as a dictionary mapping GPU types
585
+ to their respective QPS targets.
586
+ """
587
+
588
+ def __init__(self, service_name: str,
589
+ spec: 'service_spec.SkyServiceSpec') -> None:
590
+ super().__init__(service_name, spec)
591
+ # Ensure target_qps_per_replica is a dict for instance-aware logic
592
+ assert isinstance(spec.target_qps_per_replica, dict), \
593
+ 'InstanceAware Autoscaler requires dict type target_qps_per_replica'
594
+ # Re-assign with correct type using setattr to avoid typing issues
595
+ self.target_qps_per_replica = spec.target_qps_per_replica
596
+
597
+ def _generate_scaling_decisions(
598
+ self,
599
+ replica_infos: List['replica_managers.ReplicaInfo'],
600
+ ) -> List[AutoscalerDecision]:
601
+ """Generate autoscaling decisions with instance-aware logic."""
602
+ # Always use instance-aware logic
603
+ # since target_qps_per_replica is guaranteed to be dict
604
+ self._set_target_num_replicas_with_instance_aware_logic(replica_infos)
605
+
606
+ latest_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = []
607
+
608
+ for info in replica_infos:
609
+ if not info.is_terminal and info.version == self.latest_version:
610
+ latest_nonterminal_replicas.append(info)
611
+
612
+ target_num_replicas = self.get_final_target_num_replicas()
613
+ current_num_replicas = len(latest_nonterminal_replicas)
614
+
615
+ scaling_decisions: List[AutoscalerDecision] = []
616
+
617
+ # Decide if to scale up or down.
618
+ if target_num_replicas > current_num_replicas:
619
+ for _ in range(target_num_replicas - current_num_replicas):
620
+ # No resources_override to use when scaling up
621
+ scaling_decisions.append(
622
+ AutoscalerDecision(AutoscalerDecisionOperator.SCALE_UP,
623
+ target=None))
624
+ elif target_num_replicas < current_num_replicas:
625
+ num_replicas_to_scale_down = \
626
+ current_num_replicas - target_num_replicas
627
+
628
+ # Use instance-aware scale down logic
629
+ replicas_to_scale_down = self._select_replicas_to_scale_down_by_qps(
630
+ num_replicas_to_scale_down, latest_nonterminal_replicas)
631
+ for replica_id in replicas_to_scale_down:
632
+ scaling_decisions.append(
633
+ AutoscalerDecision(AutoscalerDecisionOperator.SCALE_DOWN,
634
+ target=replica_id))
635
+
636
+ # Outdated replicas are handled by base class generate_scaling_decisions
637
+ # No need to handle them here
638
+
639
+ upscale_decisions = [
640
+ d for d in scaling_decisions
641
+ if d.operator == AutoscalerDecisionOperator.SCALE_UP
642
+ ]
643
+ downscale_decisions = [
644
+ d for d in scaling_decisions
645
+ if d.operator == AutoscalerDecisionOperator.SCALE_DOWN
646
+ ]
647
+ logger.info(f'Scaling decisions: '
648
+ f'{len(upscale_decisions)} scale up, '
649
+ f'{len(downscale_decisions)} scale down '
650
+ f'(latest nonterminal: {current_num_replicas}, '
651
+ f'target: {target_num_replicas})')
652
+
653
+ return scaling_decisions
654
+
655
+ def _set_target_num_replicas_with_instance_aware_logic(
656
+ self, replica_infos: List['replica_managers.ReplicaInfo']) -> None:
657
+ """Set target_num_replicas using instance-aware logic."""
658
+ assert isinstance(self.target_qps_per_replica,
659
+ dict), 'Expected dict for instance-aware logic'
660
+ target_qps_dict = self.target_qps_per_replica
661
+
662
+ num_requests_per_second = len(
663
+ self.request_timestamps) / self.qps_window_size
664
+
665
+ total_qps = self._calculate_total_qps_from_replicas(replica_infos)
666
+ if total_qps > 0:
667
+ if num_requests_per_second >= total_qps:
668
+ # for upscaling, max_target_qps is the standard qps
669
+ max_target_qps = max(target_qps_dict.values())
670
+ over_request_num = num_requests_per_second - total_qps
671
+ current_num_replicas = len(replica_infos)
672
+ raw_target_num = current_num_replicas + math.ceil(
673
+ over_request_num / max_target_qps)
674
+ target_num_replicas = self._clip_target_num_replicas(
675
+ raw_target_num)
676
+ logger.info(
677
+ f'Instance-aware autoscaling: total QPS {total_qps}, '
678
+ f'num_requests_per_second: {num_requests_per_second}, '
679
+ f'upscaling, using maximum QPS {max_target_qps} '
680
+ f'from {target_qps_dict}, '
681
+ f'target replicas: {target_num_replicas}')
682
+ else:
683
+ # for downscaling, use qps for every ready_target_qps_list
684
+ # to calculate target_num_replicas
685
+ ready_target_qps_list = \
686
+ self._extract_target_qps_list_from_ready_replicas(
687
+ replica_infos)
688
+ ready_target_qps_list = sorted(ready_target_qps_list,
689
+ reverse=True)
690
+ if not ready_target_qps_list:
691
+ # Fallback to maximum QPS from config if no ready replicas
692
+ ready_target_qps_list = [max(target_qps_dict.values())]
693
+
694
+ raw_target_num = 0
695
+ qps_sum = 0.0
696
+ for qps in ready_target_qps_list:
697
+ raw_target_num += 1
698
+ qps_sum += qps
699
+ if qps_sum > num_requests_per_second:
700
+ break
701
+
702
+ target_num_replicas = self._clip_target_num_replicas(
703
+ raw_target_num)
704
+ logger.info(
705
+ f'Instance-aware autoscaling: total QPS {total_qps}, '
706
+ f'num_requests_per_second: {num_requests_per_second}, '
707
+ f'downscaling, using ready QPS list '
708
+ f'{ready_target_qps_list}, '
709
+ f'target replicas: {target_num_replicas}')
710
+ else:
711
+ # no replica is ready; use the normal min_replicas
712
+ target_num_replicas = self._clip_target_num_replicas(
713
+ self.min_replicas)
714
+ logger.info(f'Instance-aware autoscaling: no replica QPS available,'
715
+ f' target replicas: {target_num_replicas}')
716
+
717
+ # Apply hysteresis logic
718
+ old_target_num_replicas = self.target_num_replicas
719
+
720
+ # Faster scale up when there is no replica.
721
+ if self.target_num_replicas == 0:
722
+ self.target_num_replicas = target_num_replicas
723
+ elif target_num_replicas > self.target_num_replicas:
724
+ self.upscale_counter += 1
725
+ self.downscale_counter = 0
726
+ if self.upscale_counter >= self.scale_up_threshold:
727
+ self.upscale_counter = 0
728
+ self.target_num_replicas = target_num_replicas
729
+ elif target_num_replicas < self.target_num_replicas:
730
+ self.downscale_counter += 1
731
+ self.upscale_counter = 0
732
+ if self.downscale_counter >= self.scale_down_threshold:
733
+ self.downscale_counter = 0
734
+ self.target_num_replicas = target_num_replicas
735
+ else:
736
+ self.upscale_counter = self.downscale_counter = 0
737
+
738
+ logger.info(
739
+ f'Instance-aware: Old target number of replicas: '
740
+ f'{old_target_num_replicas}. '
741
+ f'Current target number of replicas: {target_num_replicas}. '
742
+ f'Final target number of replicas: {self.target_num_replicas}. '
743
+ f'Num overprovision: {self.num_overprovision}. '
744
+ f'Upscale counter: {self.upscale_counter}/'
745
+ f'{self.scale_up_threshold}. '
746
+ f'Downscale counter: {self.downscale_counter}/'
747
+ f'{self.scale_down_threshold}. ')
748
+
749
+ def _calculate_total_qps_from_replicas(
750
+ self, replica_infos: List['replica_managers.ReplicaInfo']) -> float:
751
+ """Calculate total QPS based on current replica GPU types."""
752
+ total_qps = 0.0
753
+ logger.info(f'Calculating total QPS from {len(replica_infos)} replicas')
754
+
755
+ for replica_info in replica_infos:
756
+ # Skip non-valid replicas
757
+ valid_statuses = [
758
+ serve_state.ReplicaStatus.READY,
759
+ serve_state.ReplicaStatus.STARTING,
760
+ serve_state.ReplicaStatus.PROVISIONING
761
+ ]
762
+ if replica_info.status not in valid_statuses:
763
+ logger.info(f'Skipping replica {replica_info.replica_id} '
764
+ f'with status: {replica_info.status}')
765
+ continue
766
+
767
+ gpu_type = self._get_gpu_type_from_replica_info(replica_info)
768
+ logger.info(f'Processing replica {replica_info.replica_id} '
769
+ f'with GPU type: {gpu_type}')
770
+
771
+ # Use flexible matching logic
772
+ qps_for_this_gpu = self._get_target_qps_for_gpu_type(gpu_type)
773
+ total_qps += qps_for_this_gpu
774
+ logger.info(f'GPU type {gpu_type} -> {qps_for_this_gpu} QPS')
775
+
776
+ logger.info(f'Calculated total QPS: {total_qps}')
777
+ return total_qps
778
+
779
+ def _get_target_qps_for_gpu_type(self, gpu_type: str) -> float:
780
+ """Get target QPS for a specific GPU type with flexible matching."""
781
+ assert isinstance(self.target_qps_per_replica,
782
+ dict), 'Expected dict for instance-aware logic'
783
+ target_qps_dict = self.target_qps_per_replica
784
+
785
+ # Direct match first
786
+ if gpu_type in target_qps_dict:
787
+ return target_qps_dict[gpu_type]
788
+
789
+ # Try matching by base name (e.g., 'A100' matches 'A100:1')
790
+ for config_key in target_qps_dict.keys():
791
+ # Remove count suffix (e.g., 'A100:1' -> 'A100')
792
+ base_name = config_key.split(':')[0]
793
+ if gpu_type == base_name:
794
+ return target_qps_dict[config_key]
795
+
796
+ # Fallback to minimum QPS
797
+ logger.warning(f'No matching QPS found for GPU type: {gpu_type}. '
798
+ f'Available types: {list(target_qps_dict.keys())}. '
799
+ f'Using minimum QPS as fallback.')
800
+ return min(target_qps_dict.values())
801
+
802
+ def _get_gpu_type_from_replica_info(
803
+ self, replica_info: 'replica_managers.ReplicaInfo') -> str:
804
+ """Extract GPU type from ReplicaInfo object."""
805
+ gpu_type = 'unknown'
806
+ handle = replica_info.handle()
807
+ if handle is not None:
808
+ accelerators = handle.launched_resources.accelerators
809
+ if accelerators and len(accelerators) > 0:
810
+ # Get the first accelerator type
811
+ gpu_type = list(accelerators.keys())[0]
812
+ return gpu_type
813
+
814
+ def _extract_target_qps_list_from_ready_replicas(
815
+ self,
816
+ replica_infos: List['replica_managers.ReplicaInfo']) -> List[float]:
817
+ """Extract target QPS list from current READY replicas."""
818
+ ready_replica_qps = []
819
+
820
+ for replica_info in replica_infos:
821
+ # Check if replica is READY
822
+ if replica_info.status != serve_state.ReplicaStatus.READY:
823
+ logger.info(
824
+ f'Replica {replica_info.replica_id} '
825
+ f'not ready (status: {replica_info.status}), skipping')
826
+ continue
827
+
828
+ gpu_type = self._get_gpu_type_from_replica_info(replica_info)
829
+
830
+ # Use flexible matching logic
831
+ qps_for_this_gpu = self._get_target_qps_for_gpu_type(gpu_type)
832
+ ready_replica_qps.append(qps_for_this_gpu)
833
+ logger.info(f'Ready replica {replica_info.replica_id} '
834
+ f'with GPU {gpu_type}: {qps_for_this_gpu} QPS')
835
+
836
+ if ready_replica_qps:
837
+ logger.info(
838
+ f'Target QPS list from ready replicas: {ready_replica_qps}')
839
+ return ready_replica_qps
840
+
841
+ return []
842
+
843
+ def _select_replicas_to_scale_down_by_qps(
844
+ self, num_replicas_to_scale_down: int,
845
+ replica_infos: List['replica_managers.ReplicaInfo']) -> List[int]:
846
+ """Select replicas to scale down (lowest QPS first)."""
847
+ # Create a list of (replica_info, target_qps) tuples
848
+ replica_qps_pairs: List[Tuple['replica_managers.ReplicaInfo',
849
+ float]] = []
850
+
851
+ for info in replica_infos:
852
+ # Include old-version replicas as well so they also get a target_qps
853
+ # assigned. Skip terminal replicas only.
854
+ if info.is_terminal:
855
+ continue
856
+
857
+ # Get GPU type directly from replica info
858
+ gpu_type = self._get_gpu_type_from_replica_info(info)
859
+
860
+ # Use flexible matching logic
861
+ target_qps = self._get_target_qps_for_gpu_type(gpu_type)
862
+
863
+ replica_qps_pairs.append((info, float(target_qps)))
864
+ logger.info(f'Replica {info.replica_id} '
865
+ f'with GPU {gpu_type}: {target_qps} QPS')
866
+
867
+ # Create a mapping from replica_id to target_qps for sorting
868
+ replica_qps_map = {
869
+ info.replica_id: target_qps
870
+ for info, target_qps in replica_qps_pairs
871
+ }
872
+
873
+ # Sort replicas by: 1. status order, 2. target_qps (asc),
874
+ # 3. version (asc), 4. replica_id (desc)
875
+ sorted_replicas = sorted(
876
+ replica_infos,
877
+ key=lambda info: (
878
+ info.status.scale_down_decision_order(),
879
+ replica_qps_map.get(info.replica_id, float('inf')),
880
+ info.version,
881
+ -info.replica_id,
882
+ ))
883
+
884
+ selected_replica_ids = []
885
+ for info in sorted_replicas:
886
+ if info.is_terminal:
887
+ continue
888
+ selected_replica_ids.append(info.replica_id)
889
+ if len(selected_replica_ids) >= num_replicas_to_scale_down:
890
+ break
891
+
892
+ logger.info(
893
+ f'Selected {len(selected_replica_ids)} replicas to scale down: '
894
+ f'{selected_replica_ids}')
895
+ return selected_replica_ids
896
+
897
+ def update_version(self, version: int, spec: 'service_spec.SkyServiceSpec',
898
+ update_mode: serve_utils.UpdateMode) -> None:
899
+ super(RequestRateAutoscaler,
900
+ self).update_version(version, spec, update_mode)
901
+ # Ensure it's a dict and re-assign using setattr to avoid typing
902
+ assert isinstance(spec.target_qps_per_replica, dict), \
903
+ 'InstanceAware Autoscaler requires dict type target_qps_per_replica'
904
+ self.target_qps_per_replica = spec.target_qps_per_replica
905
+
906
+
565
907
  class FallbackRequestRateAutoscaler(RequestRateAutoscaler):
566
908
  """FallbackRequestRateAutoscaler
567
909
 
sky/serve/controller.py CHANGED
@@ -78,7 +78,11 @@ class SkyServeController:
78
78
  assert record is not None, ('No service record found for '
79
79
  f'{self._service_name}')
80
80
  active_versions = record['active_versions']
81
- logger.info(f'All replica info: {replica_infos}')
81
+ logger.info(f'All replica info for autoscaler: {replica_infos}')
82
+
83
+ # Autoscaler now extracts GPU type info directly from
84
+ # replica_infos in generate_scaling_decisions method
85
+ # for better decoupling.
82
86
  scaling_options = self._autoscaler.generate_scaling_decisions(
83
87
  replica_infos, active_versions)
84
88
  for scaling_option in scaling_options:
@@ -118,11 +122,37 @@ class SkyServeController:
118
122
  timestamps: List[int] = request_aggregator.get('timestamps', [])
119
123
  logger.info(f'Received {len(timestamps)} inflight requests.')
120
124
  self._autoscaler.collect_request_information(request_aggregator)
121
- return responses.JSONResponse(content={
122
- 'ready_replica_urls':
123
- self._replica_manager.get_active_replica_urls()
124
- },
125
- status_code=200)
125
+
126
+ # Get replica information for instance-aware load balancing
127
+ replica_infos = serve_state.get_replica_infos(self._service_name)
128
+ ready_replica_urls = self._replica_manager.get_active_replica_urls()
129
+
130
+ # Use URL-to-info mapping to avoid duplication
131
+ replica_info = {}
132
+ for info in replica_infos:
133
+ if info.url in ready_replica_urls:
134
+ # Get GPU type from handle.launched_resources.accelerators
135
+ gpu_type = 'unknown'
136
+ handle = info.handle()
137
+ if handle is not None:
138
+ accelerators = handle.launched_resources.accelerators
139
+ if accelerators and len(accelerators) > 0:
140
+ # Get the first accelerator type
141
+ gpu_type = list(accelerators.keys())[0]
142
+
143
+ replica_info[info.url] = {'gpu_type': gpu_type}
144
+
145
+ # Check that all ready replica URLs are included in replica_info
146
+ missing_urls = set(ready_replica_urls) - set(replica_info.keys())
147
+ if missing_urls:
148
+ logger.warning(f'Ready replica URLs missing from replica_info: '
149
+ f'{missing_urls}')
150
+ # fallback: add missing URLs with unknown GPU type
151
+ for url in missing_urls:
152
+ replica_info[url] = {'gpu_type': 'unknown'}
153
+
154
+ return responses.JSONResponse(
155
+ content={'replica_info': replica_info}, status_code=200)
126
156
 
127
157
  @self._app.post('/controller/update_service')
128
158
  async def update_service(request: fastapi.Request) -> fastapi.Response:
@@ -30,11 +30,13 @@ class SkyServeLoadBalancer:
30
30
  """
31
31
 
32
32
  def __init__(
33
- self,
34
- controller_url: str,
35
- load_balancer_port: int,
36
- load_balancing_policy_name: Optional[str] = None,
37
- tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
33
+ self,
34
+ controller_url: str,
35
+ load_balancer_port: int,
36
+ load_balancing_policy_name: Optional[str] = None,
37
+ tls_credential: Optional[serve_utils.TLSCredential] = None,
38
+ target_qps_per_replica: Optional[Union[float, Dict[str, float]]] = None
39
+ ) -> None:
38
40
  """Initialize the load balancer.
39
41
 
40
42
  Args:
@@ -44,6 +46,9 @@ class SkyServeLoadBalancer:
44
46
  to use. Defaults to None.
45
47
  tls_credentials: The TLS credentials for HTTPS endpoint. Defaults
46
48
  to None.
49
+ target_qps_per_replica: Target QPS per replica for instance-aware
50
+ load balancing. Can be a float or dict mapping GPU types to QPS.
51
+ Defaults to None.
47
52
  """
48
53
  self._app = fastapi.FastAPI()
49
54
  self._controller_url: str = controller_url
@@ -51,6 +56,15 @@ class SkyServeLoadBalancer:
51
56
  # Use the registry to create the load balancing policy
52
57
  self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
53
58
  load_balancing_policy_name)
59
+
60
+ # Set accelerator QPS for instance-aware policies
61
+ if (target_qps_per_replica and
62
+ isinstance(target_qps_per_replica, dict) and
63
+ isinstance(self._load_balancing_policy,
64
+ lb_policies.InstanceAwareLeastLoadPolicy)):
65
+ self._load_balancing_policy.set_target_qps_per_accelerator(
66
+ target_qps_per_replica)
67
+
54
68
  logger.info('Starting load balancer with policy '
55
69
  f'{load_balancing_policy_name}.')
56
70
  self._request_aggregator: serve_utils.RequestsAggregator = (
@@ -73,6 +87,9 @@ class SkyServeLoadBalancer:
73
87
 
74
88
  async def _sync_with_controller_once(self) -> List[asyncio.Task]:
75
89
  close_client_tasks = []
90
+ ready_replica_urls = []
91
+ replica_info = {}
92
+
76
93
  async with aiohttp.ClientSession() as session:
77
94
  try:
78
95
  # Send request information
@@ -88,8 +105,8 @@ class SkyServeLoadBalancer:
88
105
  self._request_aggregator.clear()
89
106
  response.raise_for_status()
90
107
  response_json = await response.json()
91
- ready_replica_urls = response_json.get(
92
- 'ready_replica_urls', [])
108
+ replica_info = response_json.get('replica_info', {})
109
+ ready_replica_urls = list(replica_info.keys())
93
110
  except (aiohttp.ClientError, asyncio.TimeoutError) as e:
94
111
  logger.error(f'An error occurred when syncing with '
95
112
  f'the controller: {e}'
@@ -99,6 +116,11 @@ class SkyServeLoadBalancer:
99
116
  with self._client_pool_lock:
100
117
  self._load_balancing_policy.set_ready_replicas(
101
118
  ready_replica_urls)
119
+ # Set replica info for instance-aware policies
120
+ if isinstance(self._load_balancing_policy,
121
+ lb_policies.InstanceAwareLeastLoadPolicy):
122
+ self._load_balancing_policy.set_replica_info(
123
+ replica_info)
102
124
  for replica_url in ready_replica_urls:
103
125
  if replica_url not in self._client_pool:
104
126
  self._client_pool[replica_url] = httpx.AsyncClient(
@@ -265,23 +287,31 @@ class SkyServeLoadBalancer:
265
287
 
266
288
 
267
289
  def run_load_balancer(
268
- controller_addr: str,
269
- load_balancer_port: int,
270
- load_balancing_policy_name: Optional[str] = None,
271
- tls_credential: Optional[serve_utils.TLSCredential] = None) -> None:
290
+ controller_addr: str,
291
+ load_balancer_port: int,
292
+ load_balancing_policy_name: Optional[str] = None,
293
+ tls_credential: Optional[serve_utils.TLSCredential] = None,
294
+ target_qps_per_replica: Optional[Union[float, Dict[str, float]]] = None
295
+ ) -> None:
272
296
  """ Run the load balancer.
273
297
 
274
298
  Args:
275
299
  controller_addr: The address of the controller.
276
300
  load_balancer_port: The port where the load balancer listens to.
277
- policy_name: The name of the load balancing policy to use. Defaults to
278
- None.
301
+ policy_name: The name of the load balancing policy to use.
302
+ Defaults to None.
303
+ tls_credential:
304
+ The TLS credentials for HTTPS endpoint. Defaults to None.
305
+ target_qps_per_replica: Target QPS per replica for instance-aware
306
+ load balancing. Can be a float or dict mapping GPU types to QPS.
307
+ Defaults to None.
279
308
  """
280
309
  load_balancer = SkyServeLoadBalancer(
281
310
  controller_url=controller_addr,
282
311
  load_balancer_port=load_balancer_port,
283
312
  load_balancing_policy_name=load_balancing_policy_name,
284
- tls_credential=tls_credential)
313
+ tls_credential=tls_credential,
314
+ target_qps_per_replica=target_qps_per_replica)
285
315
  load_balancer.run()
286
316
 
287
317
 
@@ -305,5 +335,8 @@ if __name__ == '__main__':
305
335
  help=f'The load balancing policy to use. Available policies: '
306
336
  f'{", ".join(available_policies)}.')
307
337
  args = parser.parse_args()
308
- run_load_balancer(args.controller_addr, args.load_balancer_port,
309
- args.load_balancing_policy)
338
+ run_load_balancer(args.controller_addr,
339
+ args.load_balancer_port,
340
+ args.load_balancing_policy,
341
+ tls_credential=None,
342
+ target_qps_per_replica=None)