sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Tuple
1
+ from typing import Optional, Tuple
2
2
 
3
3
  import torch
4
4
  import triton
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
16
16
  moe_router_weight_ptr, # input (num_experts, hidden_dim)
17
17
  topk_weights_ptr, # output (bs, topk)
18
18
  topk_ids_ptr, # output (bs, topk)
19
+ correction_bias_ptr,
20
+ is_correction_bias: tl.constexpr,
19
21
  num_experts: tl.constexpr,
20
22
  topk: tl.constexpr,
21
23
  moe_softcapping: tl.constexpr,
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
49
51
  bottom = exped + 1
50
52
  logits_softcapped = top / bottom * moe_softcapping
51
53
 
54
+ # Add bias after softcapping
55
+ if is_correction_bias:
56
+ bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
57
+ logits_softcapped = logits_softcapped + bias
58
+
52
59
  # topk
53
60
  # assert 1 <= topk <= num_experts
54
61
 
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
109
116
  router_weight: torch.Tensor,
110
117
  topk: int,
111
118
  moe_softcapping: float,
119
+ correction_bias: Optional[torch.Tensor] = None,
112
120
  ):
113
121
  assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
114
122
  bs, hidden_dim = x.shape
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
117
125
  # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
118
126
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
119
127
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
128
+ is_correction_bias = correction_bias is not None
120
129
 
121
- grid = lambda meta: (bs,)
122
-
123
- min_num_warps = 16 if _is_hip else 32
124
-
130
+ max_warps = 16 if _is_hip else 32
125
131
  config = {
126
132
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
127
133
  "num_warps": max(
128
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
134
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
129
135
  ),
130
136
  }
131
137
 
132
- fused_moe_router_kernel[grid](
138
+ fused_moe_router_kernel[(bs,)](
133
139
  x,
134
140
  router_weight,
135
141
  topk_weights,
136
142
  topk_ids,
143
+ correction_bias,
144
+ is_correction_bias=is_correction_bias,
137
145
  num_experts=num_experts,
138
146
  topk=topk,
139
147
  moe_softcapping=moe_softcapping,
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
153
161
  topk_ids_ptr, # output (bs, topk)
154
162
  bs,
155
163
  num_experts: tl.constexpr,
156
- topk: tl.constexpr, # only support topk == 1
164
+ topk: tl.constexpr, # only support topk <= 2
157
165
  moe_softcapping: tl.constexpr,
158
166
  moe_renormalize: tl.constexpr, # not supported
159
167
  K: tl.constexpr,
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
204
212
  logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
205
213
 
206
214
  # 5. top1
207
- cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
208
- top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
215
+ arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
216
+ cond_top1 = arange_block_size_n < num_experts
217
+ top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
209
218
  top1_v = tl.max(
210
- tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
219
+ tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
211
220
  )
212
- invsumexp = 1.0 / tl.sum(
213
- tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
221
+ top1_invsumexp = 1.0 / tl.sum(
222
+ tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
214
223
  )
215
224
 
216
- # 6. store to output
217
- offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
218
- topk_mask = offs_topk < bs
219
- tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
225
+ # 6. store top1 to output
226
+ offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
227
+ top1_mask = offs_top1 < bs * topk
228
+ tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
220
229
  tl.store(
221
- topk_weights_ptr + offs_topk,
222
- invsumexp,
223
- mask=topk_mask,
230
+ topk_weights_ptr + offs_top1,
231
+ top1_invsumexp,
232
+ mask=top1_mask,
224
233
  )
225
234
 
235
+ # 7. handle topk == 2
236
+ if topk == 2:
237
+ cond_top2 = (arange_block_size_n < num_experts) and (
238
+ arange_block_size_n != top1[:, None]
239
+ )
240
+ top2 = tl.argmax(
241
+ tl.where(cond_top2, logits_softcapped, float("-inf")),
242
+ axis=1,
243
+ keep_dims=True,
244
+ )
245
+ top2_v = tl.sum(
246
+ logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
247
+ )
248
+ top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
249
+
250
+ # store top2
251
+ offs_top2 = (
252
+ pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
253
+ )
254
+ top2_mask = offs_top2 < bs * topk
255
+ tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
256
+ tl.store(
257
+ topk_weights_ptr + offs_top2,
258
+ top2_invsumexp,
259
+ mask=top2_mask,
260
+ )
261
+
226
262
 
227
263
  def fused_moe_router_large_bs_impl(
228
264
  x: torch.Tensor,
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
239
275
 
240
276
  assert num_experts <= BLOCK_SIZE_N
241
277
  assert hidden_dim % BLOCK_SIZE_K == 0
242
- assert topk == 1
278
+ assert topk <= 2
243
279
 
244
280
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
245
281
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
273
309
  gating_output,
274
310
  topk,
275
311
  renormalize,
312
+ correction_bias: Optional[torch.Tensor] = None,
276
313
  ):
277
314
  assert not renormalize
278
315
  assert (
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
286
323
  BLOCK_SIZE_K = 256
287
324
  if (
288
325
  bs >= 512
289
- and topk == 1
326
+ and topk <= 2
290
327
  and num_experts <= BLOCK_SIZE_N
291
328
  and hidden_dim % BLOCK_SIZE_K == 0
292
329
  ):
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
305
342
  router_weight=gating_output,
306
343
  topk=topk,
307
344
  moe_softcapping=moe_softcapping,
345
+ correction_bias=correction_bias,
308
346
  )
309
347
 
310
348
 
@@ -18,12 +18,12 @@ from typing import Callable, Optional
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
 
21
- from sglang.srt.managers import expert_location_dispatch
22
- from sglang.srt.managers.expert_distribution import (
21
+ from sglang.srt.eplb import expert_location_dispatch
22
+ from sglang.srt.eplb.expert_distribution import (
23
23
  ExpertDistributionRecorder,
24
24
  get_global_expert_distribution_recorder,
25
25
  )
26
- from sglang.srt.managers.expert_location_dispatch import (
26
+ from sglang.srt.eplb.expert_location_dispatch import (
27
27
  ExpertLocationDispatchInfo,
28
28
  topk_ids_logical_to_physical,
29
29
  )
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cpu,
36
36
  is_cuda,
37
37
  is_hip,
38
+ is_npu,
38
39
  )
39
40
 
40
41
  _is_cuda = is_cuda()
@@ -42,6 +43,7 @@ _is_hip = is_hip()
42
43
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
44
  _is_cpu_amx_available = cpu_has_amx_support()
44
45
  _is_cpu = is_cpu()
46
+ _is_npu = is_npu()
45
47
 
46
48
  if _is_cuda:
47
49
  from sgl_kernel import moe_fused_gate
@@ -106,37 +108,14 @@ def fused_topk(
106
108
  M, topk, dtype=torch.float32, device=hidden_states.device
107
109
  )
108
110
  topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
109
- token_expert_indicies = torch.empty(
110
- M, topk, dtype=torch.int32, device=hidden_states.device
111
- )
112
111
 
113
112
  topk_softmax(
114
113
  topk_weights,
115
114
  topk_ids,
116
- token_expert_indicies,
117
- gating_output.float(),
118
- )
119
- del token_expert_indicies
120
-
121
- return _fused_topk_postprocess(
122
- topk_weights=topk_weights,
123
- topk_ids=topk_ids,
124
- renormalize=renormalize,
125
- expert_location_dispatch_info=expert_location_dispatch_info,
126
- num_token_non_padded=num_token_non_padded,
115
+ gating_output,
116
+ renormalize,
127
117
  )
128
118
 
129
-
130
- @torch.compile(dynamic=True, backend=get_compiler_backend())
131
- def _fused_topk_postprocess(
132
- topk_weights,
133
- topk_ids,
134
- renormalize,
135
- expert_location_dispatch_info,
136
- num_token_non_padded,
137
- ):
138
- if renormalize:
139
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
140
119
  topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
141
120
  _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
142
121
  return topk_weights, topk_ids
@@ -159,6 +138,9 @@ def grouped_topk_gpu(
159
138
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
160
139
 
161
140
  scores = torch.softmax(gating_output, dim=-1)
141
+ # NPU compiler limitation
142
+ if _is_npu and scores.dtype == torch.bfloat16:
143
+ scores = scores.to(torch.float16)
162
144
  num_token = scores.shape[0]
163
145
  num_experts = scores.shape[1]
164
146
  group_scores = (
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.utils import is_cpu
11
+
10
12
  __all__ = [
11
13
  "BasevLLMParameter",
12
14
  "PackedvLLMParameter",
@@ -21,6 +23,8 @@ __all__ = [
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
26
+ _is_cpu = is_cpu()
27
+
24
28
 
25
29
  class BasevLLMParameter(Parameter):
26
30
  """
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
93
97
  ):
94
98
  if not use_presharded_weights:
95
99
  shard_size = self.data.shape[self.output_dim]
96
- loaded_weight = loaded_weight.narrow(
97
- self.output_dim, tp_rank * shard_size, shard_size
100
+
101
+ from sglang.srt.model_loader.weight_utils import (
102
+ narrow_padded_param_and_loaded_weight,
98
103
  )
104
+
105
+ if _is_cpu:
106
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
107
+ self.data,
108
+ loaded_weight,
109
+ 0, # param_data_start
110
+ tp_rank * shard_size,
111
+ self.output_dim,
112
+ shard_size,
113
+ )
114
+ assert param_data.shape == loaded_weight.shape
115
+ param_data.copy_(loaded_weight)
116
+ return
117
+ else:
118
+ loaded_weight = loaded_weight.narrow(
119
+ self.output_dim, tp_rank * shard_size, shard_size
120
+ )
121
+
99
122
  assert self.data.shape == loaded_weight.shape
100
123
  self.data.copy_(loaded_weight)
101
124
 
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
116
139
  param_data = self.data
117
140
 
118
141
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
119
- if not use_presharded_weights:
120
- loaded_weight = loaded_weight.narrow(
121
- self.output_dim, tp_rank * shard_size, shard_size
142
+
143
+ from sglang.srt.model_loader.weight_utils import (
144
+ narrow_padded_param_and_loaded_weight,
145
+ )
146
+
147
+ if _is_cpu:
148
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
149
+ param_data,
150
+ loaded_weight,
151
+ 0, # param_data_start
152
+ tp_rank * shard_size,
153
+ self.output_dim,
154
+ shard_size,
155
+ not use_presharded_weights,
122
156
  )
157
+ else:
158
+ if not use_presharded_weights:
159
+ loaded_weight = loaded_weight.narrow(
160
+ self.output_dim, tp_rank * shard_size, shard_size
161
+ )
162
+
123
163
  assert param_data.shape == loaded_weight.shape
124
164
  param_data.copy_(loaded_weight)
125
165
 
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
182
222
  ):
183
223
  if not use_presharded_weights:
184
224
  shard_size = self.data.shape[self.input_dim]
185
- loaded_weight = loaded_weight.narrow(
186
- self.input_dim, tp_rank * shard_size, shard_size
225
+
226
+ from sglang.srt.model_loader.weight_utils import (
227
+ narrow_padded_param_and_loaded_weight,
187
228
  )
188
229
 
230
+ if _is_cpu:
231
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
232
+ self.data,
233
+ loaded_weight,
234
+ 0, # param_data_start
235
+ tp_rank * shard_size,
236
+ self.input_dim,
237
+ shard_size,
238
+ )
239
+
240
+ assert param_data.shape == loaded_weight.shape
241
+ param_data.copy_(loaded_weight)
242
+
243
+ return
244
+ else:
245
+ loaded_weight = loaded_weight.narrow(
246
+ self.input_dim, tp_rank * shard_size, shard_size
247
+ )
248
+
189
249
  if len(loaded_weight.shape) == 0:
190
250
  loaded_weight = loaded_weight.reshape(1)
191
251
 
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
68
68
  )
69
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
70
  from sglang.srt.layers.quantization.qoq import QoQConfig
71
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
71
72
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
72
73
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
73
74
 
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
82
83
  "moe_wna16": MoeWNA16Config,
83
84
  "compressed-tensors": CompressedTensorsConfig,
84
85
  "qoq": QoQConfig,
86
+ "w4afp8": W4AFp8Config,
85
87
  }
86
88
 
87
89
  # VLLM-dependent quantization methods
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
76
76
  layer.input_scale = torch.nn.Parameter(
77
77
  layer.input_scale.data, requires_grad=False
78
78
  )
79
- prepare_fp8_layer_for_marlin(layer, strategy="channel")
79
+ prepare_fp8_layer_for_marlin(layer, size_k_first=True)
80
80
 
81
81
  def create_weights(
82
82
  self,
@@ -1,7 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -27,6 +27,7 @@ except ImportError:
27
27
 
28
28
 
29
29
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
30
31
  from sglang.srt.layers.linear import (
31
32
  LinearBase,
32
33
  LinearMethodBase,
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
73
74
  log_info_on_rank0,
74
75
  print_warning_once,
75
76
  set_weight_attrs,
77
+ use_intel_amx_backend,
76
78
  )
77
79
 
78
80
  _is_hip = is_hip()
@@ -86,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
86
88
  _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
87
89
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
88
90
 
89
- if _is_hip:
91
+ if _is_hip and (_use_aiter or _use_hip_int4):
90
92
  from aiter import ActivationType, QuantType
91
93
  from aiter.fused_moe import fused_moe
92
94
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
@@ -198,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
198
200
  quant_config: The quantization config.
199
201
  """
200
202
 
201
- def __init__(self, quant_config: Fp8Config):
203
+ def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
202
204
  self.quant_config = quant_config
203
205
  self.cutlass_fp8_supported = cutlass_fp8_supported()
204
206
 
@@ -284,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
284
286
  if self.quant_config.is_checkpoint_fp8_serialized:
285
287
  # WEIGHT SCALE
286
288
  if self.block_quant:
287
- assert self.quant_config.activation_scheme == "dynamic"
289
+ if hasattr(self.quant_config, "activation_scheme"):
290
+ assert self.quant_config.activation_scheme == "dynamic"
291
+ elif hasattr(self.quant_config, "linear_activation_scheme"):
292
+ assert self.quant_config.linear_activation_scheme == "dynamic"
288
293
  scale = BlockQuantScaleParameter(
289
294
  data=torch.empty(
290
295
  (output_size_per_partition + block_n - 1) // block_n,
@@ -306,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
306
311
  layer.register_parameter("weight_scale", scale)
307
312
 
308
313
  # INPUT ACTIVATION SCALE
309
- if self.quant_config.activation_scheme == "static":
314
+ if (
315
+ hasattr(self.quant_config, "activation_scheme")
316
+ and self.quant_config.activation_scheme == "static"
317
+ ) or (
318
+ hasattr(self.quant_config, "linear_activation_scheme")
319
+ and self.quant_config.linear_activation_scheme == "static"
320
+ ):
310
321
  scale = PerTensorScaleParameter(
311
322
  data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
312
323
  weight_loader=weight_loader,
@@ -330,6 +341,12 @@ class Fp8LinearMethod(LinearMethodBase):
330
341
  )
331
342
 
332
343
  layer.input_scale = None
344
+ elif _is_cpu:
345
+ assert (
346
+ _is_cpu_amx_available
347
+ ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
348
+ _amx_process_weight_after_loading(layer, ["weight"])
349
+ return
333
350
  else:
334
351
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
335
352
  layer.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -363,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
363
380
  layer.weight_scale = torch.nn.Parameter(
364
381
  layer.weight_scale.data, requires_grad=False
365
382
  )
366
- if self.quant_config.activation_scheme == "static":
383
+ if (
384
+ hasattr(self.quant_config, "activation_scheme")
385
+ and self.quant_config.activation_scheme == "static"
386
+ ) or (
387
+ hasattr(self.quant_config, "linear_activation_scheme")
388
+ and self.quant_config.linear_activation_scheme == "static"
389
+ ):
367
390
  layer.input_scale = torch.nn.Parameter(
368
391
  layer.input_scale.data, requires_grad=False
369
392
  )
@@ -397,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
397
420
  # Update layer with new values.
398
421
  layer.weight = Parameter(weight.t(), requires_grad=False)
399
422
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
400
- if self.quant_config.activation_scheme == "static":
423
+ if (
424
+ hasattr(self.quant_config, "activation_scheme")
425
+ and self.quant_config.activation_scheme == "static"
426
+ ) or (
427
+ hasattr(self.quant_config, "linear_activation_scheme")
428
+ and self.quant_config.linear_activation_scheme == "static"
429
+ ):
401
430
  layer.input_scale = Parameter(
402
431
  layer.input_scale.max(), requires_grad=False
403
432
  )
@@ -426,6 +455,17 @@ class Fp8LinearMethod(LinearMethodBase):
426
455
  )
427
456
 
428
457
  if self.block_quant:
458
+ if use_intel_amx_backend(layer):
459
+ return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
460
+ x,
461
+ layer.weight,
462
+ layer.weight_scale_inv,
463
+ self.quant_config.weight_block_size,
464
+ bias,
465
+ x.dtype,
466
+ True, # is_vnni
467
+ )
468
+
429
469
  return self.w8a8_block_fp8_linear(
430
470
  input=x,
431
471
  weight=layer.weight,
@@ -746,6 +786,13 @@ class Fp8MoEMethod:
746
786
  layer.w2_weight.data = shuffle_weight(
747
787
  layer.w2_weight.contiguous(), (16, 16)
748
788
  )
789
+
790
+ if _is_cpu:
791
+ assert (
792
+ _is_cpu_amx_available
793
+ ), "Fp8MoEMethod on CPU requires that CPU has AMX support"
794
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
795
+
749
796
  return
750
797
 
751
798
  # If checkpoint is fp16 or bfloat16, quantize in place.
@@ -971,6 +1018,24 @@ class Fp8MoEMethod:
971
1018
  routed_scaling_factor=routed_scaling_factor,
972
1019
  )
973
1020
 
1021
+ if use_intel_amx_backend(layer):
1022
+ return torch.ops.sgl_kernel.fused_experts_cpu(
1023
+ x,
1024
+ layer.w13_weight,
1025
+ layer.w2_weight,
1026
+ topk_weights,
1027
+ topk_ids,
1028
+ False, # inplace See [Note] inplace should be False in fused_experts.
1029
+ False, # use_int8_w8a8
1030
+ True, # use_fp8_w8a16
1031
+ layer.w13_weight_scale_inv, # w1_scale
1032
+ layer.w2_weight_scale_inv, # w2_scale
1033
+ self.quant_config.weight_block_size, # block_size
1034
+ None, # a1_scale
1035
+ None, # a2_scale
1036
+ True, # is_vnni
1037
+ )
1038
+
974
1039
  if _is_hip:
975
1040
  ret = self.maybe_apply_hip_fused_experts(
976
1041
  layer,
@@ -23,9 +23,9 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.math_utils import align
27
26
  from sglang.srt.layers.quantization import deep_gemm_wrapper
28
27
  from sglang.srt.utils import (
28
+ align,
29
29
  direct_register_custom_op,
30
30
  get_device_core_count,
31
31
  get_device_name,
@@ -1,9 +1,7 @@
1
1
  from typing import Callable, List, Optional, Tuple
2
2
 
3
- import einops
4
3
  import torch
5
4
 
6
- from sglang.math_utils import align
7
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
8
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
9
7
  from sglang.srt.layers.utils import is_sm100_supported
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
27
25
  w8a8_block_fp8_matmul_triton,
28
26
  )
29
27
  from sglang.srt.utils import (
28
+ align,
30
29
  get_bool_env_var,
31
30
  get_cuda_version,
32
31
  get_device_capability,
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
344
344
  if (num_bits, sym) not in cls.TYPE_MAP:
345
345
  return False
346
346
 
347
+ assert (
348
+ VLLM_AVAILABLE
349
+ ), "vllm is not installed, to use gptq_marlin, please install vllm"
350
+
347
351
  return check_marlin_supported(
348
352
  quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
349
353
  )
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
726
730
  g_idx2=layer.w2_g_idx,
727
731
  sort_indices1=layer.w13_g_idx_sort_indices,
728
732
  sort_indices2=layer.w2_g_idx_sort_indices,
729
- num_bits=self.quant_config.quant_type.size_bits,
733
+ quant_type_id=self.quant_config.quant_type.id,
730
734
  is_k_full=self.is_k_full,
731
735
  ).to(orig_dtype)