sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
14
14
 
15
15
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
16
16
  from sglang.srt.layers.quantization.fp8_kernel import (
17
+ fp8_dtype,
18
+ fp8_max,
19
+ is_fp8_fnuz,
17
20
  per_token_group_quant_fp8,
18
21
  scaled_fp8_quant,
19
22
  sglang_per_token_quant_fp8,
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
30
33
 
31
34
  _is_hip = is_hip()
32
35
  _is_cuda = is_cuda()
36
+ _is_fp8_fnuz = is_fp8_fnuz()
33
37
 
34
- if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
38
+ use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
39
+
40
+ if _is_hip and use_aiter_moe:
35
41
  from aiter import gemm_a8w8_blockscale
36
42
 
37
43
  if _is_cuda:
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
43
49
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
44
50
  TORCH_DEVICE_IDENTITY = None
45
51
 
46
- _TORCH_VERSION = torch.__version__.split("+")[0]
47
- try:
48
- _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
49
- except ValueError:
50
- _TORCH_VERSION_TUPLE = (0, 0, 0)
51
-
52
- # The condition to determine if it is on a platform that supports
53
- # torch._scaled_mm rowwise feature.
54
- # The condition is determined once as the operations
55
- # are time consuming.
56
- USE_ROWWISE_TORCH_SCALED_MM = (
57
- _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
58
- )
52
+
53
+ def use_rowwise_torch_scaled_mm():
54
+ _TORCH_VERSION = torch.__version__.split("+")[0]
55
+ try:
56
+ _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
57
+ except ValueError:
58
+ _TORCH_VERSION_TUPLE = (0, 0, 0)
59
+ if _is_hip:
60
+ # The condition to determine if it is on a platform that supports
61
+ # torch._scaled_mm rowwise feature.
62
+ # The condition is determined once as the operations
63
+ # are time consuming.
64
+ return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
65
+ return False
66
+
67
+
68
+ USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
59
69
 
60
70
 
61
71
  def cutlass_fp8_supported():
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
132
142
  output = fp8_blockwise_scaled_mm(
133
143
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
144
  )
135
- elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
145
+ elif _is_hip and use_aiter_moe:
136
146
  q_input, x_scale = per_token_group_quant_fp8(
137
147
  input_2d, block_size[1], column_major_scales=False
138
148
  )
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
164
174
 
165
175
 
166
176
  def input_to_float8(
167
- x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
177
+ x: torch.Tensor, dtype: torch.dtype = fp8_dtype
168
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
169
179
  """This function quantizes input values to float8 values with tensor-wise quantization."""
170
- finfo = torch.finfo(dtype)
171
180
  min_val, max_val = x.aminmax()
172
181
  amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
173
- fp8_max = finfo.max
174
- if _is_hip:
175
- dtype = torch.float8_e4m3fnuz
176
- fp8_max = 224.0
177
- scale = fp8_max / amax
178
- x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
182
+
183
+ if _is_fp8_fnuz:
184
+ dtype = fp8_dtype
185
+ fp_max = fp8_max
186
+ else:
187
+ finfo = torch.finfo(dtype)
188
+ fp_max = finfo.max
189
+
190
+ scale = fp_max / amax
191
+ x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
179
192
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
180
193
 
181
194
 
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
222
235
  return x_q_tensor, scale
223
236
 
224
237
 
238
+ def block_quant_dequant(
239
+ x_q_block: torch.Tensor,
240
+ x_s: torch.Tensor,
241
+ block_size: List[int],
242
+ dtype: torch.dtype,
243
+ ) -> torch.Tensor:
244
+ """This function converts block-wise quantization to unquantized.
245
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
246
+ and the block size.
247
+ The output is an unquantized tensor with dtype.
248
+ """
249
+ block_n, block_k = block_size[0], block_size[1]
250
+ n, k = x_q_block.shape
251
+ n_tiles = (n + block_n - 1) // block_n
252
+ k_tiles = (k + block_k - 1) // block_k
253
+ assert n_tiles == x_s.shape[0]
254
+ assert k_tiles == x_s.shape[1]
255
+
256
+ x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
257
+
258
+ for j in range(n_tiles):
259
+ for i in range(k_tiles):
260
+ x_q_block_tile = x_q_block[
261
+ j * block_n : min((j + 1) * block_n, n),
262
+ i * block_k : min((i + 1) * block_k, k),
263
+ ]
264
+ x_dq_block_tile = x_dq_block[
265
+ j * block_n : min((j + 1) * block_n, n),
266
+ i * block_k : min((i + 1) * block_k, k),
267
+ ]
268
+ x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
269
+
270
+ return x_dq_block
271
+
272
+
225
273
  def channel_quant_to_tensor_quant(
226
274
  x_q_channel: torch.Tensor,
227
275
  x_s: torch.Tensor,
@@ -76,7 +76,7 @@ def _per_token_group_quant_int8(
76
76
  y_s_ptr,
77
77
  # Stride of input
78
78
  y_stride,
79
- # Collums of input
79
+ # Columns of input
80
80
  N,
81
81
  # Avoid to divide zero
82
82
  eps,
@@ -370,7 +370,7 @@ def w8a8_block_int8_matmul(
370
370
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
371
371
  else:
372
372
  # Default config
373
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
373
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
374
374
  config = {
375
375
  "BLOCK_SIZE_M": 64,
376
376
  "BLOCK_SIZE_N": block_size[0],
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
8
8
  QuantizationConfig,
9
9
  QuantizeMethodBase,
10
10
  )
11
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
11
12
  from sglang.srt.layers.radix_attention import RadixAttention
12
- from sglang.srt.utils import is_hip
13
-
14
- _is_hip = is_hip()
15
13
 
16
14
  logger = logging.getLogger(__name__)
17
15
 
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
44
42
  torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45
43
  )
46
44
 
47
- @classmethod
48
- def is_fp8_fnuz(cls) -> bool:
49
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
50
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
51
-
52
45
  def apply(self, layer: torch.nn.Module) -> torch.Tensor:
53
46
  raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
54
47
 
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
57
50
  # We prefer to use separate k_scale and v_scale if present
58
51
  k_scale = layer.k_scale.to("cpu").tolist()
59
52
  v_scale = layer.v_scale.to("cpu").tolist()
60
- if _is_hip and self.is_fp8_fnuz():
53
+ if is_fp8_fnuz():
61
54
  k_scale *= 2
62
55
  v_scale *= 2
63
56
  elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
73
66
  scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
67
  k_scale = scale_to_duplicate.to("cpu").tolist()
75
68
  v_scale = scale_to_duplicate.to("cpu").tolist()
76
- if _is_hip and self.is_fp8_fnuz():
69
+ if is_fp8_fnuz():
77
70
  k_scale *= 2
78
71
  v_scale *= 2
79
72
 
@@ -14,11 +14,6 @@ if not _is_cuda:
14
14
  from vllm._custom_ops import scaled_fp8_quant
15
15
 
16
16
 
17
- def is_fp8_fnuz() -> bool:
18
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
19
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
20
-
21
-
22
17
  def is_layer_skipped(
23
18
  prefix: str,
24
19
  ignored_layers: List[str],
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
9
9
  QuantizationConfig,
10
10
  QuantizeMethodBase,
11
11
  )
12
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
12
+ from sglang.srt.layers.quantization.fp8_kernel import (
13
+ fp8_dtype,
14
+ is_fp8_fnuz,
15
+ per_token_group_quant_fp8,
16
+ )
13
17
  from sglang.srt.layers.quantization.fp8_utils import (
14
18
  apply_fp8_linear,
15
19
  cutlass_fp8_supported,
16
20
  input_to_float8,
17
21
  normalize_e4m3fn_to_e4m3fnuz,
18
22
  )
19
- from sglang.srt.utils import is_hip, set_weight_attrs
23
+ from sglang.srt.utils import set_weight_attrs
20
24
 
21
- _is_hip = is_hip()
25
+ _is_fp8_fnuz = is_fp8_fnuz()
22
26
 
23
27
 
24
28
  class W8A8Fp8Config(QuantizationConfig):
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
97
101
  if self.quantization_config.is_checkpoint_fp8_serialized:
98
102
  weight_scale = layer.weight_scale.detach()
99
103
  # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
100
- if _is_hip:
104
+ if _is_fp8_fnuz:
101
105
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
102
106
  weight=weight, weight_scale=weight_scale
103
107
  )
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
113
117
  layer.weight, layer.weight.shape[-1]
114
118
  )
115
119
  weight_scale = weight_scale.t().contiguous()
116
- if _is_hip:
117
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
118
- weight=weight, weight_scale=weight_scale
119
- )
120
120
  else:
121
121
  # if cutlass not supported, we fall back to use torch._scaled_mm
122
122
  # which requires per tensor quantization on weight
123
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
124
123
  qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
125
124
 
126
125
  # Update the layer with the new values.
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
227
226
  ):
228
227
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
229
228
 
230
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
231
229
  # WEIGHTS
232
230
  w13_weight = torch.nn.Parameter(
233
231
  torch.empty(
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
239
239
 
240
240
 
241
241
  def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
242
- assert len(top_logprobs_nums) == logprobs.shape[0], (
243
- len(top_logprobs_nums),
244
- logprobs.shape[0],
245
- )
246
242
  max_k = max(top_logprobs_nums)
247
243
  ret = logprobs.topk(max_k, dim=1)
248
244
  values = ret.values.tolist()
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  tensor_model_parallel_all_reduce,
15
15
  )
16
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
16
17
  from sglang.srt.layers.parameter import BasevLLMParameter
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
214
215
  self,
215
216
  num_embeddings: int,
216
217
  embedding_dim: int,
218
+ *,
217
219
  params_dtype: Optional[torch.dtype] = None,
218
220
  org_num_embeddings: Optional[int] = None,
219
221
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
220
222
  quant_config: Optional[QuantizationConfig] = None,
221
223
  prefix: str = "",
222
224
  enable_tp: bool = True,
225
+ use_attn_tp_group: bool = False,
223
226
  use_presharded_weights: bool = False,
224
227
  ):
225
228
  super().__init__()
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
227
230
 
228
231
  self.enable_tp = enable_tp
229
232
  if self.enable_tp:
230
- tp_rank = get_tensor_model_parallel_rank()
231
- self.tp_size = get_tensor_model_parallel_world_size()
233
+ if use_attn_tp_group:
234
+ tp_rank = get_attention_tp_rank()
235
+ self.tp_size = get_attention_tp_size()
236
+ else:
237
+ tp_rank = get_tensor_model_parallel_rank()
238
+ self.tp_size = get_tensor_model_parallel_world_size()
232
239
  else:
240
+ assert use_attn_tp_group is False
233
241
  tp_rank = 0
234
242
  self.tp_size = 1
235
243
 
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
519
527
  self,
520
528
  num_embeddings: int,
521
529
  embedding_dim: int,
530
+ *,
522
531
  bias: bool = False,
523
532
  params_dtype: Optional[torch.dtype] = None,
524
533
  org_num_embeddings: Optional[int] = None,
525
534
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
526
535
  quant_config: Optional[QuantizationConfig] = None,
527
536
  prefix: str = "",
537
+ use_attn_tp_group: bool = False,
528
538
  use_presharded_weights: bool = False,
529
539
  ):
530
540
  super().__init__(
531
541
  num_embeddings,
532
542
  embedding_dim,
533
- params_dtype,
534
- org_num_embeddings,
535
- padding_size,
536
- quant_config,
537
- prefix,
543
+ params_dtype=params_dtype,
544
+ org_num_embeddings=org_num_embeddings,
545
+ padding_size=padding_size,
546
+ quant_config=quant_config,
547
+ prefix=prefix,
548
+ use_attn_tp_group=use_attn_tp_group,
538
549
  use_presharded_weights=use_presharded_weights,
539
550
  )
540
551
  self.quant_config = quant_config
@@ -100,7 +100,7 @@ class LoRAManager:
100
100
  self.configs[name] = LoRAConfig(path)
101
101
  self.hf_target_names.update(self.configs[name].target_modules)
102
102
 
103
- # Target lora weight names for lora_a and lora_b modules repectively.
103
+ # Target lora weight names for lora_a and lora_b modules respectively.
104
104
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
105
105
  self.lora_weight_names: Set[Tuple[str]] = set(
106
106
  [get_stacked_name(module) for module in self.hf_target_names]
@@ -156,18 +156,15 @@ class LoRAManager:
156
156
  # set up batch info shared by all lora modules
157
157
  bs = forward_batch.batch_size
158
158
 
159
- if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
160
- # Do in-place updates when CUDA graph is enabled. Note that
161
- # if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
162
- # will also use these preallocated buffers, no matter whether
163
- # the batch can use CUDA graph or not.
159
+ if (
160
+ hasattr(self, "max_bs_in_cuda_graph")
161
+ and bs <= self.max_bs_in_cuda_graph
162
+ and forward_batch.forward_mode.is_cuda_graph()
163
+ ):
164
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
165
+ # could use CUDA graph.
164
166
  self.cuda_graph_batch_info.bs = bs
165
- if forward_batch.forward_mode.is_extend():
166
- self.cuda_graph_batch_info.seg_lens[:bs].copy_(
167
- forward_batch.extend_seq_lens
168
- )
169
- else:
170
- self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
167
+ self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171
168
  torch.cumsum(
172
169
  self.cuda_graph_batch_info.seg_lens[:bs],
173
170
  dim=0,
@@ -201,10 +198,10 @@ class LoRAManager:
201
198
  max_len = int(torch.max(seg_lens))
202
199
  weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
203
200
 
204
- lora_ranks = torch.empty(
201
+ lora_ranks = torch.zeros(
205
202
  (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
206
203
  )
207
- scalings = torch.empty(
204
+ scalings = torch.zeros(
208
205
  (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
209
206
  )
210
207
  for i, lora_path in enumerate(forward_batch.lora_paths):
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
50
50
  self.uid_to_buffer_id: Dict[Optional[str], int] = {}
51
51
 
52
52
  # Buffer idx -> lora uid in memory pool
53
- # All uids are initalized as empty strings for empty buffer slots
54
- # Here we don't initalize to None since None is a valid uid
53
+ # All uids are initialized as empty strings for empty buffer slots
54
+ # Here we don't initialize to None since None is a valid uid
55
55
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
56
56
 
57
57
  def get_lora_A_shape(
58
58
  self, module_name: str, base_model: torch.nn.Module
59
59
  ) -> Tuple[int]:
60
60
  """
61
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
61
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
62
62
  """
63
63
  input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
64
  c = get_stacked_multiply(module_name)
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
75
75
  self, module_name: str, base_model: torch.nn.Module
76
76
  ) -> Tuple[int]:
77
77
  """
78
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
78
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
79
79
  """
80
80
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
81
  c = get_stacked_multiply(module_name)
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
77
77
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
78
78
  )
79
79
 
80
- # Iteate to compute the block in output matrix
80
+ # Iterate to compute the block in output matrix
81
81
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
82
82
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
83
  x_tile = tl.load(
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
79
79
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
80
80
  )
81
81
 
82
- # Iteate to compute the block in output matrix
82
+ # Iterate to compute the block in output matrix
83
83
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
84
84
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
85
  x_tile = tl.load(
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
67
67
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
68
68
  )
69
69
 
70
- # Iteate to compute the block in output matrix
70
+ # Iterate to compute the block in output matrix
71
71
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
72
72
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
73
  x_tile = tl.load(
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
69
69
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
70
70
  )
71
71
 
72
- # Iteate to compute the block in output matrix
72
+ # Iterate to compute the block in output matrix
73
73
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
74
74
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
75
  x_tile = tl.load(
sglang/srt/lora/utils.py CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
79
79
  module_name: str, config: AutoConfig, base_model: torch.nn.Module
80
80
  ) -> Tuple[int]:
81
81
  """
82
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
82
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
83
83
  """
84
84
 
85
85
  if hasattr(base_model, "get_hidden_dim"):