sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.layers.moe.topk import fused_topk
8
+
9
+
10
+ @triton.jit
11
+ def fused_moe_router_kernel(
12
+ input_ptr, # input (bs, hidden_dim)
13
+ moe_router_weight_ptr, # input (num_experts, hidden_dim)
14
+ topk_weights_ptr, # output (bs, topk)
15
+ topk_ids_ptr, # output (bs, topk)
16
+ num_experts: tl.constexpr,
17
+ topk: tl.constexpr,
18
+ moe_softcapping: tl.constexpr,
19
+ moe_renormalize: tl.constexpr, # not supported
20
+ hidden_dim: tl.constexpr,
21
+ BLOCK_SIZE: tl.constexpr,
22
+ ):
23
+ pid = tl.program_id(axis=0)
24
+
25
+ offsets = tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < hidden_dim
27
+
28
+ # moe_router_weight is k major
29
+ expert_offsets = tl.arange(0, num_experts)[:, None]
30
+ router_mask = mask[None, :]
31
+ w_router = tl.load(
32
+ moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
33
+ mask=router_mask,
34
+ other=0.0,
35
+ )
36
+
37
+ x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
38
+
39
+ # todo: tl.dot?
40
+ logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
41
+
42
+ # logit softcap
43
+ logits_scaled = logits / moe_softcapping
44
+ exped = tl.exp(2 * logits_scaled)
45
+ top = exped - 1
46
+ bottom = exped + 1
47
+ logits_softcapped = top / bottom * moe_softcapping
48
+
49
+ # topk
50
+ # assert 1 <= topk <= num_experts
51
+
52
+ # 5.38 us
53
+
54
+ top1 = tl.argmax(logits_softcapped, axis=0)
55
+ tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
56
+
57
+ top1_v = tl.max(logits_softcapped, axis=0)
58
+ invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
59
+
60
+ tl.store(
61
+ topk_weights_ptr + pid * topk + 0,
62
+ invsumexp,
63
+ ) # 5.73 us
64
+
65
+ if topk >= 2:
66
+ top2 = tl.argmax(
67
+ tl.where(
68
+ tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
69
+ ),
70
+ axis=0,
71
+ )
72
+ tl.store(topk_ids_ptr + pid * topk + 1, top2)
73
+ top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
74
+ tl.store(
75
+ topk_weights_ptr + pid * topk + 1,
76
+ tl.exp(top2_v - top1_v) * invsumexp,
77
+ ) # 5.95us
78
+
79
+ # probably slow
80
+ if topk > 2:
81
+ topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
82
+ topk_mask = tl.where(
83
+ tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
84
+ )
85
+ topk_mask = tl.where(
86
+ tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
87
+ )
88
+ for i in range(2, topk):
89
+ topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
90
+ topk_mask = tl.where(
91
+ tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
92
+ )
93
+ tl.store(topk_ids_ptr + pid * topk + i, topi)
94
+ topi_v = tl.sum(
95
+ logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
96
+ )
97
+ tl.store(
98
+ topk_weights_ptr + pid * topk + i,
99
+ tl.exp(topi_v - top1_v) * invsumexp,
100
+ )
101
+ # assert not moe_renormalize, "moe weight renormalization not implemented"
102
+
103
+
104
+ def fused_moe_router_impl(
105
+ x: torch.Tensor,
106
+ router_weight: torch.Tensor,
107
+ topk: int,
108
+ moe_softcapping: float,
109
+ ):
110
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
111
+ bs, hidden_dim = x.shape
112
+ num_experts = router_weight.shape[0]
113
+
114
+ # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
115
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
116
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
+
118
+ grid = lambda meta: (bs,)
119
+ config = {
120
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
+ "num_warps": max(
122
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
123
+ ),
124
+ }
125
+
126
+ fused_moe_router_kernel[grid](
127
+ x,
128
+ router_weight,
129
+ topk_weights,
130
+ topk_ids,
131
+ num_experts=num_experts,
132
+ topk=topk,
133
+ moe_softcapping=moe_softcapping,
134
+ moe_renormalize=False,
135
+ hidden_dim=hidden_dim,
136
+ **config,
137
+ )
138
+
139
+ return topk_weights, topk_ids
140
+
141
+
142
+ @triton.jit
143
+ def fused_moe_router_large_bs_kernel(
144
+ a_ptr, # input (bs, hidden_dim)
145
+ b_ptr, # input (num_experts, hidden_dim)
146
+ topk_weights_ptr, # output (bs, topk)
147
+ topk_ids_ptr, # output (bs, topk)
148
+ bs,
149
+ num_experts: tl.constexpr,
150
+ topk: tl.constexpr, # only support topk == 1
151
+ moe_softcapping: tl.constexpr,
152
+ moe_renormalize: tl.constexpr, # not supported
153
+ K: tl.constexpr,
154
+ BLOCK_SIZE_M: tl.constexpr,
155
+ BLOCK_SIZE_N: tl.constexpr,
156
+ BLOCK_SIZE_K: tl.constexpr,
157
+ stride_am: tl.constexpr,
158
+ stride_bn: tl.constexpr,
159
+ ):
160
+
161
+ # 1. get block id
162
+ pid = tl.program_id(axis=0)
163
+
164
+ # 2. create pointers for the first block of A and B
165
+ # 2.1. setup a_ptrs with offsets in m and k
166
+ offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
167
+ bs_mask = offs_m < bs
168
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
169
+ a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
170
+
171
+ # 2.2. setup b_ptrs with offsets in k and n.
172
+ # Note: b matrix is k-major.
173
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
174
+ offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
175
+ expert_mask = offs_n < num_experts
176
+ b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
177
+
178
+ # 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
179
+ # 3.1. iterate in K dimension
180
+ # 3.2. transpose tile B
181
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
182
+ for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
183
+ a = tl.load(
184
+ a_ptrs,
185
+ mask=bs_mask,
186
+ other=0.0,
187
+ ).to(tl.float32)
188
+ b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
189
+ acc += tl.dot(a, b)
190
+
191
+ # Advance the ptrs to the next K block.
192
+ a_ptrs += BLOCK_SIZE_K
193
+ b_ptrs += BLOCK_SIZE_K
194
+
195
+ # 4. logit softcap
196
+ logits_scaled = acc / moe_softcapping
197
+ exped = tl.exp(2 * logits_scaled)
198
+ logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
199
+
200
+ # 5. top1
201
+ cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
202
+ top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
203
+ top1_v = tl.max(
204
+ tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
205
+ )
206
+ invsumexp = 1.0 / tl.sum(
207
+ tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
208
+ )
209
+
210
+ # 6. store to output
211
+ offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
212
+ topk_mask = offs_topk < bs
213
+ tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
214
+ tl.store(
215
+ topk_weights_ptr + offs_topk,
216
+ invsumexp,
217
+ mask=topk_mask,
218
+ )
219
+
220
+
221
+ def fused_moe_router_large_bs_impl(
222
+ x: torch.Tensor,
223
+ router_weight: torch.Tensor,
224
+ topk: int,
225
+ moe_softcapping: float,
226
+ BLOCK_SIZE_M: int,
227
+ BLOCK_SIZE_N: int,
228
+ BLOCK_SIZE_K: int,
229
+ ):
230
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
231
+ bs, hidden_dim = x.shape
232
+ num_experts = router_weight.shape[0]
233
+
234
+ assert num_experts <= BLOCK_SIZE_N
235
+ assert hidden_dim % BLOCK_SIZE_K == 0
236
+ assert topk == 1
237
+
238
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
239
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
240
+
241
+ grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
242
+
243
+ fused_moe_router_large_bs_kernel[grid](
244
+ a_ptr=x,
245
+ b_ptr=router_weight,
246
+ topk_weights_ptr=topk_weights,
247
+ topk_ids_ptr=topk_ids,
248
+ bs=bs,
249
+ num_experts=num_experts,
250
+ topk=topk,
251
+ moe_softcapping=moe_softcapping,
252
+ moe_renormalize=False,
253
+ K=hidden_dim,
254
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
255
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
256
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
257
+ stride_am=hidden_dim,
258
+ stride_bn=hidden_dim,
259
+ )
260
+
261
+ return topk_weights, topk_ids
262
+
263
+
264
+ def fused_moe_router_shim(
265
+ moe_softcapping,
266
+ hidden_states,
267
+ gating_output,
268
+ topk,
269
+ renormalize,
270
+ ):
271
+ assert not renormalize
272
+ assert (
273
+ len(hidden_states.shape) == 2
274
+ and hidden_states.shape[1] == gating_output.shape[1]
275
+ )
276
+ bs, hidden_dim = hidden_states.shape
277
+ num_experts = gating_output.shape[0]
278
+ BLOCK_SIZE_M = 32
279
+ BLOCK_SIZE_N = 16
280
+ BLOCK_SIZE_K = 256
281
+ if (
282
+ bs >= 512
283
+ and topk == 1
284
+ and num_experts <= BLOCK_SIZE_N
285
+ and hidden_dim % BLOCK_SIZE_K == 0
286
+ ):
287
+ return fused_moe_router_large_bs_impl(
288
+ x=hidden_states,
289
+ router_weight=gating_output,
290
+ topk=topk,
291
+ moe_softcapping=moe_softcapping,
292
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
293
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
294
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
295
+ )
296
+ else:
297
+ return fused_moe_router_impl(
298
+ x=hidden_states,
299
+ router_weight=gating_output,
300
+ topk=topk,
301
+ moe_softcapping=moe_softcapping,
302
+ )
303
+
304
+
305
+ class FusedMoeRouter:
306
+ def __init__(self, router_linear, topk, moe_softcapping) -> None:
307
+ self.router_linear = router_linear
308
+ self.topk = topk
309
+ self.moe_softcapping = moe_softcapping
310
+
311
+ def __call__(self, *args, **kwargs):
312
+ return self.forward(*args, **kwargs)
313
+
314
+ def forward(
315
+ self, x: torch.Tensor, residual: torch.Tensor
316
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ if x.is_cuda:
318
+ return self.forward_cuda(x, residual)
319
+ else:
320
+ return self.forward_vllm(x, residual)
321
+
322
+ def forward_cuda(
323
+ self, x: torch.Tensor, autotune=False
324
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
325
+ return fused_moe_router_shim(
326
+ moe_softcapping=self.moe_softcapping,
327
+ hidden_states=x,
328
+ gating_output=self.router_linear.weight,
329
+ topk=self.topk,
330
+ renormalize=False,
331
+ )
332
+
333
+ def forward_vllm(
334
+ self,
335
+ x: torch.Tensor,
336
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
337
+ # g, _ = self.router_linear.forward(x)
338
+ g = x.float() @ self.router_linear.weight.T.float()
339
+
340
+ g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
341
+
342
+ return fused_topk(x, g, self.topk, False)
@@ -17,7 +17,14 @@ from typing import Callable, Optional
17
17
  import torch
18
18
  import torch.nn.functional as F
19
19
 
20
- from sglang.srt.utils import get_compiler_backend
20
+ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
21
+
22
+ _is_cuda = is_cuda()
23
+ _is_hip = is_hip()
24
+
25
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
26
+
27
+ expert_distribution_recorder = ExpertDistributionRecorder()
21
28
 
22
29
 
23
30
  def fused_topk_native(
@@ -47,7 +54,10 @@ def fused_topk(
47
54
  topk: int,
48
55
  renormalize: bool,
49
56
  ):
50
- from vllm import _custom_ops as ops
57
+ if _is_cuda or _is_hip:
58
+ from sgl_kernel import topk_softmax
59
+ else:
60
+ from vllm import _custom_ops as vllm_ops
51
61
 
52
62
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
53
63
 
@@ -61,12 +71,20 @@ def fused_topk(
61
71
  M, topk, dtype=torch.int32, device=hidden_states.device
62
72
  )
63
73
 
64
- ops.topk_softmax(
65
- topk_weights,
66
- topk_ids,
67
- token_expert_indicies,
68
- gating_output.float(),
69
- )
74
+ if _is_cuda or _is_hip:
75
+ topk_softmax(
76
+ topk_weights,
77
+ topk_ids,
78
+ token_expert_indicies,
79
+ gating_output.float(),
80
+ )
81
+ else:
82
+ vllm_ops.topk_softmax(
83
+ topk_weights,
84
+ topk_ids,
85
+ token_expert_indicies,
86
+ gating_output.float(),
87
+ )
70
88
  del token_expert_indicies
71
89
 
72
90
  if renormalize:
@@ -75,6 +93,7 @@ def fused_topk(
75
93
  return topk_weights, topk_ids
76
94
 
77
95
 
96
+ # This is used by the Deepseek V2/V3/R1 series models
78
97
  @torch.compile(dynamic=True, backend=get_compiler_backend())
79
98
  def grouped_topk(
80
99
  hidden_states: torch.Tensor,
@@ -83,17 +102,10 @@ def grouped_topk(
83
102
  renormalize: bool,
84
103
  num_expert_group: int = 0,
85
104
  topk_group: int = 0,
86
- scoring_func: str = "softmax",
87
105
  ):
88
106
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
89
107
 
90
- if scoring_func == "softmax":
91
- scores = torch.softmax(gating_output, dim=-1)
92
- elif scoring_func == "sigmoid":
93
- scores = gating_output.sigmoid()
94
- else:
95
- raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
96
-
108
+ scores = torch.softmax(gating_output, dim=-1)
97
109
  num_token = scores.shape[0]
98
110
  group_scores = (
99
111
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -117,7 +129,6 @@ def grouped_topk(
117
129
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
118
130
 
119
131
 
120
- # DeepSeek V2/V3/R1 uses biased_grouped_top
121
132
  @torch.compile(dynamic=True, backend=get_compiler_backend())
122
133
  def biased_grouped_topk(
123
134
  hidden_states: torch.Tensor,
@@ -172,7 +183,7 @@ def select_experts(
172
183
  correction_bias: Optional[torch.Tensor] = None,
173
184
  torch_native: bool = False,
174
185
  ):
175
- # DeepSeek V2/V3/R1 uses biased_grouped_top
186
+ # DeekSeekv2 uses grouped_top_k
176
187
  if use_grouped_topk:
177
188
  assert topk_group is not None
178
189
  assert num_expert_group is not None
@@ -217,4 +228,6 @@ def select_experts(
217
228
  renormalize=renormalize,
218
229
  )
219
230
 
231
+ expert_distribution_recorder.record_new_token(topk_ids)
232
+
220
233
  return topk_weights, topk_ids
@@ -105,6 +105,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
105
105
 
106
106
  shard_offset = kwargs.get("shard_offset")
107
107
  shard_size = kwargs.get("shard_size")
108
+ tp_rank = kwargs.get("tp_rank")
108
109
  use_presharded_weights = kwargs.get("use_presharded_weights")
109
110
  if (
110
111
  isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
@@ -116,7 +117,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
116
117
 
117
118
  param_data = self.data
118
119
 
119
- tp_rank = get_tensor_model_parallel_rank()
120
120
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
121
121
  if not use_presharded_weights:
122
122
  loaded_weight = loaded_weight.narrow(