sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
- from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
31
+ from sglang.srt.utils import (
32
+ cpu_has_amx_support,
33
+ get_compiler_backend,
34
+ is_cpu,
35
+ is_cuda,
36
+ is_hip,
37
+ )
32
38
 
33
39
  _is_cuda = is_cuda()
34
40
  _is_hip = is_hip()
41
+ _is_cpu_amx_available = cpu_has_amx_support()
42
+ _is_cpu = is_cpu()
35
43
 
36
44
  if _is_cuda:
37
45
  from sgl_kernel import moe_fused_gate
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
40
48
  from sgl_kernel import topk_softmax
41
49
 
42
50
 
43
- def fused_topk_native(
51
+ def fused_topk_torch_native(
44
52
  hidden_states: torch.Tensor,
45
53
  gating_output: torch.Tensor,
46
54
  topk: int,
@@ -61,6 +69,20 @@ def fused_topk_native(
61
69
  return topk_weights, topk_ids
62
70
 
63
71
 
72
+ def fused_topk_cpu(
73
+ hidden_states: torch.Tensor,
74
+ gating_output: torch.Tensor,
75
+ topk: int,
76
+ renormalize: bool,
77
+ ):
78
+ return torch.ops.sgl_kernel.topk_softmax_cpu(
79
+ hidden_states=hidden_states,
80
+ gating_output=gating_output,
81
+ topk=topk,
82
+ renormalize=renormalize,
83
+ )
84
+
85
+
64
86
  def fused_topk(
65
87
  hidden_states: torch.Tensor,
66
88
  gating_output: torch.Tensor,
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
115
137
 
116
138
  # This is used by the Deepseek V2/V3/R1 series models
117
139
  @torch.compile(dynamic=True, backend=get_compiler_backend())
118
- def grouped_topk(
140
+ def grouped_topk_gpu(
119
141
  hidden_states: torch.Tensor,
120
142
  gating_output: torch.Tensor,
121
143
  topk: int,
@@ -171,6 +193,32 @@ def grouped_topk(
171
193
  return topk_weights, topk_ids
172
194
 
173
195
 
196
+ def grouped_topk_cpu(
197
+ hidden_states: torch.Tensor,
198
+ gating_output: torch.Tensor,
199
+ topk: int,
200
+ renormalize: bool,
201
+ num_expert_group: int = 0,
202
+ topk_group: int = 0,
203
+ num_fused_shared_experts: int = 0,
204
+ routed_scaling_factor: Optional[float] = None,
205
+ num_token_non_padded: Optional[torch.Tensor] = None,
206
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
207
+ ):
208
+ assert expert_location_dispatch_info is None
209
+ return torch.ops.sgl_kernel.grouped_topk_cpu(
210
+ hidden_states,
211
+ gating_output,
212
+ topk,
213
+ renormalize,
214
+ num_expert_group,
215
+ topk_group,
216
+ num_fused_shared_experts,
217
+ routed_scaling_factor,
218
+ num_token_non_padded,
219
+ )
220
+
221
+
174
222
  def biased_grouped_topk_impl(
175
223
  hidden_states: torch.Tensor,
176
224
  gating_output: torch.Tensor,
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
258
306
  return topk_ids
259
307
 
260
308
 
261
- def biased_grouped_topk(
309
+ def biased_grouped_topk_gpu(
262
310
  hidden_states: torch.Tensor,
263
311
  gating_output: torch.Tensor,
264
312
  correction_bias: torch.Tensor,
@@ -322,6 +370,45 @@ def biased_grouped_topk(
322
370
  )
323
371
 
324
372
 
373
+ def biased_grouped_topk_cpu(
374
+ hidden_states: torch.Tensor,
375
+ gating_output: torch.Tensor,
376
+ correction_bias: torch.Tensor,
377
+ topk: int,
378
+ renormalize: bool,
379
+ num_expert_group: int = 0,
380
+ topk_group: int = 0,
381
+ compiled: bool = True,
382
+ num_fused_shared_experts: int = 0,
383
+ routed_scaling_factor: Optional[float] = None,
384
+ num_token_non_padded: Optional[torch.Tensor] = None,
385
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
386
+ ):
387
+ assert expert_location_dispatch_info is None
388
+ return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
389
+ hidden_states,
390
+ gating_output,
391
+ correction_bias,
392
+ topk,
393
+ renormalize,
394
+ num_expert_group,
395
+ topk_group,
396
+ num_fused_shared_experts,
397
+ routed_scaling_factor,
398
+ num_token_non_padded,
399
+ )
400
+
401
+
402
+ if _is_cpu and _is_cpu_amx_available:
403
+ biased_grouped_topk = biased_grouped_topk_cpu
404
+ grouped_topk = grouped_topk_cpu
405
+ fused_topk_native = fused_topk_cpu
406
+ else:
407
+ biased_grouped_topk = biased_grouped_topk_gpu
408
+ grouped_topk = grouped_topk_gpu
409
+ fused_topk_native = fused_topk_torch_native
410
+
411
+
325
412
  def select_experts(
326
413
  hidden_states: torch.Tensor,
327
414
  router_logits: torch.Tensor,
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
14
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
15
  from sglang.srt.layers.quantization.utils import (
16
16
  all_close_1d,
17
+ cpu_has_amx_support,
17
18
  per_tensor_dequantize,
18
19
  replace_parameter,
19
20
  )
20
- from sglang.srt.utils import is_cuda, set_weight_attrs
21
+ from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
21
22
 
22
23
  _is_cuda = is_cuda()
24
+ _is_npu = is_npu()
25
+ _is_cpu_amx_available = cpu_has_amx_support()
26
+ _is_cpu = is_cpu()
23
27
 
24
- if not _is_cuda:
28
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
25
29
  from vllm import _custom_ops as vllm_ops
26
30
  from vllm._custom_ops import scaled_fp8_quant
27
31
 
@@ -64,9 +64,12 @@ from sglang.srt.layers.quantization.utils import (
64
64
  )
65
65
  from sglang.srt.layers.utils import is_sm100_supported
66
66
  from sglang.srt.utils import (
67
+ cpu_has_amx_support,
67
68
  get_bool_env_var,
69
+ is_cpu,
68
70
  is_cuda,
69
71
  is_hip,
72
+ is_npu,
70
73
  log_info_on_rank0,
71
74
  print_warning_once,
72
75
  set_weight_attrs,
@@ -74,6 +77,9 @@ from sglang.srt.utils import (
74
77
 
75
78
  _is_hip = is_hip()
76
79
  _is_cuda = is_cuda()
80
+ _is_npu = is_npu()
81
+ _is_cpu_amx_available = cpu_has_amx_support()
82
+ _is_cpu = is_cpu()
77
83
 
78
84
  _is_fp8_fnuz = is_fp8_fnuz()
79
85
 
@@ -82,10 +88,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
82
88
 
83
89
  if _is_hip:
84
90
  from aiter import ActivationType, QuantType
91
+ from aiter.fused_moe import fused_moe
85
92
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
86
93
  from aiter.ops.shuffle import shuffle_weight
87
94
 
88
- if not _is_cuda:
95
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
89
96
  from vllm._custom_ops import scaled_fp8_quant
90
97
 
91
98
 
@@ -1045,15 +1052,15 @@ class Fp8MoEMethod:
1045
1052
  if _use_hip_int4:
1046
1053
  # TODO: add triton kernel and add check _use_aiter
1047
1054
  assert not no_combine, f"{no_combine=} is not supported."
1048
- return ck_moe_2stages(
1055
+ return fused_moe(
1049
1056
  x,
1050
1057
  layer.w13_weight,
1051
1058
  layer.w2_weight,
1052
1059
  topk_weights,
1053
1060
  topk_ids,
1054
- QuantType.per_Token,
1055
- layer.w13_weight_scale1,
1056
- layer.w2_weight_scale1,
1061
+ quant_type=QuantType.per_Token,
1062
+ w1_scale=layer.w13_weight_scale1,
1063
+ w2_scale=layer.w2_weight_scale1,
1057
1064
  activation=(
1058
1065
  ActivationType.Silu if activation == "silu" else ActivationType.Gelu
1059
1066
  ),
@@ -1062,31 +1069,32 @@ class Fp8MoEMethod:
1062
1069
  if _use_aiter:
1063
1070
  assert not no_combine, f"{no_combine=} is not supported."
1064
1071
  if self.block_quant:
1065
- # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
1066
- assert (
1067
- activation == "silu"
1068
- ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
1069
- return asm_moe(
1072
+ return fused_moe(
1070
1073
  x,
1071
1074
  layer.w13_weight,
1072
1075
  layer.w2_weight,
1073
1076
  topk_weights,
1074
1077
  topk_ids,
1075
- layer.w13_weight_scale_inv,
1076
- layer.w2_weight_scale_inv,
1077
- block_shape=tuple(self.quant_config.weight_block_size),
1078
+ w1_scale=layer.w13_weight_scale_inv,
1079
+ w2_scale=layer.w2_weight_scale_inv,
1080
+ quant_type=QuantType.per_128x128,
1081
+ activation=(
1082
+ ActivationType.Silu
1083
+ if activation == "silu"
1084
+ else ActivationType.Gelu
1085
+ ),
1078
1086
  expert_mask=None,
1079
1087
  )
1080
1088
  else:
1081
- return ck_moe_2stages(
1089
+ return fused_moe(
1082
1090
  x,
1083
1091
  layer.w13_weight,
1084
1092
  layer.w2_weight,
1085
1093
  topk_weights,
1086
1094
  topk_ids,
1087
- QuantType.per_Token,
1088
- layer.w13_weight_scale1,
1089
- layer.w2_weight_scale1,
1095
+ quant_type=QuantType.per_Token,
1096
+ w1_scale=layer.w13_weight_scale1,
1097
+ w2_scale=layer.w2_weight_scale1,
1090
1098
  activation=(
1091
1099
  ActivationType.Silu
1092
1100
  if activation == "silu"
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
29
29
  requantize_with_max_scale,
30
30
  )
31
31
  from sglang.srt.layers.radix_attention import RadixAttention
32
- from sglang.srt.utils import is_cuda
32
+ from sglang.srt.utils import is_cuda, next_power_of_2
33
33
 
34
34
  if is_cuda():
35
35
  from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
36
36
 
37
+ try:
38
+ from flashinfer import fp4_quantize as fp4_quantize
39
+ from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
40
+ except ImportError:
41
+ flashinfer_cutlass_fused_moe = None
42
+
37
43
  # Initialize logger for the module
38
44
  logger = logging.getLogger(__name__)
39
45
 
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
429
435
  layer.alpha = Parameter(
430
436
  layer.input_scale * layer.weight_scale_2, requires_grad=False
431
437
  )
438
+ layer.input_scale_inv = Parameter(
439
+ (1 / input_scale_2).to(torch.float32), requires_grad=False
440
+ )
432
441
 
433
442
  # Pad and blockwise interleave weight_scale
434
443
  scales = layer.weight_scale
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
467
476
  output_shape = [x_m, w_n]
468
477
 
469
478
  # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
470
- x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
479
+ x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
471
480
 
472
481
  assert x_fp4.dtype == torch.uint8
473
482
  assert x_scale_interleaved.dtype == torch.float8_e4m3fn
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
521
530
  " quantization. Please use Blackwell and"
522
531
  " above."
523
532
  )
533
+ self.enable_flashinfer_moe = False
524
534
 
525
535
  def create_weights(
526
536
  self,
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
674
684
  w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
675
685
  layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
676
686
 
677
- w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
687
+ if self.enable_flashinfer_moe:
688
+ w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
689
+ else:
690
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
678
691
  layer.g1_alphas = Parameter(
679
692
  (w13_input_scale * w13_weight_scale_2).to(torch.float32),
680
693
  requires_grad=False,
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
700
713
  layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
701
714
 
702
715
  # GEMM 2
716
+ if self.enable_flashinfer_moe:
717
+ w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
718
+ else:
719
+ w2_input_scale = layer.w2_input_scale
720
+
703
721
  layer.g2_alphas = Parameter(
704
- (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
722
+ (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
705
723
  requires_grad=False,
706
724
  )
707
725
 
708
726
  # This is for quantization, so we need to invert it.
709
727
  layer.w2_input_scale_quant = Parameter(
710
- (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
728
+ (1 / w2_input_scale).to(torch.float32), requires_grad=False
711
729
  )
712
730
 
713
731
  assert (
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
727
745
  layer.cutlass_moe_params = CutlassMoEParams(
728
746
  CutlassMoEType.BlockscaledFP4,
729
747
  device,
730
- num_experts=layer.num_experts,
748
+ num_experts=layer.num_experts, # global num experts
731
749
  intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
732
750
  hidden_size=layer.w13_weight.shape[2] * 2,
733
751
  ) # k
734
752
 
753
+ @property
754
+ def load_up_proj_weight_first(self) -> bool:
755
+ # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
756
+ return self.enable_flashinfer_moe
757
+
735
758
  def apply(
736
759
  self,
737
760
  layer: torch.nn.Module,
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
750
773
  inplace: bool = True,
751
774
  no_combine: bool = False,
752
775
  routed_scaling_factor: Optional[float] = None,
776
+ ep_rank: Optional[int] = None,
777
+ ep_size: Optional[int] = None,
778
+ tp_rank: Optional[int] = None,
779
+ tp_size: Optional[int] = None,
753
780
  ) -> torch.Tensor:
754
781
 
755
782
  assert activation == "silu", "Only SiLU activation is supported."
756
-
757
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
758
783
  from sglang.srt.layers.moe.topk import select_experts
759
784
 
760
785
  topk_weights, topk_ids = select_experts(
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
771
796
  routed_scaling_factor=routed_scaling_factor,
772
797
  )
773
798
 
799
+ if self.enable_flashinfer_moe:
800
+ assert (
801
+ not apply_router_weight_on_input
802
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
803
+ # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
804
+ # and fp4 quantized weights loaded from the checkpoint
805
+ output = flashinfer_cutlass_fused_moe(
806
+ x,
807
+ topk_ids.to(torch.int),
808
+ topk_weights,
809
+ layer.w13_weight.view(torch.long),
810
+ layer.w2_weight.view(torch.long),
811
+ x.dtype,
812
+ quant_scales=[
813
+ layer.w13_input_scale_quant,
814
+ layer.w13_blockscale_swizzled.view(torch.int32),
815
+ layer.g1_alphas,
816
+ layer.w2_input_scale_quant,
817
+ layer.w2_blockscale_swizzled.view(torch.int32),
818
+ layer.g2_alphas,
819
+ ],
820
+ ep_size=ep_size,
821
+ ep_rank=ep_rank,
822
+ tp_size=tp_size,
823
+ tp_rank=tp_rank,
824
+ tune_max_num_tokens=next_power_of_2(x.shape[0]),
825
+ )
826
+ return output[0]
827
+
774
828
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
775
829
 
776
830
  return cutlass_moe_fp4(
@@ -6,11 +6,14 @@ from typing import List, Mapping, Tuple, Union
6
6
  import torch
7
7
 
8
8
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
9
- from sglang.srt.utils import is_cuda
9
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
10
10
 
11
11
  _is_cuda = is_cuda()
12
+ _is_npu = is_npu()
13
+ _is_cpu_amx_available = cpu_has_amx_support()
14
+ _is_cpu = is_cpu()
12
15
 
13
- if not _is_cuda:
16
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
14
17
  from vllm._custom_ops import scaled_fp8_quant
15
18
 
16
19
 
@@ -8,10 +8,13 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import is_cuda, is_hip
11
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
12
12
 
13
13
  _is_cuda = is_cuda()
14
14
  _is_hip = is_hip()
15
+ _is_npu = is_npu()
16
+ _is_cpu_amx_available = cpu_has_amx_support()
17
+ _is_cpu = is_cpu()
15
18
 
16
19
  if _is_cuda:
17
20
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
@@ -84,7 +87,9 @@ class RotaryEmbedding(CustomOp):
84
87
  if not _is_cuda:
85
88
  cache = cache.to(dtype)
86
89
 
87
- if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
90
+ if (
91
+ not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
92
+ ) and not (_is_cpu and _is_cpu_amx_available):
88
93
  from vllm._custom_ops import rotary_embedding
89
94
 
90
95
  self.vllm_rotary_embedding = rotary_embedding
@@ -147,6 +152,26 @@ class RotaryEmbedding(CustomOp):
147
152
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
148
153
  return query, key
149
154
 
155
+ def forward_cpu(
156
+ self,
157
+ positions: torch.Tensor,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ offsets: Optional[torch.Tensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ positions = torch.add(positions, offsets) if offsets is not None else positions
163
+ if _is_cpu_amx_available:
164
+ return torch.ops.sgl_kernel.rotary_embedding_cpu(
165
+ positions,
166
+ query,
167
+ key,
168
+ self.head_size,
169
+ self.cos_sin_cache,
170
+ self.is_neox_style,
171
+ )
172
+ else:
173
+ return self.forward_native(positions, query, key, offsets)
174
+
150
175
  def forward_cuda(
151
176
  self,
152
177
  positions: torch.Tensor,
@@ -696,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
696
721
  key = key_rot
697
722
  return query.to(dtype), key.to(dtype)
698
723
 
724
+ def forward_cpu(
725
+ self,
726
+ positions: torch.Tensor,
727
+ query: torch.Tensor,
728
+ key: torch.Tensor,
729
+ offsets: Optional[torch.Tensor] = None,
730
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
731
+ positions = torch.add(positions, offsets) if offsets is not None else positions
732
+ if _is_cpu_amx_available:
733
+ return torch.ops.sgl_kernel.rotary_embedding_cpu(
734
+ positions, query, key, self.head_size, self.cos_sin_cache, False
735
+ )
736
+ else:
737
+ return self.forward_native(positions, query, key, offsets)
738
+
699
739
 
700
740
  class Llama3RotaryEmbedding(RotaryEmbedding):
701
741
 
@@ -91,7 +91,7 @@ class Sampler(nn.Module):
91
91
  )
92
92
  else:
93
93
  batch_next_token_ids = top_k_top_p_sampling_from_probs(
94
- probs,
94
+ probs.contiguous(),
95
95
  sampling_info.top_ks,
96
96
  sampling_info.top_ps,
97
97
  filter_apply_order="joint",