sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib
2
4
  import sys
3
5
  from types import MappingProxyType
4
- from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
5
7
 
6
8
  import torch
7
9
  from torch.nn.parameter import Parameter
@@ -11,21 +13,20 @@ from sglang.srt.distributed import (
11
13
  get_tensor_model_parallel_world_size,
12
14
  )
13
15
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
14
- from sglang.srt.layers.linear import (
15
- LinearMethodBase,
16
- RowParallelLinear,
17
- UnquantizedLinearMethod,
18
- )
19
16
  from sglang.srt.layers.parameter import (
20
17
  ChannelQuantScaleParameter,
21
18
  ModelWeightParameter,
22
19
  PerTensorScaleParameter,
23
20
  )
24
21
  from sglang.srt.layers.quantization.base_config import (
22
+ FusedMoEMethodBase,
23
+ LinearMethodBase,
25
24
  QuantizationConfig,
26
25
  QuantizeMethodBase,
27
26
  )
27
+ from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
28
28
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
29
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
29
30
  from sglang.srt.utils import (
30
31
  apply_module_patch,
31
32
  cpu_has_amx_support,
@@ -36,6 +37,9 @@ from sglang.srt.utils import (
36
37
  use_intel_amx_backend,
37
38
  )
38
39
 
40
+ if TYPE_CHECKING:
41
+ from sglang.srt.layers.moe.topk import TopKOutput
42
+
39
43
  _is_cuda = is_cuda()
40
44
  _is_cpu_amx_available = cpu_has_amx_support()
41
45
  _is_cpu = is_cpu()
@@ -178,17 +182,18 @@ class W8A8Int8Config(QuantizationConfig):
178
182
  - Activation: dynamic, per-token, symmetric
179
183
  """
180
184
 
181
- def __init__(self, quant_config: Dict[str, Any]):
185
+ def __init__(self, quant_config: Dict[str, Any] = {}):
182
186
  super().__init__()
183
187
  self.quant_description = quant_config
184
188
  self.is_dynamic = quant_config.get("is_dynamic", False)
185
- if _is_npu:
186
- if (
187
- "packed_modules_mapping" in quant_config
188
- and quant_config["packed_modules_mapping"] is not None
189
- ):
190
- self.packed_modules_mapping = quant_config["packed_modules_mapping"]
189
+ ignore = cast(List[str], quant_config.get("ignore", []))
190
+ self.ignore = ignore if ignore is not None else []
191
+ packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
192
+ self.packed_modules_mapping = (
193
+ packed_modules_mapping if packed_modules_mapping is not None else {}
194
+ )
191
195
 
196
+ if _is_npu:
192
197
  # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
193
198
  for name in self.quant_description.keys():
194
199
  if "norm.bias" in name:
@@ -229,14 +234,14 @@ class W8A8Int8Config(QuantizationConfig):
229
234
  return []
230
235
 
231
236
  @classmethod
232
- def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
237
+ def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
233
238
  return cls(config)
234
239
 
235
240
  def get_quant_method(
236
241
  self,
237
242
  layer: torch.nn.Module,
238
243
  prefix: str,
239
- ) -> Optional["QuantizeMethodBase"]:
244
+ ) -> Optional[QuantizeMethodBase]:
240
245
  from sglang.srt.layers.linear import LinearBase
241
246
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
242
247
 
@@ -262,12 +267,16 @@ class W8A8Int8Config(QuantizationConfig):
262
267
  elif isinstance(layer, FusedMoE):
263
268
  return NPU_W8A8MoEMethod(self)
264
269
  return None
265
- else:
266
- if isinstance(layer, LinearBase):
267
- return W8A8Int8LinearMethod(self)
268
- elif isinstance(layer, FusedMoE):
269
- return W8A8Int8MoEMethod(self)
270
- return None
270
+
271
+ if should_ignore_layer(
272
+ prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
273
+ ):
274
+ return UnquantizedLinearMethod()
275
+ if isinstance(layer, LinearBase):
276
+ return W8A8Int8LinearMethod(self)
277
+ elif isinstance(layer, FusedMoE):
278
+ return W8A8Int8MoEMethod(self)
279
+ return None
271
280
 
272
281
  def is_layer_skipped(
273
282
  self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
@@ -374,7 +383,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
374
383
  )
375
384
 
376
385
 
377
- class W8A8Int8MoEMethod:
386
+ class W8A8Int8MoEMethod(FusedMoEMethodBase):
378
387
  """MoE method for INT8.
379
388
  Supports loading INT8 checkpoints with static weight scale and
380
389
  dynamic/static activation scale.
@@ -385,25 +394,7 @@ class W8A8Int8MoEMethod:
385
394
  quant_config: The quantization config.
386
395
  """
387
396
 
388
- def __new__(cls, *args, **kwargs):
389
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
390
-
391
- if not hasattr(cls, "_initialized"):
392
- original_init = cls.__init__
393
- new_cls = type(
394
- cls.__name__,
395
- (FusedMoEMethodBase,),
396
- {
397
- "__init__": original_init,
398
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
399
- },
400
- )
401
- obj = super(new_cls, new_cls).__new__(new_cls)
402
- obj.__init__(*args, **kwargs)
403
- return obj
404
- return super().__new__(cls)
405
-
406
- def __init__(self, quant_config):
397
+ def __init__(self, quant_config: W8A8Int8Config):
407
398
  self.quant_config = quant_config
408
399
 
409
400
  def create_weights(
@@ -481,15 +472,8 @@ class W8A8Int8MoEMethod:
481
472
  self,
482
473
  layer: torch.nn.Module,
483
474
  x: torch.Tensor,
484
- router_logits: torch.Tensor,
485
- top_k: int,
486
- renormalize: bool,
487
- use_grouped_topk: bool,
488
- topk_group: Optional[int] = None,
489
- num_expert_group: Optional[int] = None,
490
- num_fused_shared_experts: int = 0,
491
- custom_routing_function: Optional[Callable] = None,
492
- correction_bias: Optional[torch.Tensor] = None,
475
+ topk_output: TopKOutput,
476
+ *,
493
477
  activation: str = "silu",
494
478
  apply_router_weight_on_input: bool = False,
495
479
  inplace: bool = True,
@@ -497,24 +481,14 @@ class W8A8Int8MoEMethod:
497
481
  routed_scaling_factor: Optional[float] = None,
498
482
  ) -> torch.Tensor:
499
483
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
500
- from sglang.srt.layers.moe.topk import select_experts
501
-
502
- # Expert selection
503
- topk_weights, topk_ids = select_experts(
504
- hidden_states=x,
505
- router_logits=router_logits,
506
- use_grouped_topk=use_grouped_topk,
507
- top_k=top_k,
508
- renormalize=renormalize,
509
- topk_group=topk_group,
510
- num_expert_group=num_expert_group,
511
- num_fused_shared_experts=num_fused_shared_experts,
512
- custom_routing_function=custom_routing_function,
513
- correction_bias=correction_bias,
514
- routed_scaling_factor=routed_scaling_factor,
515
- )
516
484
 
517
485
  if use_intel_amx_backend(layer):
486
+ from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
487
+
488
+ topk_weights, topk_ids, _ = topk_output
489
+ x, topk_weights = apply_topk_weights_cpu(
490
+ apply_router_weight_on_input, topk_weights, x
491
+ )
518
492
  return torch.ops.sgl_kernel.fused_experts_cpu(
519
493
  x,
520
494
  layer.w13_weight,
@@ -536,8 +510,7 @@ class W8A8Int8MoEMethod:
536
510
  x,
537
511
  layer.w13_weight,
538
512
  layer.w2_weight,
539
- topk_weights=topk_weights,
540
- topk_ids=topk_ids,
513
+ topk_output=topk_output,
541
514
  inplace=inplace,
542
515
  activation=activation,
543
516
  apply_router_weight_on_input=apply_router_weight_on_input,
@@ -761,6 +734,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
761
734
  x: torch.Tensor,
762
735
  bias: Optional[torch.Tensor] = None,
763
736
  ) -> torch.Tensor:
737
+ from sglang.srt.layers.linear import RowParallelLinear
738
+
764
739
  if isinstance(layer, RowParallelLinear):
765
740
  tp_rank = get_tensor_model_parallel_rank()
766
741
  return self.quant_method.apply(layer, x, bias, tp_rank)
@@ -885,13 +860,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
885
860
  x: torch.Tensor,
886
861
  bias: Optional[torch.Tensor] = None,
887
862
  ) -> torch.Tensor:
863
+ from sglang.srt.layers.linear import RowParallelLinear
864
+
888
865
  if isinstance(layer, RowParallelLinear):
889
866
  tp_rank = get_tensor_model_parallel_rank()
890
867
  return self.quant_method.apply(layer, x, bias, tp_rank)
891
868
  return self.quant_method.apply(layer, x, bias)
892
869
 
893
870
 
894
- class NPU_W8A8MoEMethod:
871
+ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
895
872
  """MoE method for NPU quantization.
896
873
 
897
874
  This class search for specific quantization
@@ -910,7 +887,7 @@ class NPU_W8A8MoEMethod:
910
887
  layer: torch.nn.Module,
911
888
  num_experts: int,
912
889
  hidden_size: int,
913
- intermediate_size: List[int],
890
+ intermediate_size: int,
914
891
  params_dtype: torch.dtype,
915
892
  **extra_weight_attrs,
916
893
  ) -> None:
@@ -987,52 +964,11 @@ class NPU_W8A8MoEMethod:
987
964
  self,
988
965
  layer,
989
966
  x,
990
- router_logits,
991
- top_k,
992
- renormalize,
993
- use_grouped_topk,
994
- topk_group,
995
- num_expert_group,
996
- num_fused_shared_experts,
997
- custom_routing_function,
998
- correction_bias,
999
- activation,
1000
- apply_router_weight_on_input,
1001
- routed_scaling_factor,
967
+ topk_output: TopKOutput,
1002
968
  **kwargs,
1003
969
  ) -> torch.Tensor:
1004
- from sglang.srt.layers.moe.topk import select_experts
1005
-
1006
- global_num_experts = router_logits.shape[-1]
1007
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1008
- if global_num_experts == 256:
1009
- topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1010
- router_logits,
1011
- k=top_k,
1012
- bias=correction_bias,
1013
- k_group=topk_group,
1014
- group_count=num_expert_group,
1015
- group_select_mode=1,
1016
- renorm=0,
1017
- norm_type=1,
1018
- routed_scaling_factor=1,
1019
- eps=float(1e-20),
1020
- )
1021
- else:
1022
- topk_weights, topk_ids = select_experts(
1023
- hidden_states=x,
1024
- router_logits=router_logits,
1025
- use_grouped_topk=use_grouped_topk,
1026
- top_k=top_k,
1027
- renormalize=renormalize,
1028
- topk_group=topk_group,
1029
- num_expert_group=num_expert_group,
1030
- num_fused_shared_experts=num_fused_shared_experts,
1031
- custom_routing_function=custom_routing_function,
1032
- correction_bias=correction_bias,
1033
- torch_native=True,
1034
- routed_scaling_factor=routed_scaling_factor,
1035
- )
970
+
971
+ topk_weights, topk_ids, _ = topk_output
1036
972
  topk_ids = topk_ids.to(torch.int32)
1037
973
  topk_weights = topk_weights.to(x.dtype)
1038
974
  return npu_fused_experts(
@@ -1043,5 +979,5 @@ class NPU_W8A8MoEMethod:
1043
979
  w2_scale=layer.w2_weight_scale,
1044
980
  topk_weights=topk_weights,
1045
981
  topk_ids=topk_ids,
1046
- top_k=top_k,
982
+ top_k=topk_ids.shape[1],
1047
983
  )
@@ -5,7 +5,6 @@ from dataclasses import dataclass
5
5
  from typing import List, Optional, Sequence, Tuple
6
6
 
7
7
  import torch
8
- import torch.nn.functional as F
9
8
  from torch.nn.parameter import Parameter, UninitializedParameter
10
9
 
11
10
  from sglang.srt.distributed import (
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
22
21
  QuantizeMethodBase,
23
22
  method_has_implemented_embedding,
24
23
  )
24
+ from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
25
25
  from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
26
26
 
27
27
  DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
34
 
35
- class UnquantizedEmbeddingMethod(QuantizeMethodBase):
36
- """Unquantized method for embeddings."""
37
-
38
- def create_weights(
39
- self,
40
- layer: torch.nn.Module,
41
- input_size_per_partition: int,
42
- output_partition_sizes: List[int],
43
- input_size: int,
44
- output_size: int,
45
- params_dtype: torch.dtype,
46
- **extra_weight_attrs,
47
- ):
48
- """Create weights for embedding layer."""
49
- weight = Parameter(
50
- torch.empty(
51
- sum(output_partition_sizes),
52
- input_size_per_partition,
53
- dtype=params_dtype,
54
- ),
55
- requires_grad=False,
56
- )
57
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
58
- layer.register_parameter("weight", weight)
59
- set_weight_attrs(weight, extra_weight_attrs)
60
-
61
- def apply(
62
- self,
63
- layer: torch.nn.Module,
64
- x: torch.Tensor,
65
- bias: Optional[torch.Tensor] = None,
66
- ) -> torch.Tensor:
67
- return F.linear(x, layer.weight, bias)
68
-
69
- def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
70
- return F.embedding(input_, layer.weight)
71
-
72
-
73
35
  def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
74
36
  """Pad the vocab size to the given value."""
75
37
  return ((vocab_size + pad_to - 1) // pad_to) * pad_to
@@ -569,8 +531,6 @@ class ParallelLMHead(VocabParallelEmbedding):
569
531
  if _is_cpu and _is_cpu_amx_available:
570
532
  if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
571
533
  self.quant_method = PackWeightMethod(weight_names=["weight"])
572
- else:
573
- logger.warning("The weight of LmHead is not packed")
574
534
 
575
535
  if bias:
576
536
  self.bias = Parameter(
sglang/srt/lora/lora.py CHANGED
@@ -186,10 +186,6 @@ class LoRAAdapter(nn.Module):
186
186
  up_name = weight_name.replace("gate_proj", "up_proj")
187
187
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
188
188
  if up_name not in weights:
189
- logger.warning(
190
- f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
191
- f"Initializing up projection to zero."
192
- )
193
189
  weights[up_name] = torch.zeros_like(weights[weight_name])
194
190
  # FIXME: Add gate-only support for flashinfer in future implementations
195
191
  assert self.lora_backend.name == "triton", (
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, Set, Tuple
19
+ from typing import Dict, Iterable, Optional, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
@@ -53,6 +53,8 @@ class LoRAManager:
53
53
  lora_backend: str = "triton",
54
54
  tp_size: int = 1,
55
55
  tp_rank: int = 0,
56
+ max_lora_rank: Optional[int] = None,
57
+ target_modules: Optional[Iterable[str]] = None,
56
58
  ):
57
59
  self.base_model: torch.nn.Module = base_model
58
60
  self.base_hf_config: AutoConfig = base_hf_config
@@ -62,6 +64,10 @@ class LoRAManager:
62
64
  self.device: torch.device = next(self.base_model.parameters()).device
63
65
  self.tp_size: int = tp_size
64
66
  self.tp_rank: int = tp_rank
67
+ self.max_lora_rank: Optional[int] = max_lora_rank
68
+ self.target_modules: Optional[Set[str]] = (
69
+ set(target_modules) if target_modules else None
70
+ )
65
71
 
66
72
  # LoRA backend for running sgemm kernels
67
73
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -153,7 +159,9 @@ class LoRAManager:
153
159
  error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
154
160
 
155
161
  try:
156
- self.configs[lora_name] = LoRAConfig(lora_path)
162
+ new_adapter = LoRAConfig(lora_path)
163
+ self.validate_new_adapter(lora_name, new_adapter)
164
+ self.configs[lora_name] = new_adapter
157
165
  except Exception as e:
158
166
  success = False
159
167
  error_message = (
@@ -168,6 +176,21 @@ class LoRAManager:
168
176
  error_message=error_message,
169
177
  )
170
178
 
179
+ def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
180
+ """
181
+ Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
182
+ """
183
+
184
+ incompatible = self.memory_pool and not self.memory_pool.can_support(
185
+ lora_config
186
+ )
187
+ if incompatible:
188
+ raise ValueError(
189
+ f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
190
+ "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
191
+ "included in `--enable_lora_modules`."
192
+ )
193
+
171
194
  def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
172
195
  """
173
196
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
@@ -214,7 +237,7 @@ class LoRAManager:
214
237
  weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
215
238
  if lora_path is not None:
216
239
  lora = self.loras[lora_path]
217
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
240
+ lora_ranks[weight_indices[i]] = lora.config.r
218
241
  scalings[weight_indices[i]] = lora.scaling
219
242
 
220
243
  # Use pinned memory to avoid synchronizations during host-to-device transfer
@@ -319,7 +342,7 @@ class LoRAManager:
319
342
  )
320
343
  else:
321
344
  weight_name = get_weight_name(
322
- module_name, self.lora_weight_names, LoRAType.LORA_A
345
+ module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
323
346
  )
324
347
  module.set_lora_info(
325
348
  self.memory_pool.get_tensor(
@@ -351,58 +374,67 @@ class LoRAManager:
351
374
  i: {} for i in range(self.base_hf_config.num_hidden_layers)
352
375
  }
353
376
 
354
- # Initialize memory pool
355
- self.memory_pool = LoRAMemoryPool(
356
- self.base_hf_config,
357
- self.max_loras_per_batch,
358
- self.dtype,
359
- self.tp_size,
360
- self.tp_rank,
361
- )
377
+ # The LoRA memory pool that manages the GPU buffers for active LoRA weights.
378
+ # It is initialized lazily when the first LoRA adapter is loaded.
379
+ self.memory_pool: Optional[LoRAMemoryPool] = None
362
380
 
363
381
  def update_state_from_configs(self):
364
382
  """
365
383
  Update the internal state of the LoRAManager based on the current `self.configs`. This method
366
384
  should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
367
-
368
- This includes:
369
- - Initializing LoRA adapters if they are not already loaded.
370
- - Collect all LoRA weight names based on the current loaded adapters.
371
- - Lazily monkey-patching the base model to use LoRA layers where applicable.
372
- - Preparing the GPU buffer pool for active LoRA weights.
373
385
  """
374
386
 
375
- # Target module names in huggingface lora configs.
376
- # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
377
- hf_target_module_names: Set[str] = set()
378
- for config in self.configs.values():
379
- hf_target_module_names.update(config.target_modules)
380
- max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
381
-
382
387
  # Loads / unloads LoRA adapters based on the latest configs.
383
388
  self.update_lora_adapters()
389
+ # Apply the latest LoRA configurations to the internal state for inferencing.
390
+ self.apply_lora_configs()
391
+
392
+ def apply_lora_configs(self):
393
+ """
394
+ Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
395
+
396
+ Notes:
397
+ - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
398
+ we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
399
+ LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
400
+ early CY25H2.
401
+ """
402
+
403
+ if self.memory_pool is None:
404
+ # Infer max_lora_rank and target_modules if not explicitly specified in server args.
405
+ if self.target_modules is None:
406
+ self.target_modules = set()
407
+ for config in self.configs.values():
408
+ self.target_modules.update(config.target_modules)
409
+
410
+ if self.max_lora_rank is None:
411
+ self.max_lora_rank = max(
412
+ [x.hf_config["r"] for x in self.configs.values()],
413
+ default=0,
414
+ )
415
+
416
+ self.update_lora_weight_names()
417
+ self.update_lora_modules()
418
+ self.update_memory_buffers()
419
+ else:
420
+ # No-op if the memory pool can support the current LoRA configurations.
421
+ # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
422
+ # module is changed once FlashInfer backend is deprecated.
423
+ assert self.memory_pool.can_support(self.configs.values()), (
424
+ "LoRA memory pool cannot support the current LoRA configuration. "
425
+ "This should never happen as we should have validated adapter compatibility. "
426
+ "Please create a Github issue to report.",
427
+ )
384
428
 
385
- # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
386
- #
387
- # Please note that the following update operations are "monotonic" by design, meaning that we update
388
- # multiple places to support the new weight names when the first adapter targeting such weight names
389
- # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
390
- # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
391
- # list of LoRA weight names is expected to be extremely finite and stable.
392
- self.update_lora_weight_names(hf_target_module_names)
393
- self.update_lora_modules(hf_target_module_names)
394
- self.update_memory_buffers(max_lora_dim)
395
-
396
- def update_lora_weight_names(self, hf_target_names: Set[str]):
429
+ def update_lora_weight_names(self):
397
430
  """
398
431
  Add new LoRA weight names if needed based on the current `self.configs`.
399
432
  """
400
433
 
401
434
  # Target lora weight names for lora_a and lora_b modules respectively.
402
- for module in hf_target_names:
403
- lora_A, lora_B = get_normalized_lora_weight_names(module)
404
- self.lora_weight_names[0].update(lora_A)
405
- self.lora_weight_names[1].update(lora_B)
435
+ lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
436
+ self.lora_weight_names[0].update(lora_A)
437
+ self.lora_weight_names[1].update(lora_B)
406
438
 
407
439
  def update_lora_adapters(self):
408
440
  """
@@ -434,21 +466,23 @@ class LoRAManager:
434
466
  # Additional checks for flashinfer backend
435
467
  # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
436
468
  if self.lora_backend == "flashinfer":
437
- lora_dims = set(x.hf_config["r"] for x in self.configs.values())
469
+ lora_dims = set(x.r for x in self.configs.values())
438
470
  scalings = set(x.scaling for x in self.loras.values())
439
471
  assert (
440
472
  len(lora_dims) == 1 and len(scalings) == 1
441
473
  ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
442
474
 
443
- def update_memory_buffers(self, max_lora_dim: int):
444
- """
445
- Update the LoRA memory pool buffers based on the current LoRA configurations and update
446
- LoRA modules to use the new buffers. This method should be called after the LoRA configurations
447
- are set or updated.
448
- """
449
-
450
- self.memory_pool.init_buffers(
451
- self.lora_weight_names, self.base_model, max_lora_dim
475
+ def update_memory_buffers(self):
476
+ """(Re)initialize the LoRA memory pool based on the current configurations."""
477
+ self.memory_pool = LoRAMemoryPool(
478
+ base_hf_config=self.base_hf_config,
479
+ max_loras_per_batch=self.max_loras_per_batch,
480
+ dtype=self.dtype,
481
+ tp_size=self.tp_size,
482
+ tp_rank=self.tp_rank,
483
+ max_lora_rank=self.max_lora_rank,
484
+ lora_weight_names=self.lora_weight_names,
485
+ base_model=self.base_model,
452
486
  )
453
487
 
454
488
  def set_lora_module(self, module_name, module):
@@ -456,11 +490,11 @@ class LoRAManager:
456
490
  replace_submodule(self.base_model, module_name, lora_module)
457
491
  return lora_module
458
492
 
459
- def update_lora_modules(self, hf_target_names: Set[str]):
493
+ def update_lora_modules(self):
460
494
  # Target module names of customized layers defined in python/sglang/srt/layers
461
495
  # e.g., {"qkv_proj", "o_proj"}
462
496
  customized_target_names = get_customized_names_from_hf_names(
463
- hf_target_names, self.base_model
497
+ self.target_modules, self.base_model
464
498
  )
465
499
 
466
500
  for module_name, module in self.base_model.named_modules():