sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Tuple
1
+ from typing import Optional, Tuple
2
2
 
3
3
  import torch
4
4
  import triton
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
16
16
  moe_router_weight_ptr, # input (num_experts, hidden_dim)
17
17
  topk_weights_ptr, # output (bs, topk)
18
18
  topk_ids_ptr, # output (bs, topk)
19
+ correction_bias_ptr,
20
+ is_correction_bias: tl.constexpr,
19
21
  num_experts: tl.constexpr,
20
22
  topk: tl.constexpr,
21
23
  moe_softcapping: tl.constexpr,
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
49
51
  bottom = exped + 1
50
52
  logits_softcapped = top / bottom * moe_softcapping
51
53
 
54
+ # Add bias after softcapping
55
+ if is_correction_bias:
56
+ bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
57
+ logits_softcapped = logits_softcapped + bias
58
+
52
59
  # topk
53
60
  # assert 1 <= topk <= num_experts
54
61
 
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
109
116
  router_weight: torch.Tensor,
110
117
  topk: int,
111
118
  moe_softcapping: float,
119
+ correction_bias: Optional[torch.Tensor] = None,
112
120
  ):
113
121
  assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
114
122
  bs, hidden_dim = x.shape
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
117
125
  # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
118
126
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
119
127
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
128
+ is_correction_bias = correction_bias is not None
120
129
 
121
- grid = lambda meta: (bs,)
122
-
123
- min_num_warps = 16 if _is_hip else 32
124
-
130
+ max_warps = 16 if _is_hip else 32
125
131
  config = {
126
132
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
127
133
  "num_warps": max(
128
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
134
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
129
135
  ),
130
136
  }
131
137
 
132
- fused_moe_router_kernel[grid](
138
+ fused_moe_router_kernel[(bs,)](
133
139
  x,
134
140
  router_weight,
135
141
  topk_weights,
136
142
  topk_ids,
143
+ correction_bias,
144
+ is_correction_bias=is_correction_bias,
137
145
  num_experts=num_experts,
138
146
  topk=topk,
139
147
  moe_softcapping=moe_softcapping,
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
153
161
  topk_ids_ptr, # output (bs, topk)
154
162
  bs,
155
163
  num_experts: tl.constexpr,
156
- topk: tl.constexpr, # only support topk == 1
164
+ topk: tl.constexpr, # only support topk <= 2
157
165
  moe_softcapping: tl.constexpr,
158
166
  moe_renormalize: tl.constexpr, # not supported
159
167
  K: tl.constexpr,
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
204
212
  logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
205
213
 
206
214
  # 5. top1
207
- cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
208
- top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
215
+ arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
216
+ cond_top1 = arange_block_size_n < num_experts
217
+ top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
209
218
  top1_v = tl.max(
210
- tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
219
+ tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
211
220
  )
212
- invsumexp = 1.0 / tl.sum(
213
- tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
221
+ top1_invsumexp = 1.0 / tl.sum(
222
+ tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
214
223
  )
215
224
 
216
- # 6. store to output
217
- offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
218
- topk_mask = offs_topk < bs
219
- tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
225
+ # 6. store top1 to output
226
+ offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
227
+ top1_mask = offs_top1 < bs * topk
228
+ tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
220
229
  tl.store(
221
- topk_weights_ptr + offs_topk,
222
- invsumexp,
223
- mask=topk_mask,
230
+ topk_weights_ptr + offs_top1,
231
+ top1_invsumexp,
232
+ mask=top1_mask,
224
233
  )
225
234
 
235
+ # 7. handle topk == 2
236
+ if topk == 2:
237
+ cond_top2 = (arange_block_size_n < num_experts) and (
238
+ arange_block_size_n != top1[:, None]
239
+ )
240
+ top2 = tl.argmax(
241
+ tl.where(cond_top2, logits_softcapped, float("-inf")),
242
+ axis=1,
243
+ keep_dims=True,
244
+ )
245
+ top2_v = tl.sum(
246
+ logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
247
+ )
248
+ top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
249
+
250
+ # store top2
251
+ offs_top2 = (
252
+ pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
253
+ )
254
+ top2_mask = offs_top2 < bs * topk
255
+ tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
256
+ tl.store(
257
+ topk_weights_ptr + offs_top2,
258
+ top2_invsumexp,
259
+ mask=top2_mask,
260
+ )
261
+
226
262
 
227
263
  def fused_moe_router_large_bs_impl(
228
264
  x: torch.Tensor,
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
239
275
 
240
276
  assert num_experts <= BLOCK_SIZE_N
241
277
  assert hidden_dim % BLOCK_SIZE_K == 0
242
- assert topk == 1
278
+ assert topk <= 2
243
279
 
244
280
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
245
281
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
273
309
  gating_output,
274
310
  topk,
275
311
  renormalize,
312
+ correction_bias: Optional[torch.Tensor] = None,
276
313
  ):
277
314
  assert not renormalize
278
315
  assert (
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
286
323
  BLOCK_SIZE_K = 256
287
324
  if (
288
325
  bs >= 512
289
- and topk == 1
326
+ and topk <= 2
290
327
  and num_experts <= BLOCK_SIZE_N
291
328
  and hidden_dim % BLOCK_SIZE_K == 0
292
329
  ):
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
305
342
  router_weight=gating_output,
306
343
  topk=topk,
307
344
  moe_softcapping=moe_softcapping,
345
+ correction_bias=correction_bias,
308
346
  )
309
347
 
310
348
 
@@ -18,34 +18,43 @@ from typing import Callable, Optional
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
 
21
- from sglang.srt.managers import expert_location_dispatch
22
- from sglang.srt.managers.expert_distribution import (
21
+ from sglang.srt.eplb import expert_location_dispatch
22
+ from sglang.srt.eplb.expert_distribution import (
23
23
  ExpertDistributionRecorder,
24
24
  get_global_expert_distribution_recorder,
25
25
  )
26
- from sglang.srt.managers.expert_location_dispatch import (
26
+ from sglang.srt.eplb.expert_location_dispatch import (
27
27
  ExpertLocationDispatchInfo,
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
31
  from sglang.srt.utils import (
32
32
  cpu_has_amx_support,
33
+ get_bool_env_var,
33
34
  get_compiler_backend,
34
35
  is_cpu,
35
36
  is_cuda,
36
37
  is_hip,
38
+ is_npu,
37
39
  )
38
40
 
39
41
  _is_cuda = is_cuda()
40
42
  _is_hip = is_hip()
43
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
41
44
  _is_cpu_amx_available = cpu_has_amx_support()
42
45
  _is_cpu = is_cpu()
46
+ _is_npu = is_npu()
43
47
 
44
48
  if _is_cuda:
45
49
  from sgl_kernel import moe_fused_gate
46
50
 
47
51
  if _is_cuda or _is_hip:
48
52
  from sgl_kernel import topk_softmax
53
+ if _use_aiter:
54
+ try:
55
+ from aiter import biased_grouped_topk as aiter_biased_grouped_topk
56
+ except ImportError:
57
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
49
58
 
50
59
 
51
60
  def fused_topk_torch_native(
@@ -99,37 +108,14 @@ def fused_topk(
99
108
  M, topk, dtype=torch.float32, device=hidden_states.device
100
109
  )
101
110
  topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
102
- token_expert_indicies = torch.empty(
103
- M, topk, dtype=torch.int32, device=hidden_states.device
104
- )
105
111
 
106
112
  topk_softmax(
107
113
  topk_weights,
108
114
  topk_ids,
109
- token_expert_indicies,
110
- gating_output.float(),
111
- )
112
- del token_expert_indicies
113
-
114
- return _fused_topk_postprocess(
115
- topk_weights=topk_weights,
116
- topk_ids=topk_ids,
117
- renormalize=renormalize,
118
- expert_location_dispatch_info=expert_location_dispatch_info,
119
- num_token_non_padded=num_token_non_padded,
115
+ gating_output,
116
+ renormalize,
120
117
  )
121
118
 
122
-
123
- @torch.compile(dynamic=True, backend=get_compiler_backend())
124
- def _fused_topk_postprocess(
125
- topk_weights,
126
- topk_ids,
127
- renormalize,
128
- expert_location_dispatch_info,
129
- num_token_non_padded,
130
- ):
131
- if renormalize:
132
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
133
119
  topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
134
120
  _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
135
121
  return topk_weights, topk_ids
@@ -152,6 +138,9 @@ def grouped_topk_gpu(
152
138
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
153
139
 
154
140
  scores = torch.softmax(gating_output, dim=-1)
141
+ # NPU compiler limitation
142
+ if _is_npu and scores.dtype == torch.bfloat16:
143
+ scores = scores.to(torch.float16)
155
144
  num_token = scores.shape[0]
156
145
  num_experts = scores.shape[1]
157
146
  group_scores = (
@@ -347,6 +336,25 @@ def biased_grouped_topk_gpu(
347
336
  topk_ids, expert_location_dispatch_info, num_token_non_padded
348
337
  )
349
338
  return topk_weights, topk_ids
339
+ elif _use_aiter:
340
+ token = gating_output.shape[0]
341
+ device = gating_output.device
342
+ assert (
343
+ hidden_states.shape[0] == gating_output.shape[0]
344
+ ), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
345
+ topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
346
+ topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
347
+ aiter_biased_grouped_topk(
348
+ gating_output,
349
+ correction_bias,
350
+ topk_weights,
351
+ topk_ids,
352
+ num_expert_group,
353
+ topk_group,
354
+ renormalize,
355
+ routed_scaling_factor,
356
+ )
357
+ return topk_weights, topk_ids
350
358
  else:
351
359
  biased_grouped_topk_fn = (
352
360
  torch.compile(
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.utils import is_cpu
11
+
10
12
  __all__ = [
11
13
  "BasevLLMParameter",
12
14
  "PackedvLLMParameter",
@@ -21,6 +23,8 @@ __all__ = [
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
26
+ _is_cpu = is_cpu()
27
+
24
28
 
25
29
  class BasevLLMParameter(Parameter):
26
30
  """
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
93
97
  ):
94
98
  if not use_presharded_weights:
95
99
  shard_size = self.data.shape[self.output_dim]
96
- loaded_weight = loaded_weight.narrow(
97
- self.output_dim, tp_rank * shard_size, shard_size
100
+
101
+ from sglang.srt.model_loader.weight_utils import (
102
+ narrow_padded_param_and_loaded_weight,
98
103
  )
104
+
105
+ if _is_cpu:
106
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
107
+ self.data,
108
+ loaded_weight,
109
+ 0, # param_data_start
110
+ tp_rank * shard_size,
111
+ self.output_dim,
112
+ shard_size,
113
+ )
114
+ assert param_data.shape == loaded_weight.shape
115
+ param_data.copy_(loaded_weight)
116
+ return
117
+ else:
118
+ loaded_weight = loaded_weight.narrow(
119
+ self.output_dim, tp_rank * shard_size, shard_size
120
+ )
121
+
99
122
  assert self.data.shape == loaded_weight.shape
100
123
  self.data.copy_(loaded_weight)
101
124
 
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
116
139
  param_data = self.data
117
140
 
118
141
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
119
- if not use_presharded_weights:
120
- loaded_weight = loaded_weight.narrow(
121
- self.output_dim, tp_rank * shard_size, shard_size
142
+
143
+ from sglang.srt.model_loader.weight_utils import (
144
+ narrow_padded_param_and_loaded_weight,
145
+ )
146
+
147
+ if _is_cpu:
148
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
149
+ param_data,
150
+ loaded_weight,
151
+ 0, # param_data_start
152
+ tp_rank * shard_size,
153
+ self.output_dim,
154
+ shard_size,
155
+ not use_presharded_weights,
122
156
  )
157
+ else:
158
+ if not use_presharded_weights:
159
+ loaded_weight = loaded_weight.narrow(
160
+ self.output_dim, tp_rank * shard_size, shard_size
161
+ )
162
+
123
163
  assert param_data.shape == loaded_weight.shape
124
164
  param_data.copy_(loaded_weight)
125
165
 
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
182
222
  ):
183
223
  if not use_presharded_weights:
184
224
  shard_size = self.data.shape[self.input_dim]
185
- loaded_weight = loaded_weight.narrow(
186
- self.input_dim, tp_rank * shard_size, shard_size
225
+
226
+ from sglang.srt.model_loader.weight_utils import (
227
+ narrow_padded_param_and_loaded_weight,
187
228
  )
188
229
 
230
+ if _is_cpu:
231
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
232
+ self.data,
233
+ loaded_weight,
234
+ 0, # param_data_start
235
+ tp_rank * shard_size,
236
+ self.input_dim,
237
+ shard_size,
238
+ )
239
+
240
+ assert param_data.shape == loaded_weight.shape
241
+ param_data.copy_(loaded_weight)
242
+
243
+ return
244
+ else:
245
+ loaded_weight = loaded_weight.narrow(
246
+ self.input_dim, tp_rank * shard_size, shard_size
247
+ )
248
+
189
249
  if len(loaded_weight.shape) == 0:
190
250
  loaded_weight = loaded_weight.reshape(1)
191
251
 
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
76
76
  layer.input_scale = torch.nn.Parameter(
77
77
  layer.input_scale.data, requires_grad=False
78
78
  )
79
- prepare_fp8_layer_for_marlin(layer, strategy="channel")
79
+ prepare_fp8_layer_for_marlin(layer, size_k_first=True)
80
80
 
81
81
  def create_weights(
82
82
  self,
@@ -27,6 +27,7 @@ except ImportError:
27
27
 
28
28
 
29
29
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
30
31
  from sglang.srt.layers.linear import (
31
32
  LinearBase,
32
33
  LinearMethodBase,
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
73
74
  log_info_on_rank0,
74
75
  print_warning_once,
75
76
  set_weight_attrs,
77
+ use_intel_amx_backend,
76
78
  )
77
79
 
78
80
  _is_hip = is_hip()
@@ -330,6 +332,12 @@ class Fp8LinearMethod(LinearMethodBase):
330
332
  )
331
333
 
332
334
  layer.input_scale = None
335
+ elif _is_cpu:
336
+ assert (
337
+ _is_cpu_amx_available
338
+ ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
339
+ _amx_process_weight_after_loading(layer, ["weight"])
340
+ return
333
341
  else:
334
342
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
335
343
  layer.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -426,6 +434,17 @@ class Fp8LinearMethod(LinearMethodBase):
426
434
  )
427
435
 
428
436
  if self.block_quant:
437
+ if use_intel_amx_backend(layer):
438
+ return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
439
+ x,
440
+ layer.weight,
441
+ layer.weight_scale_inv,
442
+ self.quant_config.weight_block_size,
443
+ bias,
444
+ x.dtype,
445
+ True, # is_vnni
446
+ )
447
+
429
448
  return self.w8a8_block_fp8_linear(
430
449
  input=x,
431
450
  weight=layer.weight,
@@ -746,6 +765,13 @@ class Fp8MoEMethod:
746
765
  layer.w2_weight.data = shuffle_weight(
747
766
  layer.w2_weight.contiguous(), (16, 16)
748
767
  )
768
+
769
+ if _is_cpu:
770
+ assert (
771
+ _is_cpu_amx_available
772
+ ), "Fp8MoEMethod on CPU requires that CPU has AMX support"
773
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
774
+
749
775
  return
750
776
 
751
777
  # If checkpoint is fp16 or bfloat16, quantize in place.
@@ -971,6 +997,24 @@ class Fp8MoEMethod:
971
997
  routed_scaling_factor=routed_scaling_factor,
972
998
  )
973
999
 
1000
+ if use_intel_amx_backend(layer):
1001
+ return torch.ops.sgl_kernel.fused_experts_cpu(
1002
+ x,
1003
+ layer.w13_weight,
1004
+ layer.w2_weight,
1005
+ topk_weights,
1006
+ topk_ids,
1007
+ False, # inplace See [Note] inplace should be False in fused_experts.
1008
+ False, # use_int8_w8a8
1009
+ True, # use_fp8_w8a16
1010
+ layer.w13_weight_scale_inv, # w1_scale
1011
+ layer.w2_weight_scale_inv, # w2_scale
1012
+ self.quant_config.weight_block_size, # block_size
1013
+ None, # a1_scale
1014
+ None, # a2_scale
1015
+ True, # is_vnni
1016
+ )
1017
+
974
1018
  if _is_hip:
975
1019
  ret = self.maybe_apply_hip_fused_experts(
976
1020
  layer,
@@ -23,9 +23,9 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.math_utils import align
27
26
  from sglang.srt.layers.quantization import deep_gemm_wrapper
28
27
  from sglang.srt.utils import (
28
+ align,
29
29
  direct_register_custom_op,
30
30
  get_device_core_count,
31
31
  get_device_name,
@@ -1,9 +1,7 @@
1
1
  from typing import Callable, List, Optional, Tuple
2
2
 
3
- import einops
4
3
  import torch
5
4
 
6
- from sglang.math_utils import align
7
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
8
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
9
7
  from sglang.srt.layers.utils import is_sm100_supported
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
27
25
  w8a8_block_fp8_matmul_triton,
28
26
  )
29
27
  from sglang.srt.utils import (
28
+ align,
30
29
  get_bool_env_var,
31
30
  get_cuda_version,
32
31
  get_device_capability,
@@ -42,7 +41,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
42
41
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
42
 
44
43
  if _use_aiter:
45
- from aiter import gemm_a8w8_blockscale_CK
44
+ import aiter
45
+ from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
46
+
47
+ aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
46
48
 
47
49
  if _is_cuda:
48
50
  from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
@@ -271,9 +273,7 @@ def aiter_w8a8_block_fp8_linear(
271
273
  input_2d = input.view(-1, input.shape[-1])
272
274
  output_shape = [*input.shape[:-1], weight.shape[0]]
273
275
 
274
- q_input, x_scale = per_token_group_quant_fp8(
275
- input_2d, block_size[1], column_major_scales=False
276
- )
276
+ q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
277
277
  output = gemm_a8w8_blockscale_CK(
278
278
  q_input, weight, x_scale, weight_scale, dtype=input.dtype
279
279
  )
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
344
344
  if (num_bits, sym) not in cls.TYPE_MAP:
345
345
  return False
346
346
 
347
+ assert (
348
+ VLLM_AVAILABLE
349
+ ), "vllm is not installed, to use gptq_marlin, please install vllm"
350
+
347
351
  return check_marlin_supported(
348
352
  quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
349
353
  )
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
726
730
  g_idx2=layer.w2_g_idx,
727
731
  sort_indices1=layer.w13_g_idx_sort_indices,
728
732
  sort_indices2=layer.w2_g_idx_sort_indices,
729
- num_bits=self.quant_config.quant_type.size_bits,
733
+ quant_type_id=self.quant_config.quant_type.id,
730
734
  is_k_full=self.is_k_full,
731
735
  ).to(orig_dtype)
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
131
131
  capability_tuple = get_device_capability()
132
132
  device_capability = (
133
133
  -1
134
- if capability_tuple is None
134
+ if all(capability is None for capability in capability_tuple)
135
135
  else capability_tuple[0] * 10 + capability_tuple[1]
136
136
  )
137
137
  # Avoid circular import