sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
9
+ from sglang.srt.layers.quantization.base_config import (
10
+ QuantizationConfig,
11
+ QuantizeMethodBase,
12
+ )
13
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
14
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
15
+ from sglang.srt.utils import set_weight_attrs
16
+
17
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class W4AFp8Config(QuantizationConfig):
23
+ """Config class for MIXED_PRECISION W4AFp8."""
24
+
25
+ def __init__(
26
+ self,
27
+ is_checkpoint_fp8_serialized: bool = True,
28
+ is_checkpoint_w4afp8_serialized: bool = True,
29
+ linear_activation_scheme: str = "dynamic",
30
+ moe_activation_scheme: str = "static",
31
+ ignored_layers: Optional[List[str]] = None,
32
+ weight_block_size: Optional[List[int]] = None,
33
+ group_size: int = 128,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
37
+ self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
38
+ if is_checkpoint_w4afp8_serialized:
39
+ logger.warning("Detected w4afp8 checkpoint. Please note that")
40
+ if moe_activation_scheme not in ACTIVATION_SCHEMES:
41
+ raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
42
+ self.linear_activation_scheme = linear_activation_scheme
43
+ self.moe_activation_scheme = moe_activation_scheme
44
+ self.ignored_layers = ignored_layers or []
45
+ self.weight_block_size = [128, 128]
46
+ self.group_size = group_size
47
+
48
+ @classmethod
49
+ def get_name(cls) -> str:
50
+ return "w4afp8"
51
+
52
+ @classmethod
53
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
54
+ return [torch.bfloat16, torch.float8_e4m3fn]
55
+
56
+ @classmethod
57
+ def get_min_capability(cls) -> int:
58
+ return 90
59
+
60
+ @classmethod
61
+ def get_config_filenames(cls) -> List[str]:
62
+ return []
63
+
64
+ @classmethod
65
+ def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
66
+ quant_method = cls.get_from_keys(config, ["quant_method"])
67
+ is_checkpoint_fp8_serialized = "fp8" in quant_method
68
+ is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
69
+ linear_activation_scheme = "dynamic"
70
+ moe_activation_scheme = "static"
71
+ weight_block_size = [128, 128]
72
+ return cls(
73
+ is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
74
+ is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
75
+ linear_activation_scheme=linear_activation_scheme,
76
+ moe_activation_scheme=moe_activation_scheme,
77
+ weight_block_size=weight_block_size,
78
+ )
79
+
80
+ def get_quant_method(
81
+ self, layer: torch.nn.Module, prefix: str
82
+ ) -> Optional["QuantizeMethodBase"]:
83
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
84
+
85
+ if isinstance(layer, LinearBase):
86
+ if is_layer_skipped(prefix, self.ignored_layers):
87
+ return UnquantizedLinearMethod()
88
+ return Fp8LinearMethod(self)
89
+ elif isinstance(layer, FusedMoE):
90
+ return W4AFp8MoEMethod(self)
91
+ return None
92
+
93
+ def get_scaled_act_names(self) -> List[str]:
94
+ return []
95
+
96
+
97
+ class W4AFp8MoEMethod:
98
+
99
+ def __init__(self, quant_config: W4AFp8Config):
100
+ self.quant_config = quant_config
101
+
102
+ def create_weights(
103
+ self,
104
+ layer: Module,
105
+ num_experts_per_partition: int,
106
+ hidden_size: int,
107
+ intermediate_size: int,
108
+ params_dtype: torch.dtype,
109
+ **extra_weight_attrs,
110
+ ):
111
+ assert "weight_loader" in extra_weight_attrs
112
+
113
+ # Fused gate_up_proj (column parallel)
114
+ w13_weight = torch.nn.Parameter(
115
+ torch.empty(
116
+ num_experts_per_partition,
117
+ intermediate_size * 2,
118
+ hidden_size // 2,
119
+ dtype=torch.int8,
120
+ ),
121
+ requires_grad=False,
122
+ )
123
+ layer.register_parameter("w13_weight", w13_weight)
124
+ set_weight_attrs(w13_weight, extra_weight_attrs)
125
+
126
+ # down_proj (row parallel)
127
+ w2_weight = torch.nn.Parameter(
128
+ torch.empty(
129
+ num_experts_per_partition,
130
+ hidden_size,
131
+ intermediate_size // 2,
132
+ dtype=torch.int8,
133
+ ),
134
+ requires_grad=False,
135
+ )
136
+ layer.register_parameter("w2_weight", w2_weight)
137
+ set_weight_attrs(w2_weight, extra_weight_attrs)
138
+
139
+ w13_weight_scale = torch.nn.Parameter(
140
+ torch.zeros(
141
+ num_experts_per_partition,
142
+ 2 * intermediate_size,
143
+ hidden_size // self.quant_config.group_size,
144
+ dtype=torch.float32,
145
+ ),
146
+ requires_grad=False,
147
+ )
148
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
149
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
150
+
151
+ w2_weight_scale = torch.nn.Parameter(
152
+ torch.zeros(
153
+ num_experts_per_partition,
154
+ hidden_size,
155
+ intermediate_size // self.quant_config.group_size,
156
+ dtype=torch.float32,
157
+ ),
158
+ requires_grad=False,
159
+ )
160
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
161
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
162
+
163
+ # Input scales
164
+ w13_input_scale = torch.nn.Parameter(
165
+ torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
166
+ requires_grad=False,
167
+ )
168
+ layer.register_parameter("w13_input_scale", w13_input_scale)
169
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
170
+
171
+ w2_input_scale = torch.nn.Parameter(
172
+ torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
173
+ requires_grad=False,
174
+ )
175
+ layer.register_parameter("w2_input_scale", w2_input_scale)
176
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
177
+
178
+ # Pre-populate the strides
179
+ device = layer.w13_weight.device
180
+
181
+ self.a_strides1 = torch.full(
182
+ (num_experts_per_partition, 3),
183
+ hidden_size,
184
+ device=device,
185
+ dtype=torch.int64,
186
+ )
187
+ self.c_strides1 = torch.full(
188
+ (num_experts_per_partition, 3),
189
+ 2 * intermediate_size,
190
+ device=device,
191
+ dtype=torch.int64,
192
+ )
193
+ self.a_strides2 = torch.full(
194
+ (num_experts_per_partition, 3),
195
+ intermediate_size,
196
+ device=device,
197
+ dtype=torch.int64,
198
+ )
199
+ self.c_strides2 = torch.full(
200
+ (num_experts_per_partition, 3),
201
+ hidden_size,
202
+ device=device,
203
+ dtype=torch.int64,
204
+ )
205
+ self.b_strides1 = self.a_strides1
206
+ self.s_strides13 = self.c_strides1
207
+ self.b_strides2 = self.a_strides2
208
+ self.s_strides2 = self.c_strides2
209
+
210
+ self.expert_offsets = torch.empty(
211
+ (num_experts_per_partition + 1), dtype=torch.int32, device=device
212
+ )
213
+ self.problem_sizes1 = torch.empty(
214
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
215
+ )
216
+ self.problem_sizes2 = torch.empty(
217
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
218
+ )
219
+
220
+ return
221
+
222
+ def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
223
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
224
+ s_shape = scales.shape
225
+ # Reshape to separate groups of 4
226
+ scales_interleaved = scales.reshape(
227
+ s_shape[0], s_shape[1], (s_shape[2] // 4), 4
228
+ )
229
+ # Permute dimensions to interleave
230
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
231
+ # Reshape back to original dimensions but with interleaved values
232
+ scales_interleaved = scales_interleaved.reshape(
233
+ s_shape[0], s_shape[2] // 4, s_shape[1] * 4
234
+ )
235
+ return scales_interleaved.contiguous()
236
+
237
+ def process_weights_after_loading(self, layer: Module) -> None:
238
+ dtype = torch.bfloat16
239
+ device = layer.w2_weight.device
240
+
241
+ # Interleave w13_weight_scale (gate_up_proj)
242
+ w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
243
+ w13_weight_scale = self._interleave_scales(w13_weight_scale)
244
+ layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
245
+
246
+ # Interleave w2_weight_scale (down_proj)
247
+ w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
248
+ w2_weight_scale = self._interleave_scales(w2_weight_scale)
249
+ layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
250
+
251
+ # Process input scales
252
+ w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
253
+ new_w13_input_scale = torch.tensor(
254
+ [w13_input_scale_max],
255
+ dtype=dtype,
256
+ device=device,
257
+ )
258
+ layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
259
+
260
+ w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
261
+ new_w2_input_scale = torch.tensor(
262
+ [w2_input_scale_max], dtype=dtype, device=device
263
+ )
264
+ layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
@@ -4,6 +4,7 @@ import torch
4
4
  from torch.nn.parameter import Parameter
5
5
 
6
6
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
7
8
  from sglang.srt.layers.linear import LinearMethodBase
8
9
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
9
10
  from sglang.srt.layers.quantization.base_config import (
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
11
12
  QuantizeMethodBase,
12
13
  )
13
14
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
- from sglang.srt.utils import is_cuda, set_weight_attrs
15
+ from sglang.srt.utils import (
16
+ cpu_has_amx_support,
17
+ is_cpu,
18
+ is_cuda,
19
+ set_weight_attrs,
20
+ use_intel_amx_backend,
21
+ )
15
22
 
16
23
  _is_cuda = is_cuda()
24
+ _is_cpu_amx_available = cpu_has_amx_support()
25
+ _is_cpu = is_cpu()
17
26
  if _is_cuda:
18
27
  from sgl_kernel import int8_scaled_mm
19
28
 
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
72
81
  self.quantization_config = quantization_config
73
82
 
74
83
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
84
+ if _is_cpu:
85
+ assert (
86
+ _is_cpu_amx_available
87
+ ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
88
+ _amx_process_weight_after_loading(layer, ["weight"])
89
+ return
90
+
75
91
  layer.weight = Parameter(layer.weight.t(), requires_grad=False)
76
92
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
77
93
 
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
112
128
  x: torch.Tensor,
113
129
  bias: Optional[torch.Tensor] = None,
114
130
  ):
131
+ if use_intel_amx_backend(layer):
132
+ return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
133
+ x,
134
+ layer.weight,
135
+ layer.weight_scale,
136
+ bias,
137
+ x.dtype,
138
+ True, # is_vnni
139
+ )
140
+
115
141
  x_q, x_scale = per_token_quant_int8(x)
116
142
 
117
143
  return int8_scaled_mm(
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
206
232
  layer.register_parameter("w2_input_scale", w2_input_scale)
207
233
 
208
234
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235
+ if _is_cpu:
236
+ assert (
237
+ _is_cpu_amx_available
238
+ ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
239
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
240
+ return
241
+
209
242
  layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
210
243
  layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
211
244
  layer.w13_weight_scale = Parameter(
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
252
285
  routed_scaling_factor=routed_scaling_factor,
253
286
  )
254
287
 
288
+ if use_intel_amx_backend(layer):
289
+ return torch.ops.sgl_kernel.fused_experts_cpu(
290
+ x,
291
+ layer.w13_weight,
292
+ layer.w2_weight,
293
+ topk_weights,
294
+ topk_ids,
295
+ False, # inplace See [Note] inplace should be False in fused_experts.
296
+ True, # use_int8_w8a8
297
+ False, # use_fp8_w8a16
298
+ layer.w13_weight_scale, # w1_scale
299
+ layer.w2_weight_scale, # w2_scale
300
+ None, # block_size
301
+ layer.w13_input_scale, # a1_scale
302
+ layer.w2_input_scale, # a2_scale
303
+ True, # is_vnni
304
+ )
305
+
255
306
  return fused_experts(
256
307
  x,
257
308
  layer.w13_weight,
@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
660
660
  beta_slow: int = 1,
661
661
  mscale: float = 1,
662
662
  mscale_all_dim: float = 0,
663
- device: Optional[str] = "cuda",
663
+ device: Optional[str] = "cuda" if not _is_npu else "npu",
664
664
  ) -> None:
665
665
  self.scaling_factor = scaling_factor
666
666
  self.extrapolation_factor = extrapolation_factor
@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
679
679
  )
680
680
 
681
681
  # Re-dispatch
682
- if _is_hip:
682
+ if _is_hip or _is_npu:
683
683
  self._forward_method = self.forward_native
684
684
 
685
685
  def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
2
2
 
3
+ import logging
3
4
  from dataclasses import dataclass
4
5
  from typing import List, Optional, Sequence, Tuple
5
6
 
@@ -13,6 +14,7 @@ from sglang.srt.distributed import (
13
14
  get_tensor_model_parallel_world_size,
14
15
  tensor_model_parallel_all_reduce,
15
16
  )
17
+ from sglang.srt.layers.amx_utils import PackWeightMethod
16
18
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
17
19
  from sglang.srt.layers.parameter import BasevLLMParameter
18
20
  from sglang.srt.layers.quantization.base_config import (
@@ -20,18 +22,15 @@ from sglang.srt.layers.quantization.base_config import (
20
22
  QuantizeMethodBase,
21
23
  method_has_implemented_embedding,
22
24
  )
23
- from sglang.srt.utils import (
24
- PackWeightMethod,
25
- cpu_has_amx_support,
26
- is_cpu,
27
- set_weight_attrs,
28
- )
25
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
29
26
 
30
27
  DEFAULT_VOCAB_PADDING_SIZE = 64
31
28
 
32
29
  _is_cpu_amx_available = cpu_has_amx_support()
33
30
  _is_cpu = is_cpu()
34
31
 
32
+ logger = logging.getLogger(__name__)
33
+
35
34
 
36
35
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
37
36
  """Unquantized method for embeddings."""
@@ -250,8 +249,16 @@ class VocabParallelEmbedding(torch.nn.Module):
250
249
  self.tp_size = 1
251
250
 
252
251
  self.num_embeddings = num_embeddings
253
- self.padding_size = padding_size
254
252
  self.org_vocab_size = org_num_embeddings or num_embeddings
253
+
254
+ # Support the case where the vocab size is not divisible by the TP size.
255
+ if (
256
+ _is_cpu
257
+ and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
258
+ ):
259
+ padding_size *= self.tp_size
260
+ self.padding_size = padding_size
261
+
255
262
  num_added_embeddings = num_embeddings - self.org_vocab_size
256
263
  self.use_presharded_weights = use_presharded_weights
257
264
  if use_presharded_weights:
@@ -558,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
558
565
  )
559
566
  self.quant_config = quant_config
560
567
 
561
- # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
562
- if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
563
- self.quant_method = PackWeightMethod(weight_names=["weight"])
568
+ # We only support pack LMHead if it's not quantized.
569
+ if _is_cpu and _is_cpu_amx_available:
570
+ if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
571
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
572
+ else:
573
+ logger.warning("The weight of LmHead is not packed")
564
574
 
565
575
  if bias:
566
576
  self.bias = Parameter(
sglang/srt/lora/lora.py CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
65
65
  self.layers: List[LoRALayer] = nn.ModuleList(
66
66
  [
67
67
  LoRALayer(config, base_hf_config)
68
- for i in range(base_hf_config.num_hidden_layers)
68
+ for _ in range(base_hf_config.num_hidden_layers)
69
69
  ]
70
70
  )
71
71
 
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
88
88
  else:
89
89
  self.weights[name] = loaded_weight.cpu()
90
90
 
91
- # stack kv_proj and gate_up_proj
92
- for i in range(self.base_hf_config.num_hidden_layers):
93
- layer = self.layers[i]
94
- weight_names = [name for name, _ in layer.weights.items()]
91
+ # normalize kv_proj and gate_up_proj
92
+ for layer in self.layers:
93
+ weight_names = list(layer.weights.keys())
95
94
  self.normalize_qkv_proj(weight_names, layer.weights)
96
95
  self.normalize_gate_up_proj(weight_names, layer.weights)
97
96
 
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
35
35
  get_normalized_lora_weight_names,
36
36
  get_weight_name,
37
37
  )
38
+ from sglang.srt.managers.io_struct import LoRAUpdateResult
38
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
40
  from sglang.srt.utils import replace_submodule
40
41
 
@@ -98,44 +99,96 @@ class LoRAManager:
98
99
  ],
99
100
  )
100
101
 
101
- def load_lora_adapters(self, lora_paths: Dict[str, str]):
102
+ def create_lora_update_result(
103
+ self, success: bool, error_message: str = ""
104
+ ) -> LoRAUpdateResult:
105
+ return LoRAUpdateResult(
106
+ success=success,
107
+ error_message=error_message,
108
+ loaded_adapters={
109
+ name: config.path for name, config in self.configs.items()
110
+ },
111
+ )
112
+
113
+ def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
102
114
  """
103
115
  Load LoRA adapters from the specified paths.
104
- TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105
116
 
106
117
  Args:
107
118
  lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
119
  If a LoRA adapter is already loaded, it will be skipped with a warning.
109
120
  """
110
121
 
122
+ results = []
111
123
  for lora_name, lora_path in lora_paths.items():
112
- if lora_name in self.loras:
113
- logger.warning(
114
- f"LoRA adapter {lora_name} is already loaded."
115
- "If you want to reload it, please unload it first."
116
- )
117
- continue
124
+ result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
125
+ results.append(result)
126
+
127
+ self.update_state_from_configs()
128
+
129
+ return self.create_lora_update_result(
130
+ success=all(result.success for result in results),
131
+ error_message="\n".join(
132
+ result.error_message for result in results if not result.success
133
+ ),
134
+ )
135
+
136
+ def load_lora_adapter(
137
+ self, lora_name: str, lora_path: str, update_state: bool = True
138
+ ) -> LoRAUpdateResult:
139
+ """
140
+ Load a single LoRA adapter from the specified path.
141
+
142
+ Args:
143
+ lora_name (str): The name of the LoRA adapter.
144
+ lora_path (str): The file path to the LoRA adapter.
145
+ update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
146
+ """
118
147
 
148
+ success = True
149
+ error_message = ""
150
+
151
+ if lora_name in self.loras:
152
+ success = False
153
+ error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
154
+
155
+ try:
119
156
  self.configs[lora_name] = LoRAConfig(lora_path)
157
+ except Exception as e:
158
+ success = False
159
+ error_message = (
160
+ f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
161
+ )
120
162
 
121
- self.update_state_from_configs()
163
+ if update_state:
164
+ self.update_state_from_configs()
165
+
166
+ return self.create_lora_update_result(
167
+ success=success,
168
+ error_message=error_message,
169
+ )
122
170
 
123
- def unload_lora_adapters(self, lora_names: Set[str]):
171
+ def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
124
172
  """
125
173
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
174
  delete the corresponding LoRA modules.
127
-
128
- Args:
129
- lora_names (Set[str]): A set of LoRA adapter names to unload.
130
175
  """
131
- for lora_name in lora_names:
132
- if lora_name in self.loras:
133
- del self.configs[lora_name]
134
- else:
135
- logger.warning(f"LoRA adapter {lora_name} is not loaded.")
176
+
177
+ success = True
178
+ error_message = ""
179
+ if lora_name in self.loras:
180
+ del self.configs[lora_name]
181
+ else:
182
+ error_message = f"LoRA adapter {lora_name} is not loaded."
183
+ success = False
136
184
 
137
185
  self.update_state_from_configs()
138
186
 
187
+ return self.create_lora_update_result(
188
+ success=success,
189
+ error_message=error_message,
190
+ )
191
+
139
192
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
140
193
  # load active loras into lora memory pool
141
194
  cur_uids = set(forward_batch.lora_paths)
@@ -372,8 +425,8 @@ class LoRAManager:
372
425
  lora_adapter.initialize_weights()
373
426
  self.loras[name] = lora_adapter
374
427
 
375
- # Clean up unused LoRA adapters
376
- for name in self.loras:
428
+ # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
429
+ for name in list(self.loras):
377
430
  if name not in self.configs:
378
431
  logger.info(f"Unloading LoRA adapter {name}")
379
432
  del self.loras[name]