sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- import importlib
4
- from typing import TYPE_CHECKING, Callable, List, Optional
3
+ import importlib.util
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
24
24
  )
25
25
 
26
26
  if TYPE_CHECKING:
27
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
27
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
28
28
  from sglang.srt.layers.moe.topk import TopKOutput
29
29
 
30
30
  has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
116
116
  ) -> torch.Tensor:
117
117
 
118
118
  if use_intel_amx_backend(layer):
119
- return torch.ops.sgl_kernel.weight_packed_linear(
119
+ x_shapes = x.shape
120
+ if len(x_shapes) == 3:
121
+ x = x.view(-1, x.shape[-1])
122
+ output = torch.ops.sgl_kernel.weight_packed_linear(
120
123
  x, layer.weight, bias, True # is_vnni
121
124
  )
125
+ if len(x_shapes) == 3:
126
+ output = output.view(x_shapes[0], x_shapes[1], -1)
127
+ return output
122
128
 
123
129
  return F.linear(x, layer.weight, bias)
124
130
 
@@ -221,31 +227,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
221
227
  layer: torch.nn.Module,
222
228
  x: torch.Tensor,
223
229
  topk_output: TopKOutput,
224
- *,
225
- activation: str = "silu",
226
- apply_router_weight_on_input: bool = False,
227
- inplace: bool = True,
228
- no_combine: bool = False,
229
- routed_scaling_factor: Optional[float] = None,
230
- activation_alpha: Optional[float] = None,
231
- swiglu_limit: Optional[float] = None,
230
+ moe_runner_config: MoeRunnerConfig,
232
231
  ) -> torch.Tensor:
233
- kwargs = {}
234
- if activation_alpha is not None:
235
- kwargs["activation_alpha"] = activation_alpha
236
- if swiglu_limit is not None:
237
- kwargs["swiglu_limit"] = swiglu_limit
238
232
 
239
233
  return self.forward(
240
234
  x=x,
241
235
  layer=layer,
242
236
  topk_output=topk_output,
243
- activation=activation,
244
- apply_router_weight_on_input=apply_router_weight_on_input,
245
- inplace=inplace,
246
- no_combine=no_combine,
247
- routed_scaling_factor=routed_scaling_factor,
248
- **kwargs,
237
+ moe_runner_config=moe_runner_config,
249
238
  )
250
239
 
251
240
  def forward_cuda(
@@ -253,18 +242,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
253
242
  layer: torch.nn.Module,
254
243
  x: torch.Tensor,
255
244
  topk_output: TopKOutput,
256
- *,
257
- activation: str = "silu",
258
- apply_router_weight_on_input: bool = False,
259
- inplace: bool = True,
260
- no_combine: bool = False,
261
- routed_scaling_factor: Optional[float] = None,
262
- activation_alpha: Optional[float] = None,
263
- swiglu_limit: Optional[float] = None,
245
+ moe_runner_config: MoeRunnerConfig,
264
246
  ) -> torch.Tensor:
265
247
 
266
248
  if self.use_triton_kernels:
267
249
  if self.with_bias:
250
+ assert self.triton_kernel_moe_with_bias_forward is not None
268
251
  return self.triton_kernel_moe_with_bias_forward(
269
252
  hidden_states=x,
270
253
  w1=layer.w13_weight,
@@ -272,24 +255,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
272
255
  b1=layer.w13_weight_bias,
273
256
  b2=layer.w2_weight_bias,
274
257
  topk_output=topk_output,
275
- activation=activation,
276
- activation_alpha=activation_alpha,
277
- swiglu_limit=swiglu_limit,
258
+ moe_runner_config=moe_runner_config,
278
259
  w1_pcg=None,
279
260
  w2_pcg=None,
280
261
  )
281
262
  else:
263
+ assert self.triton_kernel_moe_forward is not None
282
264
  return self.triton_kernel_moe_forward(
283
265
  hidden_states=x,
284
266
  w1=layer.w13_weight,
285
267
  w2=layer.w2_weight,
286
268
  topk_output=topk_output,
269
+ moe_runner_config=moe_runner_config,
287
270
  )
288
271
  else:
289
272
  if _use_aiter:
290
- assert not no_combine, "unsupported"
273
+ assert not moe_runner_config.no_combine, "unsupported"
291
274
  topk_weights, topk_ids, _ = topk_output
292
- if apply_router_weight_on_input:
275
+ if moe_runner_config.apply_router_weight_on_input:
293
276
  assert (
294
277
  topk_weights.dim() == 2
295
278
  ), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -309,7 +292,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
309
292
  topk_ids,
310
293
  activation=(
311
294
  ActivationType.Silu
312
- if activation == "silu"
295
+ if moe_runner_config.activation == "silu"
313
296
  else ActivationType.Gelu
314
297
  ),
315
298
  )
@@ -325,13 +308,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
325
308
  b1=getattr(layer, "w13_weight_bias", None),
326
309
  b2=getattr(layer, "w2_weight_bias", None),
327
310
  topk_output=topk_output,
328
- inplace=inplace and not no_combine,
329
- activation=activation,
330
- apply_router_weight_on_input=apply_router_weight_on_input,
331
- no_combine=no_combine,
332
- routed_scaling_factor=routed_scaling_factor,
333
- activation_alpha=activation_alpha,
334
- swiglu_limit=swiglu_limit,
311
+ moe_runner_config=moe_runner_config,
335
312
  )
336
313
 
337
314
  def forward_cpu(
@@ -339,21 +316,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
339
316
  layer: torch.nn.Module,
340
317
  x: torch.Tensor,
341
318
  topk_output: TopKOutput,
342
- *,
343
- activation: str = "silu",
344
- apply_router_weight_on_input: bool = False,
345
- inplace: bool = True,
346
- no_combine: bool = False,
347
- routed_scaling_factor: Optional[float] = None,
319
+ moe_runner_config: MoeRunnerConfig,
348
320
  ) -> torch.Tensor:
349
- assert activation == "silu", f"activation = {activation} is not supported."
350
-
351
- if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
321
+ assert (
322
+ moe_runner_config.activation == "silu"
323
+ ), f"activation = {moe_runner_config.activation} is not supported."
324
+
325
+ if (
326
+ use_intel_amx_backend(layer)
327
+ and not moe_runner_config.apply_router_weight_on_input
328
+ ):
352
329
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
353
330
 
354
331
  topk_weights, topk_ids, _ = topk_output
355
332
  x, topk_weights = apply_topk_weights_cpu(
356
- apply_router_weight_on_input, topk_weights, x
333
+ moe_runner_config.apply_router_weight_on_input, topk_weights, x
357
334
  )
358
335
  return torch.ops.sgl_kernel.fused_experts_cpu(
359
336
  x,
@@ -378,11 +355,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
378
355
  layer,
379
356
  x,
380
357
  topk_output,
381
- activation=activation,
382
- apply_router_weight_on_input=apply_router_weight_on_input,
383
- inplace=inplace,
384
- no_combine=no_combine,
385
- routed_scaling_factor=routed_scaling_factor,
358
+ moe_runner_config,
386
359
  )
387
360
 
388
361
  def forward_npu(
@@ -390,12 +363,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
390
363
  layer: torch.nn.Module,
391
364
  x: torch.Tensor,
392
365
  topk_output: TopKOutput,
393
- *,
394
- activation: str = "silu",
395
- apply_router_weight_on_input: bool = False,
396
- inplace: bool = True,
397
- no_combine: bool = False,
398
- routed_scaling_factor: Optional[float] = None,
366
+ moe_runner_config: MoeRunnerConfig,
399
367
  ) -> torch.Tensor:
400
368
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
401
369
 
@@ -403,11 +371,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
403
371
  layer,
404
372
  x,
405
373
  topk_output,
406
- activation=activation,
407
- apply_router_weight_on_input=apply_router_weight_on_input,
408
- inplace=inplace,
409
- no_combine=no_combine,
410
- routed_scaling_factor=routed_scaling_factor,
374
+ moe_runner_config,
411
375
  )
412
376
 
413
377
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
@@ -146,6 +146,10 @@ def requantize_with_max_scale(
146
146
  return max_w_scale, weight
147
147
 
148
148
 
149
+ def update_tensor_inplace(old: torch.Tensor, new: torch.Tensor) -> None:
150
+ old.copy_(new)
151
+
152
+
149
153
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
150
154
  # Newly generated tensors need to replace existing tensors that are
151
155
  # already registered as parameters by vLLM (and won't be freed)
@@ -172,6 +176,27 @@ def replace_parameter(
172
176
  mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
173
177
 
174
178
 
179
+ def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
180
+ assert a.shape == b.shape
181
+ assert a.dtype == b.dtype == torch.float8_e4m3fn
182
+
183
+ a_u8 = a.view(torch.uint8)
184
+ b_u8 = b.view(torch.uint8)
185
+ diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
186
+
187
+ numel = a.numel()
188
+
189
+ count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
190
+ count_tiny_diff = (diff_u8 >= 1).sum().item()
191
+ count_large_diff = (diff_u8 >= 2).sum().item()
192
+
193
+ assert (
194
+ (count_diff_sign == 0)
195
+ and (count_tiny_diff / numel < 0.005)
196
+ and (count_large_diff == 0)
197
+ ), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
198
+
199
+
175
200
  # Match dynamic rules with module name (prefix) and override quantize
176
201
  # config if module (prefix) matches a rule
177
202
  def override_config(config: QuantizationConfig, prefix: str):
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
18
18
  from sglang.srt.utils import set_weight_attrs
19
19
 
20
20
  if TYPE_CHECKING:
21
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
21
+ from sglang.srt.layers.moe import MoeRunnerConfig
22
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
23
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
22
24
 
23
25
  ACTIVATION_SCHEMES = ["static", "dynamic"]
24
26
 
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
280
282
  self,
281
283
  layer: EPMoE,
282
284
  x: torch.Tensor,
283
- topk_output: TopKOutput,
284
- activation: str = "silu",
285
- apply_router_weight_on_input: bool = False,
286
- routed_scaling_factor: Optional[float] = None,
287
- **kwargs,
285
+ topk_output: StandardTopKOutput,
286
+ moe_runner_config: MoeRunnerConfig,
288
287
  ) -> torch.Tensor:
289
288
 
290
289
  # TODO(ch-wan): move it out of this class
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
324
323
  layer.w13_input_scale,
325
324
  layer.w2_input_scale,
326
325
  )
327
- if routed_scaling_factor is not None:
328
- output *= routed_scaling_factor
326
+ if moe_runner_config.routed_scaling_factor is not None:
327
+ output *= moe_runner_config.routed_scaling_factor
329
328
  return output
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
26
26
  from sglang.srt.utils import set_weight_attrs
27
27
 
28
28
  if TYPE_CHECKING:
29
- from sglang.srt.layers.moe.topk import TopKOutput
29
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
30
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
30
31
 
31
32
  _is_fp8_fnuz = is_fp8_fnuz()
32
33
 
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
269
270
  self,
270
271
  layer: torch.nn.Module,
271
272
  x: torch.Tensor,
272
- topk_output: TopKOutput,
273
- *,
274
- activation: str = "silu",
275
- apply_router_weight_on_input: bool = False,
276
- inplace: bool = True,
277
- no_combine: bool = False,
278
- routed_scaling_factor: Optional[float] = None,
273
+ topk_output: StandardTopKOutput,
274
+ moe_runner_config: MoeRunnerConfig,
279
275
  ) -> torch.Tensor:
280
276
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
281
277
 
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
284
280
  layer.w13_weight,
285
281
  layer.w2_weight,
286
282
  topk_output=topk_output,
287
- inplace=inplace,
288
- apply_router_weight_on_input=apply_router_weight_on_input,
289
- activation=activation,
283
+ moe_runner_config=moe_runner_config,
290
284
  use_fp8_w8a8=True,
291
285
  per_channel_quant=True,
292
286
  w1_scale=(layer.w13_weight_scale),
293
287
  w2_scale=(layer.w2_weight_scale),
294
288
  a1_scale=layer.w13_input_scale,
295
289
  a2_scale=layer.w2_input_scale,
296
- no_combine=no_combine,
297
- routed_scaling_factor=routed_scaling_factor,
298
290
  )
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
49
49
  )
50
50
 
51
51
  if TYPE_CHECKING:
52
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
52
53
  from sglang.srt.layers.moe.topk import TopKOutput
53
54
 
54
55
  _is_cuda = is_cuda()
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
487
488
  layer: torch.nn.Module,
488
489
  x: torch.Tensor,
489
490
  topk_output: TopKOutput,
490
- *,
491
- activation: str = "silu",
492
- apply_router_weight_on_input: bool = False,
493
- inplace: bool = True,
494
- no_combine: bool = False,
495
- routed_scaling_factor: Optional[float] = None,
491
+ moe_runner_config: MoeRunnerConfig,
496
492
  ) -> torch.Tensor:
497
493
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
498
494
 
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
501
497
 
502
498
  topk_weights, topk_ids, _ = topk_output
503
499
  x, topk_weights = apply_topk_weights_cpu(
504
- apply_router_weight_on_input, topk_weights, x
500
+ moe_runner_config.apply_router_weight_on_input, topk_weights, x
505
501
  )
506
502
  return torch.ops.sgl_kernel.fused_experts_cpu(
507
503
  x,
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
525
521
  layer.w13_weight,
526
522
  layer.w2_weight,
527
523
  topk_output=topk_output,
528
- inplace=inplace,
529
- activation=activation,
530
- apply_router_weight_on_input=apply_router_weight_on_input,
524
+ moe_runner_config=moe_runner_config,
531
525
  use_int8_w8a8=True,
532
526
  per_channel_quant=True,
533
527
  w1_scale=(layer.w13_weight_scale),
534
528
  w2_scale=(layer.w2_weight_scale),
535
529
  a1_scale=layer.w13_input_scale,
536
530
  a2_scale=layer.w2_input_scale,
537
- no_combine=no_combine,
538
- routed_scaling_factor=routed_scaling_factor,
539
531
  )
540
532
 
541
533
 
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
982
974
  layer,
983
975
  x,
984
976
  topk_output: TopKOutput,
985
- **kwargs,
977
+ moe_runner_config: MoeRunnerConfig,
986
978
  ) -> torch.Tensor:
987
979
 
988
980
  topk_weights, topk_ids, _ = topk_output
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
52
52
  v_head_dim: int = -1,
53
53
  sliding_window_size: int = -1,
54
54
  is_cross_attention: bool = False,
55
+ pos_encoding_mode: str = "NONE",
56
+ logit_capping_method: str = "tanh",
55
57
  quant_config: Optional[QuantizationConfig] = None,
56
58
  attn_type: AttentionType = AttentionType.DECODER,
57
59
  use_irope: bool = False,
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
81
83
  self.quant_method.create_weights(self)
82
84
  self.attn_type = attn_type
83
85
 
86
+ self.pos_encoding_mode = pos_encoding_mode
87
+ self.logit_capping_method = logit_capping_method
88
+ self.xai_temperature_len = -1
89
+
84
90
  def forward(
85
91
  self,
86
92
  q,
@@ -1029,6 +1029,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1029
1029
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
1030
  )
1031
1031
 
1032
+ @torch.compile(dynamic=True)
1032
1033
  def forward(
1033
1034
  self,
1034
1035
  positions: torch.Tensor,
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
32
32
  LoRABatchInfo,
33
33
  LoRAType,
34
34
  get_layer_id,
35
- get_normalized_lora_weight_names,
36
- get_weight_name,
35
+ get_normalized_target_modules,
36
+ get_target_module_name,
37
37
  )
38
38
  from sglang.srt.managers.io_struct import LoRAUpdateResult
39
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -55,7 +55,7 @@ class LoRAManager:
55
55
  tp_rank: int = 0,
56
56
  max_lora_rank: Optional[int] = None,
57
57
  target_modules: Optional[Iterable[str]] = None,
58
- lora_paths: Optional[Dict[str, LoRARef]] = None,
58
+ lora_paths: Optional[List[LoRARef]] = None,
59
59
  ):
60
60
  self.base_model: torch.nn.Module = base_model
61
61
  self.base_hf_config: AutoConfig = base_hf_config
@@ -350,19 +350,27 @@ class LoRAManager:
350
350
  """
351
351
  for layer_id, layer_modules in enumerate(self.lora_modules):
352
352
  for module_name, module in layer_modules.items():
353
- weight_name = get_weight_name(
354
- module_name, self.memory_pool.lora_weight_names
353
+ target_module = get_target_module_name(
354
+ module_name, self.memory_pool.target_modules
355
355
  )
356
356
  module.set_lora_info(
357
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
358
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
357
+ self.memory_pool.get_tensor(
358
+ target_module=target_module,
359
+ layer_id=layer_id,
360
+ lora_type=LoRAType.LORA_A,
361
+ ),
362
+ self.memory_pool.get_tensor(
363
+ target_module=target_module,
364
+ layer_id=layer_id,
365
+ lora_type=LoRAType.LORA_B,
366
+ ),
359
367
  )
360
368
 
361
369
  def init_state(
362
370
  self,
363
371
  max_lora_rank: Optional[int] = None,
364
372
  target_modules: Optional[Iterable[str]] = None,
365
- lora_paths: Optional[Dict[str, LoRARef]] = None,
373
+ lora_paths: Optional[List[LoRARef]] = None,
366
374
  ):
367
375
  """
368
376
  Initialize the internal (mutable) state of the LoRAManager.
@@ -380,12 +388,11 @@ class LoRAManager:
380
388
  max_lora_rank=max_lora_rank,
381
389
  target_modules=target_modules,
382
390
  )
383
- self.init_lora_weight_names()
384
391
  self.init_lora_modules()
385
392
  self.init_memory_pool()
386
393
  self.update_lora_info()
387
394
 
388
- def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
395
+ def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
389
396
  # Configs of all active LoRA adapters, indexed by LoRA ID.
390
397
  self.configs: Dict[str, LoRAConfig] = {}
391
398
 
@@ -399,7 +406,7 @@ class LoRAManager:
399
406
  self.num_pinned_loras: int = 0
400
407
 
401
408
  if lora_paths:
402
- for lora_ref in lora_paths.values():
409
+ for lora_ref in lora_paths:
403
410
  result = self.load_lora_adapter(lora_ref)
404
411
  if not result.success:
405
412
  raise RuntimeError(
@@ -426,6 +433,7 @@ class LoRAManager:
426
433
  "enable all support modules types. "
427
434
  )
428
435
  self.target_modules.update(config.target_modules)
436
+ self.target_modules = get_normalized_target_modules(self.target_modules)
429
437
 
430
438
  if max_lora_rank is not None:
431
439
  self.max_lora_rank = max_lora_rank
@@ -435,15 +443,6 @@ class LoRAManager:
435
443
  default=0,
436
444
  )
437
445
 
438
- def init_lora_weight_names(self):
439
- """
440
- Add new LoRA weight names if needed based on the current `self.configs`.
441
- """
442
-
443
- self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
444
- self.target_modules
445
- )
446
-
447
446
  def load_lora_weights(self, lora_ref: LoRARef):
448
447
  """
449
448
  Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
@@ -467,7 +466,7 @@ class LoRAManager:
467
466
  tp_size=self.tp_size,
468
467
  tp_rank=self.tp_rank,
469
468
  max_lora_rank=self.max_lora_rank,
470
- lora_weight_names=self.lora_weight_names,
469
+ target_modules=self.target_modules,
471
470
  base_model=self.base_model,
472
471
  )
473
472
 
@@ -494,7 +493,7 @@ class LoRAManager:
494
493
  continue
495
494
 
496
495
  # The module should be converted if it is included in target_names
497
- if module_name.split(".")[-1] in self.lora_weight_names:
496
+ if module_name.split(".")[-1] in self.target_modules:
498
497
  layer_id = get_layer_id(module_name)
499
498
  self.lora_modules[layer_id][module_name] = self.set_lora_module(
500
499
  module_name, module
@@ -59,9 +59,9 @@ class LoRARegistry:
59
59
  update / eventual consistency model between the tokenizer manager process and the scheduler processes.
60
60
  """
61
61
 
62
- def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
62
+ def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
63
63
  assert lora_paths is None or all(
64
- isinstance(lora, LoRARef) for lora in lora_paths.values()
64
+ isinstance(lora, LoRARef) for lora in lora_paths
65
65
  ), (
66
66
  "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
67
67
  "Please file an issue if you see this error."
@@ -78,7 +78,7 @@ class LoRARegistry:
78
78
 
79
79
  # Initialize the registry with provided LoRA paths, if present.
80
80
  if lora_paths:
81
- for lora_ref in lora_paths.values():
81
+ for lora_ref in lora_paths:
82
82
  self._register_adapter(lora_ref)
83
83
 
84
84
  async def register(self, lora_ref: LoRARef):