sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,244 @@
1
+ from typing import Any, Callable, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
+ from sglang.srt.layers.linear import LinearMethodBase
8
+ from sglang.srt.layers.parameter import (
9
+ ChannelQuantScaleParameter,
10
+ GroupQuantScaleParameter,
11
+ ModelWeightParameter,
12
+ )
13
+ from sglang.srt.layers.quantization.base_config import (
14
+ QuantizationConfig,
15
+ QuantizeMethodBase,
16
+ )
17
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
18
+ from sglang.srt.utils import is_cuda
19
+
20
+ _is_cuda = is_cuda()
21
+ if _is_cuda:
22
+ from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm
23
+
24
+
25
+ QoQ_SUPPORTED_WEIGHT_BITS = [4]
26
+ QoQ_SUPPORTED_GROUP_SIZES = [-1, 128]
27
+
28
+
29
+ class QoQConfig(QuantizationConfig):
30
+ """Config class for QoQ Quantization.
31
+
32
+ - Weight: static, per-channel/group, asymmetric
33
+ - Activation: dynamic, per-token, symmetric
34
+
35
+ Reference: https://arxiv.org/abs/2405.04532
36
+ https://github.com/mit-han-lab/omniserve
37
+ """
38
+
39
+ def __init__(self, weight_bits: int, group_size: int) -> None:
40
+ self.weight_bits = weight_bits
41
+ self.group_size = group_size
42
+
43
+ # Verify
44
+ if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS:
45
+ raise ValueError(
46
+ f"QoQ does not support weight_bits = {self.weight_bits}. "
47
+ f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} "
48
+ "are supported."
49
+ )
50
+ if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES:
51
+ raise ValueError(
52
+ f"QoQ does not support group_size = {self.group_size}. "
53
+ f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} "
54
+ "are supported."
55
+ )
56
+
57
+ # 4 bits packed into 8 bit datatype.
58
+ self.pack_factor = 8 // self.weight_bits
59
+
60
+ def __repr__(self) -> str:
61
+ return "QoQConfig(weight_bits={}, group_size={})".format(
62
+ self.weight_bits, self.group_size
63
+ )
64
+
65
+ @classmethod
66
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
67
+ return [torch.float16]
68
+
69
+ @classmethod
70
+ def get_min_capability(cls) -> int:
71
+ return 80
72
+
73
+ @classmethod
74
+ def get_name(self) -> str:
75
+ return "qoq"
76
+
77
+ @classmethod
78
+ def get_config_filenames(cls) -> List[str]:
79
+ """List of filenames to search for in the model directory."""
80
+ return [
81
+ "quant_config.json",
82
+ "quantize_config.json",
83
+ ]
84
+
85
+ @classmethod
86
+ def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
87
+ weight_bits = cls.get_from_keys(config, ["wbits"])
88
+ group_size = cls.get_from_keys(config, ["group_size"])
89
+ return cls(weight_bits, group_size)
90
+
91
+ def get_quant_method(
92
+ self,
93
+ layer: torch.nn.Module,
94
+ prefix: str,
95
+ ) -> Optional["QuantizeMethodBase"]:
96
+ from sglang.srt.layers.linear import LinearBase
97
+
98
+ if isinstance(layer, LinearBase):
99
+ return QoQLinearMethod(self)
100
+ return None
101
+
102
+ def get_scaled_act_names(self) -> List[str]:
103
+ return []
104
+
105
+
106
+ class QoQLinearMethod(LinearMethodBase):
107
+ """Linear method for QoQ.
108
+
109
+ Args:
110
+ quant_config: The QoQ quantization config.
111
+ """
112
+
113
+ def __init__(self, quant_config: QoQConfig):
114
+ self.quant_config = quant_config
115
+
116
+ def create_weights(
117
+ self,
118
+ layer: torch.nn.Module,
119
+ input_size_per_partition: int,
120
+ output_partition_sizes: List[int],
121
+ input_size: int,
122
+ output_size: int,
123
+ params_dtype: torch.dtype,
124
+ **extra_weight_attrs,
125
+ ):
126
+
127
+ weight_loader = extra_weight_attrs.get("weight_loader")
128
+
129
+ # Validate output_size_per_partition
130
+ output_size_per_partition = sum(output_partition_sizes)
131
+ if output_size_per_partition % 32 != 0:
132
+ raise ValueError(
133
+ f"Weight output_size_per_partition = "
134
+ f"{output_size_per_partition} is not divisible by 32."
135
+ )
136
+
137
+ # Validate input_size_per_partition
138
+ if input_size_per_partition % self.quant_config.pack_factor != 0:
139
+ raise ValueError(
140
+ f"Weight input_size_per_partition = "
141
+ f"{input_size_per_partition} is not divisible by "
142
+ f"pack_factor = {self.quant_config.pack_factor}."
143
+ )
144
+ if (
145
+ self.quant_config.group_size != -1
146
+ and input_size_per_partition % self.quant_config.group_size != 0
147
+ ):
148
+ raise ValueError(
149
+ f"Weight input_size_per_partition = "
150
+ f"{input_size_per_partition} is not divisible by "
151
+ f"group_size = {self.quant_config.group_size}."
152
+ )
153
+
154
+ qweight = ModelWeightParameter(
155
+ data=torch.empty(
156
+ output_size_per_partition,
157
+ input_size_per_partition // self.quant_config.pack_factor,
158
+ dtype=torch.int8,
159
+ ),
160
+ input_dim=1,
161
+ output_dim=0,
162
+ weight_loader=weight_loader,
163
+ )
164
+ layer.register_parameter("qweight", qweight)
165
+
166
+ s1_scales = ChannelQuantScaleParameter(
167
+ data=torch.empty(output_size_per_partition, dtype=torch.float16),
168
+ output_dim=0,
169
+ weight_loader=weight_loader,
170
+ )
171
+ layer.register_parameter("s1_scales", s1_scales)
172
+
173
+ if self.quant_config.group_size == -1:
174
+ s1_szeros = ChannelQuantScaleParameter(
175
+ data=torch.empty(output_size_per_partition, dtype=torch.float16),
176
+ output_dim=0,
177
+ weight_loader=weight_loader,
178
+ )
179
+ layer.register_parameter("s1_szeros", s1_szeros)
180
+ else:
181
+ s2_scales = GroupQuantScaleParameter(
182
+ data=torch.empty(
183
+ (
184
+ input_size_per_partition // self.quant_config.group_size,
185
+ output_size_per_partition,
186
+ ),
187
+ dtype=torch.int8,
188
+ ),
189
+ input_dim=0,
190
+ output_dim=1,
191
+ weight_loader=weight_loader,
192
+ )
193
+ layer.register_parameter("s2_scales", s2_scales)
194
+
195
+ s2_zeros = GroupQuantScaleParameter(
196
+ data=torch.empty(
197
+ (
198
+ input_size_per_partition // self.quant_config.group_size,
199
+ output_size_per_partition,
200
+ ),
201
+ dtype=torch.int8,
202
+ ),
203
+ input_dim=0,
204
+ output_dim=1,
205
+ weight_loader=weight_loader,
206
+ )
207
+ layer.register_parameter("s2_zeros", s2_zeros)
208
+
209
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
210
+ layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
211
+ layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False)
212
+ if self.quant_config.group_size == -1:
213
+ layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False)
214
+ else:
215
+ layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False)
216
+ layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False)
217
+
218
+ def apply(
219
+ self,
220
+ layer: torch.nn.Module,
221
+ x: torch.Tensor,
222
+ bias: Optional[torch.Tensor] = None,
223
+ ):
224
+ assert x.dtype == torch.float16, "QoQ only supports float16 input now"
225
+ if self.quant_config.group_size == -1:
226
+ x_q, x_scale, x_sum = per_token_quant_int8(
227
+ x, scale_dtype=x.dtype, cal_sum=True
228
+ )
229
+ out = qserve_w4a8_per_chn_gemm(
230
+ x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum
231
+ )
232
+ else:
233
+ x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype)
234
+ out = qserve_w4a8_per_group_gemm(
235
+ x_q,
236
+ layer.qweight,
237
+ layer.s2_zeros,
238
+ layer.s2_scales,
239
+ layer.s1_scales,
240
+ x_scale,
241
+ )
242
+ if bias is not None:
243
+ out = out + bias
244
+ return out
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
239
239
 
240
240
 
241
241
  def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
242
- assert len(top_logprobs_nums) == logprobs.shape[0], (
243
- len(top_logprobs_nums),
244
- logprobs.shape[0],
245
- )
246
242
  max_k = max(top_logprobs_nums)
247
243
  ret = logprobs.topk(max_k, dim=1)
248
244
  values = ret.values.tolist()
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  tensor_model_parallel_all_reduce,
15
15
  )
16
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
16
17
  from sglang.srt.layers.parameter import BasevLLMParameter
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
214
215
  self,
215
216
  num_embeddings: int,
216
217
  embedding_dim: int,
218
+ *,
217
219
  params_dtype: Optional[torch.dtype] = None,
218
220
  org_num_embeddings: Optional[int] = None,
219
221
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
220
222
  quant_config: Optional[QuantizationConfig] = None,
221
223
  prefix: str = "",
222
224
  enable_tp: bool = True,
225
+ use_attn_tp_group: bool = False,
223
226
  use_presharded_weights: bool = False,
224
227
  ):
225
228
  super().__init__()
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
227
230
 
228
231
  self.enable_tp = enable_tp
229
232
  if self.enable_tp:
230
- tp_rank = get_tensor_model_parallel_rank()
231
- self.tp_size = get_tensor_model_parallel_world_size()
233
+ if use_attn_tp_group:
234
+ tp_rank = get_attention_tp_rank()
235
+ self.tp_size = get_attention_tp_size()
236
+ else:
237
+ tp_rank = get_tensor_model_parallel_rank()
238
+ self.tp_size = get_tensor_model_parallel_world_size()
232
239
  else:
240
+ assert use_attn_tp_group is False
233
241
  tp_rank = 0
234
242
  self.tp_size = 1
235
243
 
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
519
527
  self,
520
528
  num_embeddings: int,
521
529
  embedding_dim: int,
530
+ *,
522
531
  bias: bool = False,
523
532
  params_dtype: Optional[torch.dtype] = None,
524
533
  org_num_embeddings: Optional[int] = None,
525
534
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
526
535
  quant_config: Optional[QuantizationConfig] = None,
527
536
  prefix: str = "",
537
+ use_attn_tp_group: bool = False,
528
538
  use_presharded_weights: bool = False,
529
539
  ):
530
540
  super().__init__(
531
541
  num_embeddings,
532
542
  embedding_dim,
533
- params_dtype,
534
- org_num_embeddings,
535
- padding_size,
536
- quant_config,
537
- prefix,
543
+ params_dtype=params_dtype,
544
+ org_num_embeddings=org_num_embeddings,
545
+ padding_size=padding_size,
546
+ quant_config=quant_config,
547
+ prefix=prefix,
548
+ use_attn_tp_group=use_attn_tp_group,
538
549
  use_presharded_weights=use_presharded_weights,
539
550
  )
540
551
  self.quant_config = quant_config
@@ -100,7 +100,7 @@ class LoRAManager:
100
100
  self.configs[name] = LoRAConfig(path)
101
101
  self.hf_target_names.update(self.configs[name].target_modules)
102
102
 
103
- # Target lora weight names for lora_a and lora_b modules repectively.
103
+ # Target lora weight names for lora_a and lora_b modules respectively.
104
104
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
105
105
  self.lora_weight_names: Set[Tuple[str]] = set(
106
106
  [get_stacked_name(module) for module in self.hf_target_names]
@@ -170,9 +170,7 @@ class LoRAManager:
170
170
  dim=0,
171
171
  out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
172
172
  )
173
- self.cuda_graph_batch_info.max_len = int(
174
- torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
175
- )
173
+ self.cuda_graph_batch_info.max_len = 1
176
174
 
177
175
  for i, lora_path in enumerate(forward_batch.lora_paths):
178
176
  self.cuda_graph_batch_info.weight_indices[i] = (
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
50
50
  self.uid_to_buffer_id: Dict[Optional[str], int] = {}
51
51
 
52
52
  # Buffer idx -> lora uid in memory pool
53
- # All uids are initalized as empty strings for empty buffer slots
54
- # Here we don't initalize to None since None is a valid uid
53
+ # All uids are initialized as empty strings for empty buffer slots
54
+ # Here we don't initialize to None since None is a valid uid
55
55
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
56
56
 
57
57
  def get_lora_A_shape(
58
58
  self, module_name: str, base_model: torch.nn.Module
59
59
  ) -> Tuple[int]:
60
60
  """
61
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
61
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
62
62
  """
63
63
  input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
64
  c = get_stacked_multiply(module_name)
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
75
75
  self, module_name: str, base_model: torch.nn.Module
76
76
  ) -> Tuple[int]:
77
77
  """
78
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
78
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
79
79
  """
80
80
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
81
  c = get_stacked_multiply(module_name)
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
77
77
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
78
78
  )
79
79
 
80
- # Iteate to compute the block in output matrix
80
+ # Iterate to compute the block in output matrix
81
81
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
82
82
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
83
  x_tile = tl.load(
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
79
79
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
80
80
  )
81
81
 
82
- # Iteate to compute the block in output matrix
82
+ # Iterate to compute the block in output matrix
83
83
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
84
84
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
85
  x_tile = tl.load(
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
67
67
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
68
68
  )
69
69
 
70
- # Iteate to compute the block in output matrix
70
+ # Iterate to compute the block in output matrix
71
71
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
72
72
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
73
  x_tile = tl.load(
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
69
69
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
70
70
  )
71
71
 
72
- # Iteate to compute the block in output matrix
72
+ # Iterate to compute the block in output matrix
73
73
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
74
74
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
75
  x_tile = tl.load(
sglang/srt/lora/utils.py CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
79
79
  module_name: str, config: AutoConfig, base_model: torch.nn.Module
80
80
  ) -> Tuple[int]:
81
81
  """
82
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
82
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
83
83
  """
84
84
 
85
85
  if hasattr(base_model, "get_hidden_dim"):
@@ -17,13 +17,13 @@ import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
19
  import threading
20
+ import time
20
21
  from enum import Enum, auto
21
22
 
22
23
  import psutil
23
24
  import setproctitle
24
25
  import zmq
25
26
 
26
- from sglang.srt.disaggregation.utils import DisaggregationMode
27
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
28
28
  from sglang.srt.managers.io_struct import (
29
29
  TokenizedEmbeddingReqInput,
@@ -158,7 +158,7 @@ class DataParallelController:
158
158
  # This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
159
159
  # function in scheduler.py will kill the scheduler.
160
160
  while True:
161
- pass
161
+ time.sleep(30 * 24 * 3600)
162
162
 
163
163
  def launch_dp_attention_schedulers(self, server_args, port_args):
164
164
  self.launch_tensor_parallel_group(server_args, port_args, 0, None)
@@ -210,7 +210,7 @@ class DataParallelController:
210
210
  )
211
211
  # compute zmq ports for this dp rank
212
212
  rank_port_args = PortArgs.init_new(server_args, dp_rank)
213
- # Data parallelism resues the tensor parallelism group,
213
+ # Data parallelism reuses the tensor parallelism group,
214
214
  # so all dp ranks should use the same nccl port.
215
215
  rank_port_args.nccl_port = port_args.nccl_port
216
216