sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib
2
4
  import sys
3
5
  from types import MappingProxyType
4
- from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
5
7
 
6
8
  import torch
7
9
  from torch.nn.parameter import Parameter
@@ -11,21 +13,20 @@ from sglang.srt.distributed import (
11
13
  get_tensor_model_parallel_world_size,
12
14
  )
13
15
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
14
- from sglang.srt.layers.linear import (
15
- LinearMethodBase,
16
- RowParallelLinear,
17
- UnquantizedLinearMethod,
18
- )
19
16
  from sglang.srt.layers.parameter import (
20
17
  ChannelQuantScaleParameter,
21
18
  ModelWeightParameter,
22
19
  PerTensorScaleParameter,
23
20
  )
24
21
  from sglang.srt.layers.quantization.base_config import (
22
+ FusedMoEMethodBase,
23
+ LinearMethodBase,
25
24
  QuantizationConfig,
26
25
  QuantizeMethodBase,
27
26
  )
27
+ from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
28
28
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
29
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
29
30
  from sglang.srt.utils import (
30
31
  apply_module_patch,
31
32
  cpu_has_amx_support,
@@ -36,6 +37,9 @@ from sglang.srt.utils import (
36
37
  use_intel_amx_backend,
37
38
  )
38
39
 
40
+ if TYPE_CHECKING:
41
+ from sglang.srt.layers.moe.topk import TopKOutput
42
+
39
43
  _is_cuda = is_cuda()
40
44
  _is_cpu_amx_available = cpu_has_amx_support()
41
45
  _is_cpu = is_cpu()
@@ -178,17 +182,18 @@ class W8A8Int8Config(QuantizationConfig):
178
182
  - Activation: dynamic, per-token, symmetric
179
183
  """
180
184
 
181
- def __init__(self, quant_config: Dict[str, Any]):
185
+ def __init__(self, quant_config: Dict[str, Any] = {}):
182
186
  super().__init__()
183
187
  self.quant_description = quant_config
184
188
  self.is_dynamic = quant_config.get("is_dynamic", False)
185
- if _is_npu:
186
- if (
187
- "packed_modules_mapping" in quant_config
188
- and quant_config["packed_modules_mapping"] is not None
189
- ):
190
- self.packed_modules_mapping = quant_config["packed_modules_mapping"]
189
+ ignore = cast(List[str], quant_config.get("ignore", []))
190
+ self.ignore = ignore if ignore is not None else []
191
+ packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
192
+ self.packed_modules_mapping = (
193
+ packed_modules_mapping if packed_modules_mapping is not None else {}
194
+ )
191
195
 
196
+ if _is_npu:
192
197
  # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
193
198
  for name in self.quant_description.keys():
194
199
  if "norm.bias" in name:
@@ -229,14 +234,14 @@ class W8A8Int8Config(QuantizationConfig):
229
234
  return []
230
235
 
231
236
  @classmethod
232
- def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
237
+ def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
233
238
  return cls(config)
234
239
 
235
240
  def get_quant_method(
236
241
  self,
237
242
  layer: torch.nn.Module,
238
243
  prefix: str,
239
- ) -> Optional["QuantizeMethodBase"]:
244
+ ) -> Optional[QuantizeMethodBase]:
240
245
  from sglang.srt.layers.linear import LinearBase
241
246
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
242
247
 
@@ -262,12 +267,16 @@ class W8A8Int8Config(QuantizationConfig):
262
267
  elif isinstance(layer, FusedMoE):
263
268
  return NPU_W8A8MoEMethod(self)
264
269
  return None
265
- else:
266
- if isinstance(layer, LinearBase):
267
- return W8A8Int8LinearMethod(self)
268
- elif isinstance(layer, FusedMoE):
269
- return W8A8Int8MoEMethod(self)
270
- return None
270
+
271
+ if should_ignore_layer(
272
+ prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
273
+ ):
274
+ return UnquantizedLinearMethod()
275
+ if isinstance(layer, LinearBase):
276
+ return W8A8Int8LinearMethod(self)
277
+ elif isinstance(layer, FusedMoE):
278
+ return W8A8Int8MoEMethod(self)
279
+ return None
271
280
 
272
281
  def is_layer_skipped(
273
282
  self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
@@ -374,7 +383,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
374
383
  )
375
384
 
376
385
 
377
- class W8A8Int8MoEMethod:
386
+ class W8A8Int8MoEMethod(FusedMoEMethodBase):
378
387
  """MoE method for INT8.
379
388
  Supports loading INT8 checkpoints with static weight scale and
380
389
  dynamic/static activation scale.
@@ -385,25 +394,7 @@ class W8A8Int8MoEMethod:
385
394
  quant_config: The quantization config.
386
395
  """
387
396
 
388
- def __new__(cls, *args, **kwargs):
389
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
390
-
391
- if not hasattr(cls, "_initialized"):
392
- original_init = cls.__init__
393
- new_cls = type(
394
- cls.__name__,
395
- (FusedMoEMethodBase,),
396
- {
397
- "__init__": original_init,
398
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
399
- },
400
- )
401
- obj = super(new_cls, new_cls).__new__(new_cls)
402
- obj.__init__(*args, **kwargs)
403
- return obj
404
- return super().__new__(cls)
405
-
406
- def __init__(self, quant_config):
397
+ def __init__(self, quant_config: W8A8Int8Config):
407
398
  self.quant_config = quant_config
408
399
 
409
400
  def create_weights(
@@ -481,15 +472,8 @@ class W8A8Int8MoEMethod:
481
472
  self,
482
473
  layer: torch.nn.Module,
483
474
  x: torch.Tensor,
484
- router_logits: torch.Tensor,
485
- top_k: int,
486
- renormalize: bool,
487
- use_grouped_topk: bool,
488
- topk_group: Optional[int] = None,
489
- num_expert_group: Optional[int] = None,
490
- num_fused_shared_experts: int = 0,
491
- custom_routing_function: Optional[Callable] = None,
492
- correction_bias: Optional[torch.Tensor] = None,
475
+ topk_output: TopKOutput,
476
+ *,
493
477
  activation: str = "silu",
494
478
  apply_router_weight_on_input: bool = False,
495
479
  inplace: bool = True,
@@ -497,24 +481,14 @@ class W8A8Int8MoEMethod:
497
481
  routed_scaling_factor: Optional[float] = None,
498
482
  ) -> torch.Tensor:
499
483
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
500
- from sglang.srt.layers.moe.topk import select_experts
501
-
502
- # Expert selection
503
- topk_weights, topk_ids = select_experts(
504
- hidden_states=x,
505
- router_logits=router_logits,
506
- use_grouped_topk=use_grouped_topk,
507
- top_k=top_k,
508
- renormalize=renormalize,
509
- topk_group=topk_group,
510
- num_expert_group=num_expert_group,
511
- num_fused_shared_experts=num_fused_shared_experts,
512
- custom_routing_function=custom_routing_function,
513
- correction_bias=correction_bias,
514
- routed_scaling_factor=routed_scaling_factor,
515
- )
516
484
 
517
485
  if use_intel_amx_backend(layer):
486
+ from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
487
+
488
+ topk_weights, topk_ids, _ = topk_output
489
+ x, topk_weights = apply_topk_weights_cpu(
490
+ apply_router_weight_on_input, topk_weights, x
491
+ )
518
492
  return torch.ops.sgl_kernel.fused_experts_cpu(
519
493
  x,
520
494
  layer.w13_weight,
@@ -536,8 +510,7 @@ class W8A8Int8MoEMethod:
536
510
  x,
537
511
  layer.w13_weight,
538
512
  layer.w2_weight,
539
- topk_weights=topk_weights,
540
- topk_ids=topk_ids,
513
+ topk_output=topk_output,
541
514
  inplace=inplace,
542
515
  activation=activation,
543
516
  apply_router_weight_on_input=apply_router_weight_on_input,
@@ -761,6 +734,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
761
734
  x: torch.Tensor,
762
735
  bias: Optional[torch.Tensor] = None,
763
736
  ) -> torch.Tensor:
737
+ from sglang.srt.layers.linear import RowParallelLinear
738
+
764
739
  if isinstance(layer, RowParallelLinear):
765
740
  tp_rank = get_tensor_model_parallel_rank()
766
741
  return self.quant_method.apply(layer, x, bias, tp_rank)
@@ -885,13 +860,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
885
860
  x: torch.Tensor,
886
861
  bias: Optional[torch.Tensor] = None,
887
862
  ) -> torch.Tensor:
863
+ from sglang.srt.layers.linear import RowParallelLinear
864
+
888
865
  if isinstance(layer, RowParallelLinear):
889
866
  tp_rank = get_tensor_model_parallel_rank()
890
867
  return self.quant_method.apply(layer, x, bias, tp_rank)
891
868
  return self.quant_method.apply(layer, x, bias)
892
869
 
893
870
 
894
- class NPU_W8A8MoEMethod:
871
+ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
895
872
  """MoE method for NPU quantization.
896
873
 
897
874
  This class search for specific quantization
@@ -910,7 +887,7 @@ class NPU_W8A8MoEMethod:
910
887
  layer: torch.nn.Module,
911
888
  num_experts: int,
912
889
  hidden_size: int,
913
- intermediate_size: List[int],
890
+ intermediate_size: int,
914
891
  params_dtype: torch.dtype,
915
892
  **extra_weight_attrs,
916
893
  ) -> None:
@@ -987,52 +964,11 @@ class NPU_W8A8MoEMethod:
987
964
  self,
988
965
  layer,
989
966
  x,
990
- router_logits,
991
- top_k,
992
- renormalize,
993
- use_grouped_topk,
994
- topk_group,
995
- num_expert_group,
996
- num_fused_shared_experts,
997
- custom_routing_function,
998
- correction_bias,
999
- activation,
1000
- apply_router_weight_on_input,
1001
- routed_scaling_factor,
967
+ topk_output: TopKOutput,
1002
968
  **kwargs,
1003
969
  ) -> torch.Tensor:
1004
- from sglang.srt.layers.moe.topk import select_experts
1005
-
1006
- global_num_experts = router_logits.shape[-1]
1007
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1008
- if global_num_experts == 256:
1009
- topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1010
- router_logits,
1011
- k=top_k,
1012
- bias=correction_bias,
1013
- k_group=topk_group,
1014
- group_count=num_expert_group,
1015
- group_select_mode=1,
1016
- renorm=0,
1017
- norm_type=1,
1018
- routed_scaling_factor=1,
1019
- eps=float(1e-20),
1020
- )
1021
- else:
1022
- topk_weights, topk_ids = select_experts(
1023
- hidden_states=x,
1024
- router_logits=router_logits,
1025
- use_grouped_topk=use_grouped_topk,
1026
- top_k=top_k,
1027
- renormalize=renormalize,
1028
- topk_group=topk_group,
1029
- num_expert_group=num_expert_group,
1030
- num_fused_shared_experts=num_fused_shared_experts,
1031
- custom_routing_function=custom_routing_function,
1032
- correction_bias=correction_bias,
1033
- torch_native=True,
1034
- routed_scaling_factor=routed_scaling_factor,
1035
- )
970
+
971
+ topk_weights, topk_ids, _ = topk_output
1036
972
  topk_ids = topk_ids.to(torch.int32)
1037
973
  topk_weights = topk_weights.to(x.dtype)
1038
974
  return npu_fused_experts(
@@ -1043,5 +979,5 @@ class NPU_W8A8MoEMethod:
1043
979
  w2_scale=layer.w2_weight_scale,
1044
980
  topk_weights=topk_weights,
1045
981
  topk_ids=topk_ids,
1046
- top_k=top_k,
982
+ top_k=topk_ids.shape[1],
1047
983
  )
@@ -12,14 +12,16 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
+ from __future__ import annotations
15
16
 
16
17
  from enum import Enum
17
- from typing import Optional
18
+ from typing import TYPE_CHECKING, Optional
18
19
 
19
20
  from torch import nn
20
21
 
21
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
25
 
24
26
 
25
27
  class AttentionType(Enum):
@@ -5,7 +5,6 @@ from dataclasses import dataclass
5
5
  from typing import List, Optional, Sequence, Tuple
6
6
 
7
7
  import torch
8
- import torch.nn.functional as F
9
8
  from torch.nn.parameter import Parameter, UninitializedParameter
10
9
 
11
10
  from sglang.srt.distributed import (
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
22
21
  QuantizeMethodBase,
23
22
  method_has_implemented_embedding,
24
23
  )
24
+ from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
25
25
  from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
26
26
 
27
27
  DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
34
 
35
- class UnquantizedEmbeddingMethod(QuantizeMethodBase):
36
- """Unquantized method for embeddings."""
37
-
38
- def create_weights(
39
- self,
40
- layer: torch.nn.Module,
41
- input_size_per_partition: int,
42
- output_partition_sizes: List[int],
43
- input_size: int,
44
- output_size: int,
45
- params_dtype: torch.dtype,
46
- **extra_weight_attrs,
47
- ):
48
- """Create weights for embedding layer."""
49
- weight = Parameter(
50
- torch.empty(
51
- sum(output_partition_sizes),
52
- input_size_per_partition,
53
- dtype=params_dtype,
54
- ),
55
- requires_grad=False,
56
- )
57
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
58
- layer.register_parameter("weight", weight)
59
- set_weight_attrs(weight, extra_weight_attrs)
60
-
61
- def apply(
62
- self,
63
- layer: torch.nn.Module,
64
- x: torch.Tensor,
65
- bias: Optional[torch.Tensor] = None,
66
- ) -> torch.Tensor:
67
- return F.linear(x, layer.weight, bias)
68
-
69
- def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
70
- return F.embedding(input_, layer.weight)
71
-
72
-
73
35
  def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
74
36
  """Pad the vocab size to the given value."""
75
37
  return ((vocab_size + pad_to - 1) // pad_to) * pad_to
@@ -569,8 +531,6 @@ class ParallelLMHead(VocabParallelEmbedding):
569
531
  if _is_cpu and _is_cpu_amx_available:
570
532
  if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
571
533
  self.quant_method = PackWeightMethod(weight_names=["weight"])
572
- else:
573
- logger.warning("The weight of LmHead is not packed")
574
534
 
575
535
  if bias:
576
536
  self.bias = Parameter(
sglang/srt/lora/lora.py CHANGED
@@ -186,10 +186,6 @@ class LoRAAdapter(nn.Module):
186
186
  up_name = weight_name.replace("gate_proj", "up_proj")
187
187
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
188
188
  if up_name not in weights:
189
- logger.warning(
190
- f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
191
- f"Initializing up projection to zero."
192
- )
193
189
  weights[up_name] = torch.zeros_like(weights[weight_name])
194
190
  # FIXME: Add gate-only support for flashinfer in future implementations
195
191
  assert self.lora_backend.name == "triton", (