sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,7 @@ class Scheduler:
149
149
  if not self.spec_algorithm.is_none()
150
150
  else 1
151
151
  )
152
+ self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
152
153
 
153
154
  # Distributed rank info
154
155
  self.dp_size = server_args.dp_size
@@ -281,6 +282,7 @@ class Scheduler:
281
282
  # Print debug info
282
283
  logger.info(
283
284
  f"max_total_num_tokens={self.max_total_num_tokens}, "
285
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
284
286
  f"max_prefill_tokens={self.max_prefill_tokens}, "
285
287
  f"max_running_requests={self.max_running_requests}, "
286
288
  f"context_len={self.model_config.context_len}"
@@ -408,6 +410,11 @@ class Scheduler:
408
410
  },
409
411
  )
410
412
 
413
+ # The largest prefill length of a single request
414
+ self._largest_prefill_len: int = 0
415
+ # The largest context length (prefill + generation) of a single request
416
+ self._largest_prefill_decode_len: int = 0
417
+
411
418
  # Init request dispatcher
412
419
  self._request_dispatcher = TypeBasedDispatcher(
413
420
  [
@@ -480,7 +487,7 @@ class Scheduler:
480
487
  @torch.no_grad()
481
488
  def event_loop_overlap(self):
482
489
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
483
- result_queue = deque()
490
+ self.result_queue = deque()
484
491
 
485
492
  while True:
486
493
  recv_reqs = self.recv_requests()
@@ -491,7 +498,7 @@ class Scheduler:
491
498
 
492
499
  if batch:
493
500
  result = self.run_batch(batch)
494
- result_queue.append((batch.copy(), result))
501
+ self.result_queue.append((batch.copy(), result))
495
502
 
496
503
  if self.last_batch is None:
497
504
  # Create a dummy first batch to start the pipeline for overlap schedule.
@@ -505,7 +512,7 @@ class Scheduler:
505
512
 
506
513
  if self.last_batch:
507
514
  # Process the results of the last batch
508
- tmp_batch, tmp_result = result_queue.popleft()
515
+ tmp_batch, tmp_result = self.result_queue.popleft()
509
516
  tmp_batch.next_batch_sampling_info = (
510
517
  self.tp_worker.cur_sampling_info if batch else None
511
518
  )
@@ -636,7 +643,7 @@ class Scheduler:
636
643
  self.waiting_queue.append(req)
637
644
  return
638
645
 
639
- # Handle image inputs
646
+ # Handle multimodal inputs
640
647
  if recv_req.image_inputs is not None:
641
648
  image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
642
649
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
@@ -660,24 +667,23 @@ class Scheduler:
660
667
  self.waiting_queue.append(req)
661
668
  return
662
669
 
663
- # Copy more attributes
664
- req.logprob_start_len = recv_req.logprob_start_len
665
-
666
- if req.logprob_start_len == -1:
667
- # By default, only return the logprobs for output tokens
668
- req.logprob_start_len = len(req.origin_input_ids) - 1
669
-
670
670
  # Validate prompts length
671
671
  error_msg = validate_input_length(
672
672
  req,
673
673
  self.max_req_input_len,
674
674
  self.server_args.allow_auto_truncate,
675
675
  )
676
-
677
676
  if error_msg:
678
677
  self.waiting_queue.append(req)
679
678
  return
680
679
 
680
+ # Copy more attributes
681
+ if recv_req.logprob_start_len == -1:
682
+ # By default, only return the logprobs for output tokens
683
+ req.logprob_start_len = len(req.origin_input_ids) - 1
684
+ else:
685
+ req.logprob_start_len = recv_req.logprob_start_len
686
+
681
687
  req.sampling_params.max_new_tokens = min(
682
688
  (
683
689
  req.sampling_params.max_new_tokens
@@ -725,15 +731,26 @@ class Scheduler:
725
731
  req.tokenizer = self.tokenizer
726
732
 
727
733
  # Validate prompts length
728
- validate_input_length(
734
+ error_msg = validate_input_length(
729
735
  req,
730
736
  self.max_req_input_len,
731
737
  self.server_args.allow_auto_truncate,
732
738
  )
739
+ if error_msg:
740
+ self.waiting_queue.append(req)
741
+ return
733
742
 
743
+ # Copy more attributes
744
+ req.logprob_start_len = len(req.origin_input_ids) - 1
734
745
  self.waiting_queue.append(req)
735
746
 
736
- def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
747
+ def log_prefill_stats(
748
+ self,
749
+ adder: PrefillAdder,
750
+ can_run_list: List[Req],
751
+ running_bs: ScheduleBatch,
752
+ has_being_chunked: bool,
753
+ ):
737
754
  self.tree_cache_metrics["total"] += (
738
755
  adder.log_input_tokens + adder.log_hit_tokens
739
756
  ) / 10**9
@@ -815,10 +832,16 @@ class Scheduler:
815
832
  available_size = (
816
833
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
817
834
  )
818
- if available_size != self.max_total_num_tokens:
835
+ protected_size = self.tree_cache.protected_size()
836
+ memory_leak = available_size != (
837
+ self.max_total_num_tokens
838
+ if not self.enable_hierarchical_cache
839
+ else self.max_total_num_tokens - protected_size
840
+ )
841
+ if memory_leak:
819
842
  msg = (
820
843
  "KV cache pool leak detected!"
821
- f"{available_size=}, {self.max_total_num_tokens=}\n"
844
+ f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
822
845
  )
823
846
  warnings.warn(msg)
824
847
  if crash_on_warnings():
@@ -933,7 +956,14 @@ class Scheduler:
933
956
  res = adder.add_one_req(req)
934
957
  if res != AddReqResult.CONTINUE:
935
958
  if res == AddReqResult.NO_TOKEN:
936
- self.batch_is_full = True
959
+ if self.enable_hierarchical_cache:
960
+ # Set batch_is_full after making sure there are requests that can be served
961
+ self.batch_is_full = len(adder.can_run_list) > 0 or (
962
+ self.running_batch is not None
963
+ and not self.running_batch.is_empty()
964
+ )
965
+ else:
966
+ self.batch_is_full = True
937
967
  break
938
968
  if self.server_args.prefill_only_one_req:
939
969
  break
@@ -1023,7 +1053,7 @@ class Scheduler:
1023
1053
  )
1024
1054
 
1025
1055
  # Check for jump-forward
1026
- if not self.disable_jump_forward:
1056
+ if not self.disable_jump_forward and batch.has_grammar:
1027
1057
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
1028
1058
  self.waiting_queue.extend(jump_forward_reqs)
1029
1059
  if batch.is_empty():
@@ -1044,26 +1074,23 @@ class Scheduler:
1044
1074
  self.forward_ct += 1
1045
1075
 
1046
1076
  if self.is_generation:
1047
- if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
1048
- if self.spec_algorithm.is_none():
1049
- model_worker_batch = batch.get_model_worker_batch()
1050
- logits_output, next_token_ids = (
1051
- self.tp_worker.forward_batch_generation(model_worker_batch)
1052
- )
1053
- else:
1054
- (
1055
- logits_output,
1056
- next_token_ids,
1057
- model_worker_batch,
1058
- num_accepted_tokens,
1059
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1060
- self.spec_num_total_accepted_tokens += (
1061
- num_accepted_tokens + batch.batch_size()
1062
- )
1063
- self.spec_num_total_forward_ct += batch.batch_size()
1064
- self.num_generated_tokens += num_accepted_tokens
1077
+ if self.spec_algorithm.is_none():
1078
+ model_worker_batch = batch.get_model_worker_batch()
1079
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1080
+ model_worker_batch
1081
+ )
1065
1082
  else:
1066
- assert False, "batch.extend_num_tokens == 0, this is unexpected!"
1083
+ (
1084
+ logits_output,
1085
+ next_token_ids,
1086
+ model_worker_batch,
1087
+ num_accepted_tokens,
1088
+ ) = self.draft_worker.forward_batch_speculative_generation(batch)
1089
+ self.spec_num_total_accepted_tokens += (
1090
+ num_accepted_tokens + batch.batch_size()
1091
+ )
1092
+ self.spec_num_total_forward_ct += batch.batch_size()
1093
+ self.num_generated_tokens += num_accepted_tokens
1067
1094
  batch.output_ids = next_token_ids
1068
1095
 
1069
1096
  ret = GenerationBatchResult(
@@ -1072,7 +1099,6 @@ class Scheduler:
1072
1099
  bid=model_worker_batch.bid,
1073
1100
  )
1074
1101
  else: # embedding or reward model
1075
- assert batch.extend_num_tokens != 0
1076
1102
  model_worker_batch = batch.get_model_worker_batch()
1077
1103
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1078
1104
  ret = EmbeddingBatchResult(
@@ -1371,6 +1397,7 @@ class Scheduler:
1371
1397
  prompt_tokens = []
1372
1398
  completion_tokens = []
1373
1399
  cached_tokens = []
1400
+ spec_verify_ct = []
1374
1401
 
1375
1402
  if return_logprob:
1376
1403
  input_token_logprobs_val = []
@@ -1424,6 +1451,9 @@ class Scheduler:
1424
1451
  completion_tokens.append(len(req.output_ids))
1425
1452
  cached_tokens.append(req.cached_tokens)
1426
1453
 
1454
+ if not self.spec_algorithm.is_none():
1455
+ spec_verify_ct.append(req.spec_verify_ct)
1456
+
1427
1457
  if return_logprob:
1428
1458
  input_token_logprobs_val.append(req.input_token_logprobs_val)
1429
1459
  input_token_logprobs_idx.append(req.input_token_logprobs_idx)
@@ -1451,6 +1481,7 @@ class Scheduler:
1451
1481
  prompt_tokens,
1452
1482
  completion_tokens,
1453
1483
  cached_tokens,
1484
+ spec_verify_ct,
1454
1485
  input_token_logprobs_val,
1455
1486
  input_token_logprobs_idx,
1456
1487
  output_token_logprobs_val,
@@ -1564,6 +1595,15 @@ class Scheduler:
1564
1595
  self.grammar_backend.reset()
1565
1596
  self.req_to_token_pool.clear()
1566
1597
  self.token_to_kv_pool.clear()
1598
+
1599
+ if not self.spec_algorithm.is_none():
1600
+ self.draft_worker.model_runner.req_to_token_pool.clear()
1601
+ self.draft_worker.model_runner.token_to_kv_pool.clear()
1602
+
1603
+ self.num_generated_tokens = 0
1604
+ self.forward_ct_decode = 0
1605
+ self.spec_num_total_accepted_tokens = 0
1606
+ self.spec_num_total_forward_ct = 0
1567
1607
  torch.cuda.empty_cache()
1568
1608
  logger.info("Cache flushed successfully!")
1569
1609
  if_success = True
@@ -785,6 +785,9 @@ class TokenizerManager:
785
785
  i,
786
786
  )
787
787
 
788
+ if self.server_args.speculative_algorithm:
789
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
790
+
788
791
  if not isinstance(recv_obj, BatchEmbeddingOut):
789
792
  meta_info.update(
790
793
  {
@@ -809,6 +812,7 @@ class TokenizerManager:
809
812
  "embedding": recv_obj.embeddings[i],
810
813
  "meta_info": meta_info,
811
814
  }
815
+
812
816
  state.out_list.append(out_dict)
813
817
  state.finished = recv_obj.finished_reasons[i] is not None
814
818
  state.event.set()
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
41
41
  def evictable_size(self):
42
42
  pass
43
43
 
44
+ @abstractmethod
45
+ def protected_size(self):
46
+ raise NotImplementedError()
47
+
44
48
  def total_size(self):
45
49
  raise NotImplementedError()
46
50
 
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
85
85
 
86
86
  def evictable_size(self):
87
87
  return 0
88
+
89
+ def protected_size(self):
90
+ return 0
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class TreeNode:
37
- def __init__(self):
37
+
38
+ counter = 0
39
+
40
+ def __init__(self, id: Optional[int] = None):
38
41
  self.children = defaultdict(TreeNode)
39
42
  self.parent = None
40
43
  self.key = None
@@ -42,6 +45,23 @@ class TreeNode:
42
45
  self.lock_ref = 0
43
46
  self.last_access_time = time.time()
44
47
 
48
+ self.hit_count = 0
49
+ # indicating the node is loading KV cache from host
50
+ self.loading = False
51
+ # store the host indices of KV cache
52
+ self.host_value = None
53
+
54
+ self.id = TreeNode.counter if id is None else id
55
+ TreeNode.counter += 1
56
+
57
+ @property
58
+ def evicted(self):
59
+ return self.value is None
60
+
61
+ @property
62
+ def backuped(self):
63
+ return self.host_value is not None
64
+
45
65
  def __lt__(self, other: "TreeNode"):
46
66
  return self.last_access_time < other.last_access_time
47
67
 
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
75
95
  self.root_node.value = []
76
96
  self.root_node.lock_ref = 1
77
97
  self.evictable_size_ = 0
98
+ self.protected_size_ = 0
78
99
 
79
100
  def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
101
  """Find the matching prefix from the radix tree.
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
203
224
  while node != self.root_node:
204
225
  if node.lock_ref == 0:
205
226
  self.evictable_size_ -= len(node.value)
227
+ self.protected_size_ += len(node.value)
206
228
  delta -= len(node.value)
207
229
  node.lock_ref += 1
208
230
  node = node.parent
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
216
238
  while node != self.root_node:
217
239
  if node.lock_ref == 1:
218
240
  self.evictable_size_ += len(node.value)
241
+ self.protected_size_ -= len(node.value)
219
242
  delta += len(node.value)
220
243
  node.lock_ref -= 1
221
244
  node = node.parent
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
224
247
  def evictable_size(self):
225
248
  return self.evictable_size_
226
249
 
250
+ def protected_size(self):
251
+ # protected size refers to the size of the cache that is locked
252
+ return self.protected_size_
253
+
227
254
  ##### Internal Helper Functions #####
228
255
 
229
256
  def _match_prefix_helper(
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
303
330
  self.evictable_size_ -= len(node.key)
304
331
 
305
332
  def _total_size_helper(self, node: TreeNode):
333
+ if node.evicted:
334
+ return 0
306
335
  x = len(node.value)
307
336
  for child in node.children.values():
308
337
  x += self._total_size_helper(child)
@@ -24,7 +24,7 @@ import tqdm
24
24
  from vllm.model_executor.custom_op import CustomOp
25
25
 
26
26
  from sglang.srt.distributed import get_tensor_model_parallel_rank
27
- from sglang.srt.distributed.parallel_state import graph_capture
27
+ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
29
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
30
30
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  from sglang.srt.model_executor.model_runner import ModelRunner
39
39
 
40
40
 
41
- def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
41
+ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
42
42
  for sub in model._modules.values():
43
43
  if isinstance(sub, CustomOp):
44
44
  if reverse:
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
47
47
  else:
48
48
  # NOTE: Temporarily workaround MoE
49
49
  if "FusedMoE" in sub.__class__.__name__:
50
- if batch_size == 1:
50
+ if num_tokens == 1:
51
51
  # The performance of torch.compile on this layer is not always good when bs > 1,
52
52
  # so we decide to only use torch.compile when bs =1
53
53
  sub._forward_method = fused_moe_forward_native
@@ -55,22 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
55
55
  sub._forward_method = sub.forward_native
56
56
  setattr(sub, "is_torch_compile", True)
57
57
  if isinstance(sub, torch.nn.Module):
58
- _to_torch(sub, reverse, batch_size)
58
+ _to_torch(sub, reverse, num_tokens)
59
59
 
60
60
 
61
61
  @contextmanager
62
62
  def patch_model(
63
63
  model: torch.nn.Module,
64
64
  enable_compile: bool,
65
- batch_size: int,
66
- tp_group: "GroupCoordinator",
65
+ num_tokens: int,
66
+ tp_group: GroupCoordinator,
67
67
  ):
68
68
  """Patch the model to make it compatible with with torch.compile"""
69
69
  backup_ca_comm = None
70
70
 
71
71
  try:
72
72
  if enable_compile:
73
- _to_torch(model, reverse=False, batch_size=batch_size)
73
+ _to_torch(model, reverse=False, num_tokens=num_tokens)
74
74
  backup_ca_comm = tp_group.ca_comm
75
75
  # Use custom-allreduce here.
76
76
  # We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -85,7 +85,7 @@ def patch_model(
85
85
  yield model.forward
86
86
  finally:
87
87
  if enable_compile:
88
- _to_torch(model, reverse=True, batch_size=batch_size)
88
+ _to_torch(model, reverse=True, num_tokens=num_tokens)
89
89
  tp_group.ca_comm = backup_ca_comm
90
90
 
91
91
 
@@ -149,9 +149,18 @@ class CudaGraphRunner:
149
149
  and bs <= model_runner.server_args.cuda_graph_max_bs
150
150
  ]
151
151
 
152
+ self.compile_bs = (
153
+ [
154
+ bs
155
+ for bs in self.capture_bs
156
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
157
+ ]
158
+ if self.use_torch_compile
159
+ else []
160
+ )
161
+
152
162
  self.capture_forward_mode = ForwardMode.DECODE
153
163
  self.num_tokens_per_bs = 1
154
-
155
164
  if model_runner.spec_algorithm.is_eagle():
156
165
  if self.model_runner.is_draft_worker:
157
166
  self.num_tokens_per_bs = (
@@ -163,16 +172,6 @@ class CudaGraphRunner:
163
172
  self.model_runner.server_args.speculative_num_draft_tokens
164
173
  )
165
174
 
166
- self.compile_bs = (
167
- [
168
- bs
169
- for bs in self.capture_bs
170
- if bs <= self.model_runner.server_args.torch_compile_max_bs
171
- ]
172
- if self.use_torch_compile
173
- else []
174
- )
175
-
176
175
  # Attention backend
177
176
  self.max_bs = max(self.capture_bs)
178
177
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -180,7 +179,6 @@ class CudaGraphRunner:
180
179
  self.seq_len_fill_value = (
181
180
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
182
181
  )
183
-
184
182
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
185
183
  self.encoder_len_fill_value = 0
186
184
 
@@ -189,14 +187,14 @@ class CudaGraphRunner:
189
187
 
190
188
  # Common inputs
191
189
  with torch.device("cuda"):
192
- self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
190
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
193
191
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
194
192
  self.seq_lens = torch.full(
195
193
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
196
194
  )
197
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
195
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
198
196
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
199
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
197
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
200
198
 
201
199
  # Speculative_inference
202
200
  if model_runner.spec_algorithm.is_eagle():
@@ -285,8 +283,8 @@ class CudaGraphRunner:
285
283
  with patch_model(
286
284
  self.model_runner.model,
287
285
  bs in self.compile_bs,
288
- bs,
289
- self.model_runner.tp_group,
286
+ num_tokens=bs * self.num_tokens_per_bs,
287
+ tp_group=self.model_runner.tp_group,
290
288
  ) as forward:
291
289
  (
292
290
  graph,
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import maybe_torch_compile
41
+ from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention import AttentionBackend
@@ -282,6 +282,9 @@ class ForwardBatch:
282
282
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
283
283
  lora_paths=batch.lora_paths,
284
284
  sampling_info=batch.sampling_info,
285
+ req_to_token_pool=model_runner.req_to_token_pool,
286
+ token_to_kv_pool=model_runner.token_to_kv_pool,
287
+ attn_backend=model_runner.attn_backend,
285
288
  spec_algorithm=batch.spec_algorithm,
286
289
  spec_info=batch.spec_info,
287
290
  capture_hidden_mode=batch.capture_hidden_mode,
@@ -336,11 +339,6 @@ class ForwardBatch:
336
339
  if model_runner.model_is_mrope:
337
340
  ret.compute_mrope_positions(model_runner, batch)
338
341
 
339
- # Init attention information
340
- ret.req_to_token_pool = model_runner.req_to_token_pool
341
- ret.token_to_kv_pool = model_runner.token_to_kv_pool
342
- ret.attn_backend = model_runner.attn_backend
343
-
344
342
  # Init lora information
345
343
  if model_runner.server_args.lora_paths is not None:
346
344
  model_runner.lora_manager.prepare_lora_batch(ret)
@@ -417,6 +415,6 @@ def compute_position_torch(
417
415
  return positions.to(torch.int64), extend_start_loc
418
416
 
419
417
 
420
- @maybe_torch_compile(dynamic=True)
418
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
421
419
  def clamp_position(seq_lens):
422
420
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@@ -185,9 +185,12 @@ class ModelRunner:
185
185
  self.load_model()
186
186
 
187
187
  # Apply torchao quantization
188
- apply_torchao_config_to_model(
189
- self.model, global_server_args_dict["torchao_config"]
190
- )
188
+ torchao_applied = getattr(self.model, "torchao_applied", False)
189
+ # In layered loading, torchao may have been applied
190
+ if not torchao_applied:
191
+ apply_torchao_config_to_model(
192
+ self.model, global_server_args_dict["torchao_config"]
193
+ )
191
194
 
192
195
  # Apply torch TP if the model supports it
193
196
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
@@ -215,7 +218,7 @@ class ModelRunner:
215
218
 
216
219
  def init_torch_distributed(self):
217
220
  logger.info("Init torch distributed begin.")
218
- # Init torch distributed
221
+
219
222
  torch.get_device_module(self.device).set_device(self.gpu_id)
220
223
  if self.device == "cuda":
221
224
  backend = "nccl"
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  return model.eval()
375
375
 
376
376
 
377
+ class LayeredModelLoader(DefaultModelLoader):
378
+ """Model loader that loads weights layer by layer so that one can quantize a
379
+ layer before loading another to make the peak memory envelope smaller."""
380
+
381
+ def __init__(self, load_config: LoadConfig):
382
+ # Back to the default load format
383
+ load_config.load_format = LoadFormat.AUTO
384
+ super().__init__(load_config)
385
+
386
+ def load_model(
387
+ self,
388
+ *,
389
+ model_config: ModelConfig,
390
+ device_config: DeviceConfig,
391
+ ) -> nn.Module:
392
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
393
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
394
+
395
+ torchao_config = global_server_args_dict.get("torchao_config")
396
+ target_device = torch.device(device_config.device)
397
+
398
+ with set_default_torch_dtype(model_config.dtype):
399
+ # Create model on meta device
400
+ with torch.device("meta"):
401
+ model = _initialize_model(
402
+ model_config,
403
+ self.load_config,
404
+ )
405
+
406
+ # Check model's layered load support
407
+ if not hasattr(model, "load_weights_to_module"):
408
+ raise ValueError(
409
+ "LayeredModelLoader requires the model to have a "
410
+ "`load_weights_to_module` method. "
411
+ f"{model_config.model_path} does not support it."
412
+ )
413
+
414
+ # Get all weights from disk
415
+ weights = self._get_all_weights(model_config, model)
416
+
417
+ # Helper function to recursively fill the weights of a module
418
+ def fill_module(module, fqn: List[str], weights):
419
+ """
420
+ fqn: list of strings representing the fully qualified name of `module`.
421
+ """
422
+ # Layer by layer
423
+ for name, submod in module.named_children():
424
+ fill_module(submod, fqn + [name], weights)
425
+
426
+ # First materialize on target device
427
+ module.to_empty(device=target_device, recurse=False)
428
+ fqn_path = ".".join(fqn)
429
+ # Fill weights
430
+ model.load_weights_to_module(
431
+ fqn_path,
432
+ weights,
433
+ )
434
+ # Quantize weights if applicable
435
+ if torchao_config and "proj" in fqn_path:
436
+ # Note: `None` here is needed to indicate no filter, see
437
+ # `apply_torchao_config_to_model` for details.
438
+ apply_torchao_config_to_model(module, torchao_config, None)
439
+
440
+ # Start calling on root module
441
+ fill_module(model, [], weights)
442
+
443
+ if torchao_config:
444
+ model.torchao_applied = True
445
+
446
+ return model.eval()
447
+
448
+
377
449
  class DummyModelLoader(BaseModelLoader):
378
450
  """Model loader that will set model weights to random values."""
379
451
 
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1149
1221
  if load_config.load_format == LoadFormat.GGUF:
1150
1222
  return GGUFModelLoader(load_config)
1151
1223
 
1224
+ if load_config.load_format == LoadFormat.LAYERED:
1225
+ return LayeredModelLoader(load_config)
1226
+
1152
1227
  return DefaultModelLoader(load_config)