sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ import logging
27
27
  import threading
28
28
  from enum import IntEnum
29
29
  from functools import wraps
30
- from typing import List, Tuple, Union
30
+ from typing import List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import psutil
@@ -49,7 +49,6 @@ class ReqToTokenPool:
49
49
  size: int,
50
50
  max_context_len: int,
51
51
  device: str,
52
- use_records: bool,
53
52
  enable_memory_saver: bool,
54
53
  ):
55
54
  memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -64,17 +63,9 @@ class ReqToTokenPool:
64
63
  (size, max_context_len), dtype=torch.int32, device=device
65
64
  )
66
65
  self.free_slots = list(range(size))
67
- self.write_records = []
68
- self.use_records = use_records
69
-
70
- if self.use_records:
71
- self.write = self.write_with_records
72
- else:
73
- self.write = self.write_without_records
74
66
 
75
67
  def write(self, indices, values):
76
- # Keep the signature for type checking. It will be assigned during runtime.
77
- raise NotImplementedError()
68
+ self.req_to_token[indices] = values
78
69
 
79
70
  def available_size(self):
80
71
  return len(self.free_slots)
@@ -96,23 +87,6 @@ class ReqToTokenPool:
96
87
 
97
88
  def clear(self):
98
89
  self.free_slots = list(range(self.size))
99
- self.write_records = []
100
-
101
- def write_without_records(self, indices, values):
102
- self.req_to_token[indices] = values
103
-
104
- def write_with_records(self, indices, values):
105
- self.req_to_token[indices] = values
106
- self.write_records.append((indices, values))
107
-
108
- def get_write_records(self):
109
- ret = self.write_records
110
- self.write_records = []
111
- return ret
112
-
113
- def apply_write_records(self, write_records: List[Tuple]):
114
- for indices, values in write_records:
115
- self.req_to_token[indices] = values
116
90
 
117
91
 
118
92
  class BaseTokenToKVPool:
@@ -296,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
296
270
  loc: torch.Tensor,
297
271
  cache_k: torch.Tensor,
298
272
  cache_v: torch.Tensor,
299
- k_scale: float = 1.0,
300
- v_scale: float = 1.0,
273
+ k_scale: Optional[float] = None,
274
+ v_scale: Optional[float] = None,
301
275
  ):
302
276
  layer_id = layer.layer_id
303
277
  if cache_k.dtype != self.dtype:
304
- cache_k = (cache_k / k_scale).to(self.dtype)
305
- cache_v = (cache_v / v_scale).to(self.dtype)
278
+ if k_scale is not None:
279
+ cache_k.div_(k_scale)
280
+ if v_scale is not None:
281
+ cache_v.div_(v_scale)
282
+ cache_k = cache_k.to(self.dtype)
283
+ cache_v = cache_v.to(self.dtype)
306
284
  if self.store_dtype != self.dtype:
307
285
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
308
286
  self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
@@ -25,6 +25,7 @@ class SchedulerStats:
25
25
  gen_throughput: float = 0.0
26
26
  num_queue_reqs: int = 0
27
27
  cache_hit_rate: float = 0.0
28
+ spec_accept_length: float = 0.0
28
29
 
29
30
 
30
31
  class SchedulerMetricsCollector:
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
37
38
 
38
39
  self.num_running_reqs = Gauge(
39
40
  name="sglang:num_running_reqs",
40
- documentation="The number of running requests",
41
+ documentation="The number of running requests.",
41
42
  labelnames=labels.keys(),
42
43
  multiprocess_mode="sum",
43
44
  )
44
45
 
45
46
  self.num_used_tokens = Gauge(
46
47
  name="sglang:num_used_tokens",
47
- documentation="The number of used tokens",
48
+ documentation="The number of used tokens.",
48
49
  labelnames=labels.keys(),
49
50
  multiprocess_mode="sum",
50
51
  )
51
52
 
52
53
  self.token_usage = Gauge(
53
54
  name="sglang:token_usage",
54
- documentation="The token usage",
55
+ documentation="The token usage.",
55
56
  labelnames=labels.keys(),
56
57
  multiprocess_mode="mostrecent",
57
58
  )
58
59
 
59
60
  self.gen_throughput = Gauge(
60
61
  name="sglang:gen_throughput",
61
- documentation="The generate throughput (token/s)",
62
+ documentation="The generation throughput (token/s).",
62
63
  labelnames=labels.keys(),
63
64
  multiprocess_mode="sum",
64
65
  )
65
66
 
66
67
  self.num_queue_reqs = Gauge(
67
68
  name="sglang:num_queue_reqs",
68
- documentation="The number of requests in the waiting queue",
69
+ documentation="The number of requests in the waiting queue.",
69
70
  labelnames=labels.keys(),
70
71
  multiprocess_mode="sum",
71
72
  )
72
73
 
73
74
  self.cache_hit_rate = Gauge(
74
75
  name="sglang:cache_hit_rate",
75
- documentation="The cache hit rate",
76
+ documentation="The prefix cache hit rate.",
77
+ labelnames=labels.keys(),
78
+ multiprocess_mode="mostrecent",
79
+ )
80
+
81
+ self.spec_accept_length = Gauge(
82
+ name="sglang:spec_accept_length",
83
+ documentation="The average acceptance length of speculative decoding.",
76
84
  labelnames=labels.keys(),
77
85
  multiprocess_mode="mostrecent",
78
86
  )
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
88
96
  self._log_gauge(self.gen_throughput, stats.gen_throughput)
89
97
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
90
98
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
99
+ self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
91
100
 
92
101
 
93
102
  class TokenizerMetricsCollector:
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
23
  import tqdm
24
- from vllm.distributed import get_tensor_model_parallel_rank
25
- from vllm.distributed.parallel_state import graph_capture
26
24
  from vllm.model_executor.custom_op import CustomOp
27
25
 
26
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
27
+ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
29
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
30
30
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -33,13 +33,12 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
- from sglang.srt.utils import monkey_patch_vllm_all_gather
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from sglang.srt.model_executor.model_runner import ModelRunner
40
39
 
41
40
 
42
- def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
41
+ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
43
42
  for sub in model._modules.values():
44
43
  if isinstance(sub, CustomOp):
45
44
  if reverse:
@@ -48,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
48
47
  else:
49
48
  # NOTE: Temporarily workaround MoE
50
49
  if "FusedMoE" in sub.__class__.__name__:
51
- if batch_size == 1:
50
+ if num_tokens == 1:
52
51
  # The performance of torch.compile on this layer is not always good when bs > 1,
53
52
  # so we decide to only use torch.compile when bs =1
54
53
  sub._forward_method = fused_moe_forward_native
@@ -56,23 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
56
55
  sub._forward_method = sub.forward_native
57
56
  setattr(sub, "is_torch_compile", True)
58
57
  if isinstance(sub, torch.nn.Module):
59
- _to_torch(sub, reverse, batch_size)
58
+ _to_torch(sub, reverse, num_tokens)
60
59
 
61
60
 
62
61
  @contextmanager
63
62
  def patch_model(
64
63
  model: torch.nn.Module,
65
64
  enable_compile: bool,
66
- batch_size: int,
67
- tp_group: "GroupCoordinator",
65
+ num_tokens: int,
66
+ tp_group: GroupCoordinator,
68
67
  ):
69
68
  """Patch the model to make it compatible with with torch.compile"""
70
69
  backup_ca_comm = None
71
70
 
72
71
  try:
73
72
  if enable_compile:
74
- _to_torch(model, reverse=False, batch_size=batch_size)
75
- monkey_patch_vllm_all_gather()
73
+ _to_torch(model, reverse=False, num_tokens=num_tokens)
76
74
  backup_ca_comm = tp_group.ca_comm
77
75
  # Use custom-allreduce here.
78
76
  # We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -87,8 +85,7 @@ def patch_model(
87
85
  yield model.forward
88
86
  finally:
89
87
  if enable_compile:
90
- _to_torch(model, reverse=True, batch_size=batch_size)
91
- monkey_patch_vllm_all_gather(reverse=True)
88
+ _to_torch(model, reverse=True, num_tokens=num_tokens)
92
89
  tp_group.ca_comm = backup_ca_comm
93
90
 
94
91
 
@@ -122,6 +119,7 @@ class CudaGraphRunner:
122
119
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
123
120
  self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
124
121
  self.tp_size = self.model_runner.tp_size
122
+ self.dp_size = self.model_runner.server_args.dp_size
125
123
 
126
124
  # Batch sizes to capture
127
125
  self.capture_bs = self.model_runner.server_args.cuda_graph_bs
@@ -151,9 +149,18 @@ class CudaGraphRunner:
151
149
  and bs <= model_runner.server_args.cuda_graph_max_bs
152
150
  ]
153
151
 
152
+ self.compile_bs = (
153
+ [
154
+ bs
155
+ for bs in self.capture_bs
156
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
157
+ ]
158
+ if self.use_torch_compile
159
+ else []
160
+ )
161
+
154
162
  self.capture_forward_mode = ForwardMode.DECODE
155
163
  self.num_tokens_per_bs = 1
156
-
157
164
  if model_runner.spec_algorithm.is_eagle():
158
165
  if self.model_runner.is_draft_worker:
159
166
  self.num_tokens_per_bs = (
@@ -165,16 +172,6 @@ class CudaGraphRunner:
165
172
  self.model_runner.server_args.speculative_num_draft_tokens
166
173
  )
167
174
 
168
- self.compile_bs = (
169
- [
170
- bs
171
- for bs in self.capture_bs
172
- if bs <= self.model_runner.server_args.torch_compile_max_bs
173
- ]
174
- if self.use_torch_compile
175
- else []
176
- )
177
-
178
175
  # Attention backend
179
176
  self.max_bs = max(self.capture_bs)
180
177
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -182,7 +179,6 @@ class CudaGraphRunner:
182
179
  self.seq_len_fill_value = (
183
180
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
184
181
  )
185
-
186
182
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
187
183
  self.encoder_len_fill_value = 0
188
184
 
@@ -191,14 +187,14 @@ class CudaGraphRunner:
191
187
 
192
188
  # Common inputs
193
189
  with torch.device("cuda"):
194
- self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
190
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
195
191
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
196
192
  self.seq_lens = torch.full(
197
193
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
198
194
  )
199
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
195
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
200
196
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
201
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
197
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
202
198
 
203
199
  # Speculative_inference
204
200
  if model_runner.spec_algorithm.is_eagle():
@@ -218,7 +214,7 @@ class CudaGraphRunner:
218
214
  if self.enable_dp_attention:
219
215
  self.gathered_buffer = torch.zeros(
220
216
  (
221
- self.max_bs * self.tp_size,
217
+ self.max_bs * self.dp_size,
222
218
  self.model_runner.model_config.hidden_size,
223
219
  ),
224
220
  dtype=self.model_runner.dtype,
@@ -287,8 +283,8 @@ class CudaGraphRunner:
287
283
  with patch_model(
288
284
  self.model_runner.model,
289
285
  bs in self.compile_bs,
290
- bs,
291
- self.model_runner.tp_group,
286
+ num_tokens=bs * self.num_tokens_per_bs,
287
+ tp_group=self.model_runner.tp_group,
292
288
  ) as forward:
293
289
  (
294
290
  graph,
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import maybe_torch_compile
41
+ from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention import AttentionBackend
@@ -282,6 +282,9 @@ class ForwardBatch:
282
282
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
283
283
  lora_paths=batch.lora_paths,
284
284
  sampling_info=batch.sampling_info,
285
+ req_to_token_pool=model_runner.req_to_token_pool,
286
+ token_to_kv_pool=model_runner.token_to_kv_pool,
287
+ attn_backend=model_runner.attn_backend,
285
288
  spec_algorithm=batch.spec_algorithm,
286
289
  spec_info=batch.spec_info,
287
290
  capture_hidden_mode=batch.capture_hidden_mode,
@@ -336,11 +339,6 @@ class ForwardBatch:
336
339
  if model_runner.model_is_mrope:
337
340
  ret.compute_mrope_positions(model_runner, batch)
338
341
 
339
- # Init attention information
340
- ret.req_to_token_pool = model_runner.req_to_token_pool
341
- ret.token_to_kv_pool = model_runner.token_to_kv_pool
342
- ret.attn_backend = model_runner.attn_backend
343
-
344
342
  # Init lora information
345
343
  if model_runner.server_args.lora_paths is not None:
346
344
  model_runner.lora_manager.prepare_lora_batch(ret)
@@ -417,6 +415,6 @@ def compute_position_torch(
417
415
  return positions.to(torch.int64), extend_start_loc
418
416
 
419
417
 
420
- @maybe_torch_compile(dynamic=True)
418
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
421
419
  def clamp_position(seq_lens):
422
420
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  import torch.distributed as dist
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.configs.device_config import DeviceConfig
26
+ from sglang.srt.configs.load_config import LoadConfig
27
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
28
+ from sglang.srt.distributed import (
25
29
  get_tp_group,
26
30
  init_distributed_environment,
27
31
  initialize_model_parallel,
28
32
  set_custom_all_reduce,
29
33
  )
30
-
31
- from sglang.srt.configs.device_config import DeviceConfig
32
- from sglang.srt.configs.load_config import LoadConfig
33
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
34
+ from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
34
35
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
36
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
37
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
38
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
+ from sglang.srt.layers.dp_attention import (
40
+ get_attention_tp_group,
41
+ get_attention_tp_size,
42
+ initialize_dp_attention,
43
+ )
38
44
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
39
45
  from sglang.srt.layers.sampler import Sampler
40
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -57,8 +63,8 @@ from sglang.srt.utils import (
57
63
  init_custom_process_group,
58
64
  is_cuda,
59
65
  is_hip,
66
+ monkey_patch_p2p_access_check,
60
67
  monkey_patch_vllm_gguf_config,
61
- monkey_patch_vllm_p2p_access_check,
62
68
  set_cpu_offload_max_bytes,
63
69
  )
64
70
 
@@ -101,8 +107,10 @@ class ModelRunner:
101
107
  self.model_config.attention_arch == AttentionArch.MLA
102
108
  and not self.server_args.disable_mla
103
109
  ):
104
- logger.info("MLA optimization is turned on. Use triton backend.")
105
- self.server_args.attention_backend = "triton"
110
+ # TODO: add MLA optimization on CPU
111
+ if self.server_args.device != "cpu":
112
+ logger.info("MLA optimization is turned on. Use triton backend.")
113
+ self.server_args.attention_backend = "triton"
106
114
 
107
115
  if self.server_args.enable_double_sparsity:
108
116
  logger.info(
@@ -159,6 +167,7 @@ class ModelRunner:
159
167
  "enable_nan_detection": server_args.enable_nan_detection,
160
168
  "enable_dp_attention": server_args.enable_dp_attention,
161
169
  "enable_ep_moe": server_args.enable_ep_moe,
170
+ "device": server_args.device,
162
171
  }
163
172
  )
164
173
 
@@ -176,9 +185,12 @@ class ModelRunner:
176
185
  self.load_model()
177
186
 
178
187
  # Apply torchao quantization
179
- apply_torchao_config_to_model(
180
- self.model, global_server_args_dict["torchao_config"]
181
- )
188
+ torchao_applied = getattr(self.model, "torchao_applied", False)
189
+ # In layered loading, torchao may have been applied
190
+ if not torchao_applied:
191
+ apply_torchao_config_to_model(
192
+ self.model, global_server_args_dict["torchao_config"]
193
+ )
182
194
 
183
195
  # Apply torch TP if the model supports it
184
196
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
@@ -206,7 +218,7 @@ class ModelRunner:
206
218
 
207
219
  def init_torch_distributed(self):
208
220
  logger.info("Init torch distributed begin.")
209
- # Init torch distributed
221
+
210
222
  torch.get_device_module(self.device).set_device(self.gpu_id)
211
223
  if self.device == "cuda":
212
224
  backend = "nccl"
@@ -216,9 +228,12 @@ class ModelRunner:
216
228
  backend = "gloo"
217
229
  elif self.device == "hpu":
218
230
  backend = "hccl"
231
+ elif self.device == "cpu":
232
+ backend = "gloo"
219
233
 
220
234
  if not self.server_args.enable_p2p_check:
221
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
235
+ monkey_patch_p2p_access_check()
236
+
222
237
  if self.server_args.dist_init_addr:
223
238
  dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
224
239
  else:
@@ -226,7 +241,7 @@ class ModelRunner:
226
241
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
227
242
 
228
243
  if not self.is_draft_worker:
229
- # Only initilzie the distributed environment on the target model worker.
244
+ # Only initialize the distributed environment on the target model worker.
230
245
  init_distributed_environment(
231
246
  backend=backend,
232
247
  world_size=self.tp_size,
@@ -235,11 +250,18 @@ class ModelRunner:
235
250
  distributed_init_method=dist_init_method,
236
251
  )
237
252
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
253
+ initialize_dp_attention(
254
+ enable_dp_attention=self.server_args.enable_dp_attention,
255
+ tp_rank=self.tp_rank,
256
+ tp_size=self.tp_size,
257
+ dp_size=self.server_args.dp_size,
258
+ )
238
259
 
239
260
  min_per_gpu_memory = get_available_gpu_memory(
240
261
  self.device, self.gpu_id, distributed=self.tp_size > 1
241
262
  )
242
263
  self.tp_group = get_tp_group()
264
+ self.attention_tp_group = get_attention_tp_group()
243
265
 
244
266
  # Check memory for tensor parallelism
245
267
  if self.tp_size > 1:
@@ -257,7 +279,8 @@ class ModelRunner:
257
279
  )
258
280
 
259
281
  # This can reduce thread conflicts and speed up weight loading.
260
- torch.set_num_threads(1)
282
+ if self.device != "cpu":
283
+ torch.set_num_threads(1)
261
284
  if self.device == "cuda":
262
285
  if torch.cuda.get_device_capability()[0] < 8:
263
286
  logger.info(
@@ -277,12 +300,15 @@ class ModelRunner:
277
300
  monkey_patch_vllm_gguf_config()
278
301
 
279
302
  # Load the model
303
+ # Remove monkey_patch when linear.py quant remove dependencies with vllm
304
+ monkey_patch_vllm_parallel_state()
280
305
  with self.memory_saver_adapter.region():
281
306
  self.model = get_model(
282
307
  model_config=self.model_config,
283
308
  load_config=self.load_config,
284
309
  device_config=DeviceConfig(self.device),
285
310
  )
311
+ monkey_patch_vllm_parallel_state(reverse=True)
286
312
 
287
313
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
288
314
  if self.server_args.quantization_param_path is not None:
@@ -521,7 +547,7 @@ class ModelRunner:
521
547
  )
522
548
  else:
523
549
  cell_size = (
524
- self.model_config.get_num_kv_heads(self.tp_size)
550
+ self.model_config.get_num_kv_heads(get_attention_tp_size())
525
551
  * self.model_config.head_dim
526
552
  * self.model_config.num_hidden_layers
527
553
  * 2
@@ -595,7 +621,6 @@ class ModelRunner:
595
621
  size=max_num_reqs + 1,
596
622
  max_context_len=self.model_config.context_len + 4,
597
623
  device=self.device,
598
- use_records=False,
599
624
  enable_memory_saver=self.server_args.enable_memory_saver,
600
625
  )
601
626
  if (
@@ -615,7 +640,7 @@ class ModelRunner:
615
640
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
616
641
  self.max_total_num_tokens,
617
642
  dtype=self.kv_cache_dtype,
618
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
643
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
619
644
  head_dim=self.model_config.head_dim,
620
645
  layer_num=self.model_config.num_hidden_layers,
621
646
  device=self.device,
@@ -626,7 +651,7 @@ class ModelRunner:
626
651
  self.token_to_kv_pool = MHATokenToKVPool(
627
652
  self.max_total_num_tokens,
628
653
  dtype=self.kv_cache_dtype,
629
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
654
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
630
655
  head_dim=self.model_config.head_dim,
631
656
  layer_num=self.model_config.num_hidden_layers,
632
657
  device=self.device,
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
22
  from transformers import AutoModelForCausalLM, PretrainedConfig
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
- from vllm.distributed import (
25
- get_tensor_model_parallel_rank,
26
- get_tensor_model_parallel_world_size,
27
- )
28
24
 
29
25
  from sglang.srt.configs.device_config import DeviceConfig
30
26
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
27
  from sglang.srt.configs.model_config import ModelConfig
28
+ from sglang.srt.distributed import (
29
+ get_tensor_model_parallel_rank,
30
+ get_tensor_model_parallel_world_size,
31
+ )
32
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
33
33
  from sglang.srt.model_loader.utils import (
34
34
  get_model_architecture,
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  return model.eval()
375
375
 
376
376
 
377
+ class LayeredModelLoader(DefaultModelLoader):
378
+ """Model loader that loads weights layer by layer so that one can quantize a
379
+ layer before loading another to make the peak memory envelope smaller."""
380
+
381
+ def __init__(self, load_config: LoadConfig):
382
+ # Back to the default load format
383
+ load_config.load_format = LoadFormat.AUTO
384
+ super().__init__(load_config)
385
+
386
+ def load_model(
387
+ self,
388
+ *,
389
+ model_config: ModelConfig,
390
+ device_config: DeviceConfig,
391
+ ) -> nn.Module:
392
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
393
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
394
+
395
+ torchao_config = global_server_args_dict.get("torchao_config")
396
+ target_device = torch.device(device_config.device)
397
+
398
+ with set_default_torch_dtype(model_config.dtype):
399
+ # Create model on meta device
400
+ with torch.device("meta"):
401
+ model = _initialize_model(
402
+ model_config,
403
+ self.load_config,
404
+ )
405
+
406
+ # Check model's layered load support
407
+ if not hasattr(model, "load_weights_to_module"):
408
+ raise ValueError(
409
+ "LayeredModelLoader requires the model to have a "
410
+ "`load_weights_to_module` method. "
411
+ f"{model_config.model_path} does not support it."
412
+ )
413
+
414
+ # Get all weights from disk
415
+ weights = self._get_all_weights(model_config, model)
416
+
417
+ # Helper function to recursively fill the weights of a module
418
+ def fill_module(module, fqn: List[str], weights):
419
+ """
420
+ fqn: list of strings representing the fully qualified name of `module`.
421
+ """
422
+ # Layer by layer
423
+ for name, submod in module.named_children():
424
+ fill_module(submod, fqn + [name], weights)
425
+
426
+ # First materialize on target device
427
+ module.to_empty(device=target_device, recurse=False)
428
+ fqn_path = ".".join(fqn)
429
+ # Fill weights
430
+ model.load_weights_to_module(
431
+ fqn_path,
432
+ weights,
433
+ )
434
+ # Quantize weights if applicable
435
+ if torchao_config and "proj" in fqn_path:
436
+ # Note: `None` here is needed to indicate no filter, see
437
+ # `apply_torchao_config_to_model` for details.
438
+ apply_torchao_config_to_model(module, torchao_config, None)
439
+
440
+ # Start calling on root module
441
+ fill_module(model, [], weights)
442
+
443
+ if torchao_config:
444
+ model.torchao_applied = True
445
+
446
+ return model.eval()
447
+
448
+
377
449
  class DummyModelLoader(BaseModelLoader):
378
450
  """Model loader that will set model weights to random values."""
379
451
 
@@ -496,7 +568,8 @@ class ShardedStateLoader(BaseModelLoader):
496
568
  device_config: DeviceConfig,
497
569
  ) -> nn.Module:
498
570
  from safetensors.torch import safe_open
499
- from vllm.distributed import get_tensor_model_parallel_rank
571
+
572
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
500
573
 
501
574
  local_model_path = self._prepare_weights(
502
575
  model_config.model_path, model_config.revision
@@ -556,7 +629,8 @@ class ShardedStateLoader(BaseModelLoader):
556
629
  max_size: Optional[int] = None,
557
630
  ) -> None:
558
631
  from safetensors.torch import save_file
559
- from vllm.distributed import get_tensor_model_parallel_rank
632
+
633
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
560
634
 
561
635
  if pattern is None:
562
636
  pattern = ShardedStateLoader.DEFAULT_PATTERN
@@ -1147,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1147
1221
  if load_config.load_format == LoadFormat.GGUF:
1148
1222
  return GGUFModelLoader(load_config)
1149
1223
 
1224
+ if load_config.load_format == LoadFormat.LAYERED:
1225
+ return LayeredModelLoader(load_config)
1226
+
1150
1227
  return DefaultModelLoader(load_config)