sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/_custom_ops.py +29 -1
  2. sglang/srt/configs/model_config.py +1 -1
  3. sglang/srt/conversation.py +1 -1
  4. sglang/srt/disaggregation/common/conn.py +34 -6
  5. sglang/srt/disaggregation/mini_lb.py +3 -2
  6. sglang/srt/disaggregation/mooncake/conn.py +49 -20
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  8. sglang/srt/disaggregation/nixl/conn.py +17 -13
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  10. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  11. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  12. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  13. sglang/srt/distributed/parallel_state.py +70 -15
  14. sglang/srt/entrypoints/engine.py +2 -8
  15. sglang/srt/entrypoints/http_server.py +20 -32
  16. sglang/srt/entrypoints/openai/protocol.py +3 -3
  17. sglang/srt/entrypoints/openai/serving_chat.py +27 -4
  18. sglang/srt/function_call/base_format_detector.py +74 -12
  19. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  20. sglang/srt/function_call/ebnf_composer.py +95 -63
  21. sglang/srt/function_call/function_call_parser.py +4 -4
  22. sglang/srt/function_call/kimik2_detector.py +41 -16
  23. sglang/srt/function_call/llama32_detector.py +6 -3
  24. sglang/srt/function_call/mistral_detector.py +11 -3
  25. sglang/srt/function_call/pythonic_detector.py +16 -14
  26. sglang/srt/function_call/qwen25_detector.py +12 -3
  27. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
  28. sglang/srt/layers/activation.py +11 -3
  29. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  30. sglang/srt/layers/communicator.py +12 -12
  31. sglang/srt/layers/dp_attention.py +72 -24
  32. sglang/srt/layers/logits_processor.py +34 -24
  33. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  35. sglang/srt/layers/moe/topk.py +5 -13
  36. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  37. sglang/srt/layers/quantization/modelopt_quant.py +8 -4
  38. sglang/srt/layers/quantization/utils.py +0 -9
  39. sglang/srt/layers/radix_attention.py +5 -3
  40. sglang/srt/lora/lora_manager.py +133 -169
  41. sglang/srt/lora/lora_registry.py +124 -0
  42. sglang/srt/lora/mem_pool.py +2 -2
  43. sglang/srt/managers/cache_controller.py +53 -6
  44. sglang/srt/managers/io_struct.py +19 -1
  45. sglang/srt/managers/schedule_batch.py +13 -3
  46. sglang/srt/managers/scheduler.py +13 -25
  47. sglang/srt/managers/tokenizer_manager.py +28 -25
  48. sglang/srt/managers/tp_worker.py +2 -4
  49. sglang/srt/mem_cache/allocator.py +67 -7
  50. sglang/srt/mem_cache/hicache_storage.py +17 -1
  51. sglang/srt/mem_cache/hiradix_cache.py +30 -16
  52. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  53. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  54. sglang/srt/model_executor/forward_batch_info.py +201 -29
  55. sglang/srt/model_executor/model_runner.py +41 -23
  56. sglang/srt/models/deepseek_v2.py +1 -2
  57. sglang/srt/models/mllama4.py +10 -3
  58. sglang/srt/models/qwen2_moe.py +0 -4
  59. sglang/srt/models/qwen3_moe.py +1 -6
  60. sglang/srt/reasoning_parser.py +46 -4
  61. sglang/srt/sampling/sampling_batch_info.py +6 -5
  62. sglang/srt/server_args.py +76 -55
  63. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  65. sglang/srt/speculative/eagle_utils.py +51 -23
  66. sglang/srt/speculative/eagle_worker.py +59 -44
  67. sglang/srt/two_batch_overlap.py +9 -5
  68. sglang/srt/utils.py +17 -68
  69. sglang/test/test_activation.py +50 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
  72. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
  73. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -53,9 +53,7 @@ elif _is_hip:
53
53
  from aiter import moe_sum
54
54
  except ImportError:
55
55
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
56
- else:
57
- from vllm import _custom_ops as vllm_ops
58
- from vllm._custom_ops import scaled_fp8_quant
56
+
59
57
 
60
58
  if _is_cuda or _is_hip:
61
59
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
63
61
 
64
62
  logger = logging.getLogger(__name__)
65
63
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
66
- enable_moe_align_block_size_triton = bool(
67
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
68
- )
69
64
 
70
65
 
71
66
  @triton.jit
@@ -533,190 +528,6 @@ def fused_moe_kernel(
533
528
  tl.store(c_ptrs, accumulator, mask=c_mask)
534
529
 
535
530
 
536
- @triton.jit
537
- def moe_align_block_size_stage1(
538
- topk_ids_ptr,
539
- tokens_cnts_ptr,
540
- num_experts: tl.constexpr,
541
- numel: tl.constexpr,
542
- tokens_per_thread: tl.constexpr,
543
- ):
544
- pid = tl.program_id(0)
545
-
546
- start_idx = pid * tokens_per_thread
547
-
548
- off_c = (pid + 1) * num_experts
549
-
550
- for i in range(tokens_per_thread):
551
- if start_idx + i < numel:
552
- idx = tl.load(topk_ids_ptr + start_idx + i)
553
- token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
554
- tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
555
-
556
-
557
- @triton.jit
558
- def moe_align_block_size_stage2(
559
- tokens_cnts_ptr,
560
- num_experts: tl.constexpr,
561
- ):
562
- pid = tl.program_id(0)
563
-
564
- last_cnt = 0
565
- for i in range(1, num_experts + 1):
566
- token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
567
- last_cnt = last_cnt + token_cnt
568
- tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
569
-
570
-
571
- @triton.jit
572
- def moe_align_block_size_stage3(
573
- total_tokens_post_pad_ptr,
574
- tokens_cnts_ptr,
575
- cumsum_ptr,
576
- num_experts: tl.constexpr,
577
- block_size: tl.constexpr,
578
- ):
579
- last_cumsum = 0
580
- off_cnt = num_experts * num_experts
581
- for i in range(1, num_experts + 1):
582
- token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
583
- last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
584
- tl.store(cumsum_ptr + i, last_cumsum)
585
- tl.store(total_tokens_post_pad_ptr, last_cumsum)
586
-
587
-
588
- @triton.jit
589
- def moe_align_block_size_stage4(
590
- topk_ids_ptr,
591
- sorted_token_ids_ptr,
592
- expert_ids_ptr,
593
- tokens_cnts_ptr,
594
- cumsum_ptr,
595
- num_experts: tl.constexpr,
596
- block_size: tl.constexpr,
597
- numel: tl.constexpr,
598
- tokens_per_thread: tl.constexpr,
599
- ):
600
- pid = tl.program_id(0)
601
- start_idx = tl.load(cumsum_ptr + pid)
602
- end_idx = tl.load(cumsum_ptr + pid + 1)
603
-
604
- for i in range(start_idx, end_idx, block_size):
605
- tl.store(expert_ids_ptr + i // block_size, pid)
606
-
607
- start_idx = pid * tokens_per_thread
608
- off_t = pid * num_experts
609
-
610
- for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
611
- expert_id = tl.load(topk_ids_ptr + i)
612
- token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
613
- rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
614
- tl.store(sorted_token_ids_ptr + rank_post_pad, i)
615
- tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
616
-
617
-
618
- def moe_align_block_size_triton(
619
- topk_ids: torch.Tensor,
620
- num_experts: int,
621
- block_size: int,
622
- sorted_token_ids: torch.Tensor,
623
- expert_ids: torch.Tensor,
624
- num_tokens_post_pad: torch.Tensor,
625
- ) -> None:
626
- numel = topk_ids.numel()
627
- grid = (num_experts,)
628
- tokens_cnts = torch.zeros(
629
- (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
630
- )
631
- cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
632
- tokens_per_thread = ceil_div(numel, num_experts)
633
-
634
- moe_align_block_size_stage1[grid](
635
- topk_ids,
636
- tokens_cnts,
637
- num_experts,
638
- numel,
639
- tokens_per_thread,
640
- )
641
- moe_align_block_size_stage2[grid](
642
- tokens_cnts,
643
- num_experts,
644
- )
645
- moe_align_block_size_stage3[(1,)](
646
- num_tokens_post_pad,
647
- tokens_cnts,
648
- cumsum,
649
- num_experts,
650
- block_size,
651
- )
652
- moe_align_block_size_stage4[grid](
653
- topk_ids,
654
- sorted_token_ids,
655
- expert_ids,
656
- tokens_cnts,
657
- cumsum,
658
- num_experts,
659
- block_size,
660
- numel,
661
- tokens_per_thread,
662
- )
663
-
664
-
665
- @triton.jit
666
- def init_sorted_ids_and_cumsum_buffer_kernel(
667
- sorted_ids_ptr,
668
- cumsum_buffer_ptr,
669
- max_num_tokens_padded,
670
- topk_ids_numel,
671
- num_experts: tl.constexpr,
672
- BLOCK_SIZE: tl.constexpr,
673
- ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
674
- ):
675
- pid = tl.program_id(0)
676
- offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
677
-
678
- sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
679
-
680
- if pid < sorted_ids_blocks:
681
- mask = offsets < max_num_tokens_padded
682
- tl.store(
683
- sorted_ids_ptr + offsets,
684
- tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
685
- mask=mask,
686
- )
687
- elif pid == sorted_ids_blocks:
688
- offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
689
- mask_e = offset_e < num_experts + 1
690
- tl.store(
691
- cumsum_buffer_ptr + offset_e,
692
- tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
693
- mask=mask_e,
694
- )
695
-
696
-
697
- def init_sorted_ids_and_cumsum_buffer(
698
- max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
699
- ):
700
- sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
701
- cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
702
-
703
- BLOCK_SIZE = 1024
704
- sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
705
- grid = (sorted_ids_blocks + 1,)
706
-
707
- init_sorted_ids_and_cumsum_buffer_kernel[grid](
708
- sorted_ids,
709
- cumsum_buffer,
710
- max_num_tokens_padded,
711
- topk_ids_numel,
712
- num_experts,
713
- BLOCK_SIZE,
714
- next_power_of_2(num_experts + 1),
715
- )
716
-
717
- return sorted_ids, cumsum_buffer
718
-
719
-
720
531
  def moe_align_block_size(
721
532
  topk_ids: torch.Tensor, block_size: int, num_experts: int
722
533
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -766,42 +577,32 @@ def moe_align_block_size(
766
577
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
767
578
  )
768
579
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
769
- if enable_moe_align_block_size_triton:
770
- sorted_ids.fill_(topk_ids.numel())
771
- moe_align_block_size_triton(
772
- topk_ids,
773
- num_experts,
774
- block_size,
775
- sorted_ids,
776
- expert_ids,
777
- num_tokens_post_pad,
778
- )
779
- else:
780
- cumsum_buffer = torch.empty(
781
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
782
- )
783
- token_cnts_buffer = torch.empty(
784
- (num_experts + 1) * num_experts,
785
- dtype=torch.int32,
786
- device=topk_ids.device,
787
- )
788
580
 
789
- # Threshold based on benchmark results
790
- fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
791
- if not fuse_sorted_ids_padding:
792
- sorted_ids.fill_(topk_ids.numel())
581
+ cumsum_buffer = torch.empty(
582
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
+ )
584
+ token_cnts_buffer = torch.empty(
585
+ (num_experts + 1) * num_experts,
586
+ dtype=torch.int32,
587
+ device=topk_ids.device,
588
+ )
793
589
 
794
- sgl_moe_align_block_size(
795
- topk_ids,
796
- num_experts,
797
- block_size,
798
- sorted_ids,
799
- expert_ids,
800
- num_tokens_post_pad,
801
- token_cnts_buffer,
802
- cumsum_buffer,
803
- fuse_sorted_ids_padding,
804
- )
590
+ # Threshold based on benchmark results
591
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
592
+ if not fuse_sorted_ids_padding:
593
+ sorted_ids.fill_(topk_ids.numel())
594
+
595
+ sgl_moe_align_block_size(
596
+ topk_ids,
597
+ num_experts,
598
+ block_size,
599
+ sorted_ids,
600
+ expert_ids,
601
+ num_tokens_post_pad,
602
+ token_cnts_buffer,
603
+ cumsum_buffer,
604
+ fuse_sorted_ids_padding,
605
+ )
805
606
  return sorted_ids, expert_ids, num_tokens_post_pad
806
607
 
807
608
 
@@ -15,7 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import math
18
- from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
18
+ from typing import Callable, NamedTuple, Optional
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
@@ -39,10 +39,10 @@ from sglang.srt.utils import (
39
39
 
40
40
  _is_cuda = is_cuda()
41
41
  _is_hip = is_hip()
42
- _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
- _is_cpu_amx_available = cpu_has_amx_support()
44
42
  _is_cpu = is_cpu()
43
+ _is_cpu_amx_available = cpu_has_amx_support()
45
44
  _is_npu = is_npu()
45
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
46
46
 
47
47
  if _is_cuda:
48
48
  from sgl_kernel import moe_fused_gate
@@ -54,7 +54,6 @@ if _use_aiter:
54
54
  from aiter import biased_grouped_topk as aiter_biased_grouped_topk
55
55
  except ImportError:
56
56
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
57
-
58
57
  if _is_npu:
59
58
  import torch_npu
60
59
 
@@ -387,6 +386,7 @@ def grouped_topk_cpu(
387
386
  )
388
387
 
389
388
 
389
+ @torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
390
390
  def biased_grouped_topk_impl(
391
391
  hidden_states: torch.Tensor,
392
392
  gating_output: torch.Tensor,
@@ -482,7 +482,6 @@ def biased_grouped_topk_gpu(
482
482
  renormalize: bool,
483
483
  num_expert_group: int = 0,
484
484
  topk_group: int = 0,
485
- compiled: bool = not _is_npu,
486
485
  num_fused_shared_experts: int = 0,
487
486
  routed_scaling_factor: Optional[float] = None,
488
487
  num_token_non_padded: Optional[torch.Tensor] = None,
@@ -535,14 +534,7 @@ def biased_grouped_topk_gpu(
535
534
  )
536
535
  return topk_weights, topk_ids
537
536
  else:
538
- biased_grouped_topk_fn = (
539
- torch.compile(
540
- biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
541
- )
542
- if compiled
543
- else biased_grouped_topk_impl
544
- )
545
- return biased_grouped_topk_fn(
537
+ return biased_grouped_topk_impl(
546
538
  hidden_states,
547
539
  gating_output,
548
540
  correction_bias,
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
28
28
  CompressedTensorsConfig,
29
29
  )
30
30
 
31
- _is_cuda = is_cuda()
32
- _is_npu = is_npu()
33
- _is_cpu_amx_available = cpu_has_amx_support()
34
- _is_cpu = is_cpu()
35
- _is_hip = is_hip()
36
-
37
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
38
- from vllm import _custom_ops as vllm_ops
39
- from vllm._custom_ops import scaled_fp8_quant
40
31
 
41
32
  try:
42
33
  import vllm
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
568
559
  requires_grad=False,
569
560
  )
570
561
 
562
+ from vllm import _custom_ops as vllm_ops
563
+
571
564
  marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
572
565
  layer.w13_weight_packed,
573
566
  layer.w13_g_idx_sort_indices,
@@ -952,7 +952,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
952
952
  tp_rank: Optional[int] = None,
953
953
  tp_size: Optional[int] = None,
954
954
  ) -> torch.Tensor:
955
-
956
955
  assert activation == "silu", "Only SiLU activation is supported."
957
956
 
958
957
  if self.enable_flashinfer_moe:
@@ -982,13 +981,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
982
981
  tp_size=tp_size,
983
982
  tp_rank=tp_rank,
984
983
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
985
- )
986
- return output[0]
984
+ )[0]
985
+ if routed_scaling_factor is not None:
986
+ output *= routed_scaling_factor
987
+ return output
987
988
 
988
989
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
989
990
 
990
991
  topk_weights, topk_ids, _ = topk_output
991
- return cutlass_moe_fp4(
992
+ output = cutlass_moe_fp4(
992
993
  a=x,
993
994
  a1_gscale=layer.w13_input_scale_quant,
994
995
  w1_fp4=layer.w13_weight,
@@ -1003,3 +1004,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1003
1004
  params=layer.cutlass_moe_params,
1004
1005
  apply_router_weight_on_input=apply_router_weight_on_input,
1005
1006
  ).to(x.dtype)
1007
+ if routed_scaling_factor is not None:
1008
+ output *= routed_scaling_factor
1009
+ return output
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
 
20
- _is_cuda = is_cuda()
21
- _is_npu = is_npu()
22
- _is_cpu_amx_available = cpu_has_amx_support()
23
- _is_cpu = is_cpu()
24
- _is_hip = is_hip()
25
-
26
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
27
- from vllm._custom_ops import scaled_fp8_quant
28
-
29
20
 
30
21
  def is_layer_skipped(
31
22
  prefix: str,
@@ -12,14 +12,16 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
+ from __future__ import annotations
15
16
 
16
17
  from enum import Enum
17
- from typing import Optional
18
+ from typing import TYPE_CHECKING, Optional
18
19
 
19
20
  from torch import nn
20
21
 
21
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
25
 
24
26
 
25
27
  class AttentionType(Enum):