sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.layers.moe.topk import fused_topk
8
+
9
+
10
+ @triton.jit
11
+ def fused_moe_router_kernel(
12
+ input_ptr, # input (bs, hidden_dim)
13
+ moe_router_weight_ptr, # input (num_experts, hidden_dim)
14
+ topk_weights_ptr, # output (bs, topk)
15
+ topk_ids_ptr, # output (bs, topk)
16
+ num_experts: tl.constexpr,
17
+ topk: tl.constexpr,
18
+ moe_softcapping: tl.constexpr,
19
+ moe_renormalize: tl.constexpr, # not supported
20
+ hidden_dim: tl.constexpr,
21
+ BLOCK_SIZE: tl.constexpr,
22
+ ):
23
+ pid = tl.program_id(axis=0)
24
+
25
+ offsets = tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < hidden_dim
27
+
28
+ # moe_router_weight is k major
29
+ expert_offsets = tl.arange(0, num_experts)[:, None]
30
+ router_mask = mask[None, :]
31
+ w_router = tl.load(
32
+ moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
33
+ mask=router_mask,
34
+ other=0.0,
35
+ )
36
+
37
+ x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
38
+
39
+ # todo: tl.dot?
40
+ logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
41
+
42
+ # logit softcap
43
+ logits_scaled = logits / moe_softcapping
44
+ exped = tl.exp(2 * logits_scaled)
45
+ top = exped - 1
46
+ bottom = exped + 1
47
+ logits_softcapped = top / bottom * moe_softcapping
48
+
49
+ # topk
50
+ # assert 1 <= topk <= num_experts
51
+
52
+ # 5.38 us
53
+
54
+ top1 = tl.argmax(logits_softcapped, axis=0)
55
+ tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
56
+
57
+ top1_v = tl.max(logits_softcapped, axis=0)
58
+ invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
59
+
60
+ tl.store(
61
+ topk_weights_ptr + pid * topk + 0,
62
+ invsumexp,
63
+ ) # 5.73 us
64
+
65
+ if topk >= 2:
66
+ top2 = tl.argmax(
67
+ tl.where(
68
+ tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
69
+ ),
70
+ axis=0,
71
+ )
72
+ tl.store(topk_ids_ptr + pid * topk + 1, top2)
73
+ top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
74
+ tl.store(
75
+ topk_weights_ptr + pid * topk + 1,
76
+ tl.exp(top2_v - top1_v) * invsumexp,
77
+ ) # 5.95us
78
+
79
+ # probably slow
80
+ if topk > 2:
81
+ topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
82
+ topk_mask = tl.where(
83
+ tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
84
+ )
85
+ topk_mask = tl.where(
86
+ tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
87
+ )
88
+ for i in range(2, topk):
89
+ topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
90
+ topk_mask = tl.where(
91
+ tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
92
+ )
93
+ tl.store(topk_ids_ptr + pid * topk + i, topi)
94
+ topi_v = tl.sum(
95
+ logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
96
+ )
97
+ tl.store(
98
+ topk_weights_ptr + pid * topk + i,
99
+ tl.exp(topi_v - top1_v) * invsumexp,
100
+ )
101
+ # assert not moe_renormalize, "moe weight renormalization not implemented"
102
+
103
+
104
+ def fused_moe_router_impl(
105
+ x: torch.Tensor,
106
+ router_weight: torch.Tensor,
107
+ topk: int,
108
+ moe_softcapping: float,
109
+ ):
110
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
111
+ bs, hidden_dim = x.shape
112
+ num_experts = router_weight.shape[0]
113
+
114
+ # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
115
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
116
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
+
118
+ grid = lambda meta: (bs,)
119
+ config = {
120
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
+ "num_warps": max(
122
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
123
+ ),
124
+ }
125
+
126
+ fused_moe_router_kernel[grid](
127
+ x,
128
+ router_weight,
129
+ topk_weights,
130
+ topk_ids,
131
+ num_experts=num_experts,
132
+ topk=topk,
133
+ moe_softcapping=moe_softcapping,
134
+ moe_renormalize=False,
135
+ hidden_dim=hidden_dim,
136
+ **config,
137
+ )
138
+
139
+ return topk_weights, topk_ids
140
+
141
+
142
+ @triton.jit
143
+ def fused_moe_router_large_bs_kernel(
144
+ a_ptr, # input (bs, hidden_dim)
145
+ b_ptr, # input (num_experts, hidden_dim)
146
+ topk_weights_ptr, # output (bs, topk)
147
+ topk_ids_ptr, # output (bs, topk)
148
+ bs,
149
+ num_experts: tl.constexpr,
150
+ topk: tl.constexpr, # only support topk == 1
151
+ moe_softcapping: tl.constexpr,
152
+ moe_renormalize: tl.constexpr, # not supported
153
+ K: tl.constexpr,
154
+ BLOCK_SIZE_M: tl.constexpr,
155
+ BLOCK_SIZE_N: tl.constexpr,
156
+ BLOCK_SIZE_K: tl.constexpr,
157
+ stride_am: tl.constexpr,
158
+ stride_bn: tl.constexpr,
159
+ ):
160
+
161
+ # 1. get block id
162
+ pid = tl.program_id(axis=0)
163
+
164
+ # 2. create pointers for the first block of A and B
165
+ # 2.1. setup a_ptrs with offsets in m and k
166
+ offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
167
+ bs_mask = offs_m < bs
168
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
169
+ a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
170
+
171
+ # 2.2. setup b_ptrs with offsets in k and n.
172
+ # Note: b matrix is k-major.
173
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
174
+ offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
175
+ expert_mask = offs_n < num_experts
176
+ b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
177
+
178
+ # 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
179
+ # 3.1. iterate in K dimension
180
+ # 3.2. transpose tile B
181
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
182
+ for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
183
+ a = tl.load(
184
+ a_ptrs,
185
+ mask=bs_mask,
186
+ other=0.0,
187
+ ).to(tl.float32)
188
+ b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
189
+ acc += tl.dot(a, b)
190
+
191
+ # Advance the ptrs to the next K block.
192
+ a_ptrs += BLOCK_SIZE_K
193
+ b_ptrs += BLOCK_SIZE_K
194
+
195
+ # 4. logit softcap
196
+ logits_scaled = acc / moe_softcapping
197
+ exped = tl.exp(2 * logits_scaled)
198
+ logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
199
+
200
+ # 5. top1
201
+ cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
202
+ top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
203
+ top1_v = tl.max(
204
+ tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
205
+ )
206
+ invsumexp = 1.0 / tl.sum(
207
+ tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
208
+ )
209
+
210
+ # 6. store to output
211
+ offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
212
+ topk_mask = offs_topk < bs
213
+ tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
214
+ tl.store(
215
+ topk_weights_ptr + offs_topk,
216
+ invsumexp,
217
+ mask=topk_mask,
218
+ )
219
+
220
+
221
+ def fused_moe_router_large_bs_impl(
222
+ x: torch.Tensor,
223
+ router_weight: torch.Tensor,
224
+ topk: int,
225
+ moe_softcapping: float,
226
+ BLOCK_SIZE_M: int,
227
+ BLOCK_SIZE_N: int,
228
+ BLOCK_SIZE_K: int,
229
+ ):
230
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
231
+ bs, hidden_dim = x.shape
232
+ num_experts = router_weight.shape[0]
233
+
234
+ assert num_experts <= BLOCK_SIZE_N
235
+ assert hidden_dim % BLOCK_SIZE_K == 0
236
+ assert topk == 1
237
+
238
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
239
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
240
+
241
+ grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
242
+
243
+ fused_moe_router_large_bs_kernel[grid](
244
+ a_ptr=x,
245
+ b_ptr=router_weight,
246
+ topk_weights_ptr=topk_weights,
247
+ topk_ids_ptr=topk_ids,
248
+ bs=bs,
249
+ num_experts=num_experts,
250
+ topk=topk,
251
+ moe_softcapping=moe_softcapping,
252
+ moe_renormalize=False,
253
+ K=hidden_dim,
254
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
255
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
256
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
257
+ stride_am=hidden_dim,
258
+ stride_bn=hidden_dim,
259
+ )
260
+
261
+ return topk_weights, topk_ids
262
+
263
+
264
+ def fused_moe_router_shim(
265
+ moe_softcapping,
266
+ hidden_states,
267
+ gating_output,
268
+ topk,
269
+ renormalize,
270
+ ):
271
+ assert not renormalize
272
+ assert (
273
+ len(hidden_states.shape) == 2
274
+ and hidden_states.shape[1] == gating_output.shape[1]
275
+ )
276
+ bs, hidden_dim = hidden_states.shape
277
+ num_experts = gating_output.shape[0]
278
+ BLOCK_SIZE_M = 32
279
+ BLOCK_SIZE_N = 16
280
+ BLOCK_SIZE_K = 256
281
+ if (
282
+ bs >= 512
283
+ and topk == 1
284
+ and num_experts <= BLOCK_SIZE_N
285
+ and hidden_dim % BLOCK_SIZE_K == 0
286
+ ):
287
+ return fused_moe_router_large_bs_impl(
288
+ x=hidden_states,
289
+ router_weight=gating_output,
290
+ topk=topk,
291
+ moe_softcapping=moe_softcapping,
292
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
293
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
294
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
295
+ )
296
+ else:
297
+ return fused_moe_router_impl(
298
+ x=hidden_states,
299
+ router_weight=gating_output,
300
+ topk=topk,
301
+ moe_softcapping=moe_softcapping,
302
+ )
303
+
304
+
305
+ class FusedMoeRouter:
306
+ def __init__(self, router_linear, topk, moe_softcapping) -> None:
307
+ self.router_linear = router_linear
308
+ self.topk = topk
309
+ self.moe_softcapping = moe_softcapping
310
+
311
+ def __call__(self, *args, **kwargs):
312
+ return self.forward(*args, **kwargs)
313
+
314
+ def forward(
315
+ self, x: torch.Tensor, residual: torch.Tensor
316
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ if x.is_cuda:
318
+ return self.forward_cuda(x, residual)
319
+ else:
320
+ return self.forward_vllm(x, residual)
321
+
322
+ def forward_cuda(
323
+ self, x: torch.Tensor, autotune=False
324
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
325
+ return fused_moe_router_shim(
326
+ moe_softcapping=self.moe_softcapping,
327
+ hidden_states=x,
328
+ gating_output=self.router_linear.weight,
329
+ topk=self.topk,
330
+ renormalize=False,
331
+ )
332
+
333
+ def forward_vllm(
334
+ self,
335
+ x: torch.Tensor,
336
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
337
+ # g, _ = self.router_linear.forward(x)
338
+ g = x.float() @ self.router_linear.weight.T.float()
339
+
340
+ g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
341
+
342
+ return fused_topk(x, g, self.topk, False)
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "ModelWeightParameter",
17
17
  "ChannelQuantScaleParameter",
18
18
  "GroupQuantScaleParameter",
19
+ "BlockQuantScaleParameter",
19
20
  "PackedColumnParameter",
20
21
  "RowvLLMParameter",
21
22
  ]
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
221
222
  pass
222
223
 
223
224
 
225
+ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
226
+ """
227
+ Parameter class for weight scales loaded for weights with
228
+ block-wise quantization. Uses both column and row parallelism.
229
+ """
230
+
231
+ pass
232
+
233
+
224
234
  class PerTensorScaleParameter(BasevLLMParameter):
225
235
  """
226
236
  Parameter class for scales where the number of scales is
@@ -1,4 +1,6 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+ import builtins
3
+ import inspect
2
4
  import re
3
5
  from copy import deepcopy
4
6
  from typing import Callable, Dict, Optional, Type, Union
@@ -6,10 +8,7 @@ from typing import Callable, Dict, Optional, Type, Union
6
8
  import torch
7
9
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
8
10
  from vllm.model_executor.layers.quantization.awq import AWQConfig
9
- from vllm.model_executor.layers.quantization.awq_marlin import (
10
- AWQMarlinConfig,
11
- AWQMoEMethod,
12
- )
11
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
13
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
14
13
  from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
15
14
  CompressedTensorsConfig,
@@ -28,6 +27,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
28
27
  from sglang.srt.layers.quantization.fp8 import Fp8Config
29
28
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
30
29
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
30
+ from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
31
31
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
32
32
 
33
33
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@@ -50,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
50
50
  "qqq": QQQConfig,
51
51
  "experts_int8": ExpertsInt8Config,
52
52
  "w8a8_int8": W8A8Int8Config,
53
+ "w8a8_fp8": W8A8Fp8Config,
53
54
  }
54
55
 
55
56
 
@@ -178,96 +179,117 @@ def gptq_get_quant_method(self, layer, prefix):
178
179
  return None
179
180
 
180
181
 
181
- def awq_get_quant_method(self, layer, prefix):
182
- from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
183
- from vllm.model_executor.layers.quantization.awq_marlin import (
184
- AWQMarlinLinearMethod,
185
- AWQMoEMethod,
186
- )
182
+ original_isinstance = builtins.isinstance
187
183
 
188
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
189
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
190
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
191
184
 
192
- if isinstance(layer, LinearBase) or (
193
- isinstance(layer, ParallelLMHead) and self.lm_head_quantized
194
- ):
195
- if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
196
- return UnquantizedLinearMethod()
197
- return AWQMarlinLinearMethod(self)
198
- elif isinstance(layer, FusedMoE):
199
- return AWQMoEMethod(self)
200
- return None
185
+ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
186
+ """
187
+ Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
188
+ can recognize sglang layers
189
+ """
201
190
 
191
+ if reverse:
192
+ builtins.isinstance = original_isinstance
193
+ return
202
194
 
203
- original_awq_moe_method_apply = AWQMoEMethod.apply
204
-
205
-
206
- def awq_moe_method_apply(
207
- self,
208
- layer: torch.nn.Module,
209
- x: torch.Tensor,
210
- router_logits: torch.Tensor,
211
- top_k: int,
212
- renormalize: bool,
213
- use_grouped_topk: bool = False,
214
- topk_group: Optional[int] = None,
215
- num_expert_group: Optional[int] = None,
216
- custom_routing_function: Optional[Callable] = None,
217
- scoring_func: str = "softmax",
218
- e_score_correction_bias: Optional[torch.Tensor] = None,
219
- **kwargs,
220
- ):
221
- return original_awq_moe_method_apply(
222
- self,
223
- layer,
224
- x,
225
- router_logits,
226
- top_k,
227
- renormalize,
228
- use_grouped_topk,
229
- topk_group,
230
- num_expert_group,
231
- custom_routing_function,
232
- scoring_func,
233
- e_score_correction_bias,
234
- )
235
-
236
-
237
- def patch_vllm_linear_base_isinstance():
238
- import builtins
239
-
195
+ from vllm.model_executor.layers.fused_moe import FusedMoE
240
196
  from vllm.model_executor.layers.linear import LinearBase
197
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
198
+ VocabParallelEmbedding,
199
+ )
241
200
 
242
201
  from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
243
-
244
- original_isinstance = builtins.isinstance
202
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
203
+ from sglang.srt.layers.vocab_parallel_embedding import (
204
+ VocabParallelEmbedding as PatchedVocabParallelEmbedding,
205
+ )
245
206
 
246
207
  def patched_isinstance(obj, classinfo):
247
208
  if classinfo is LinearBase:
248
209
  return original_isinstance(obj, PatchedLinearBase)
210
+ if classinfo is FusedMoE:
211
+ return original_isinstance(obj, PatchedFusedMoE)
212
+ if classinfo is VocabParallelEmbedding:
213
+ return original_isinstance(obj, PatchedVocabParallelEmbedding)
249
214
  return original_isinstance(obj, classinfo)
250
215
 
251
216
  builtins.isinstance = patched_isinstance
252
217
 
253
218
 
254
- def apply_monkey_patches():
219
+ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
220
+ """
221
+ Monkey patch the apply function of vllm's FusedMoEMethodBase.
222
+ Convert sglang arguments to vllm arguments.
223
+ """
224
+ original_apply = class_obj.apply
225
+ sig = inspect.signature(original_apply)
226
+ param_names = list(sig.parameters.keys())
227
+ has_correction_bias = "e_score_correction_bias" in param_names
228
+
229
+ def new_apply(
230
+ self,
231
+ layer: torch.nn.Module,
232
+ x: torch.Tensor,
233
+ router_logits: torch.Tensor,
234
+ top_k: int,
235
+ renormalize: bool,
236
+ use_grouped_topk: bool,
237
+ topk_group: Optional[int] = None,
238
+ num_expert_group: Optional[int] = None,
239
+ custom_routing_function: Optional[Callable] = None,
240
+ correction_bias: Optional[torch.Tensor] = None,
241
+ activation: str = "silu",
242
+ inplace: bool = True,
243
+ no_combine: bool = False,
244
+ ):
245
+ assert activation == "silu"
246
+ assert inplace and not no_combine
247
+
248
+ kwargs = {
249
+ "self": self,
250
+ "layer": layer,
251
+ "x": x,
252
+ "router_logits": router_logits,
253
+ "top_k": top_k,
254
+ "renormalize": renormalize,
255
+ "use_grouped_topk": use_grouped_topk,
256
+ "topk_group": topk_group,
257
+ "num_expert_group": num_expert_group,
258
+ "custom_routing_function": custom_routing_function,
259
+ }
260
+ if correction_bias is not None:
261
+ if not has_correction_bias:
262
+ raise ValueError(
263
+ "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
264
+ )
265
+ kwargs["e_score_correction_bias"] = correction_bias
266
+ return original_apply(**kwargs)
267
+
268
+ setattr(class_obj, "apply", new_apply)
269
+
270
+
271
+ def monkey_patch_quant_configs():
255
272
  """Apply all monkey patches in one place."""
256
273
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
274
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
275
+ CompressedTensorsW8A8Fp8MoEMethod,
276
+ CompressedTensorsWNA16MoEMethod,
277
+ )
278
+ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
257
279
 
258
280
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
259
281
  setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
260
- setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
261
- setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
282
+
283
+ monkey_patch_moe_apply(AWQMoEMethod)
284
+ monkey_patch_moe_apply(GPTQMarlinMoEMethod)
285
+ monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
286
+ monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
262
287
 
263
288
 
264
- patch_vllm_linear_base_isinstance()
265
- # Apply patches when module is imported
266
- apply_monkey_patches()
289
+ monkey_patch_quant_configs()
267
290
 
268
291
 
269
292
  __all__ = [
270
- "QuantizationConfig",
271
293
  "get_quantization_config",
272
294
  "QUANTIZATION_METHODS",
273
295
  ]
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
13
13
  LinearMethodBase,
14
14
  UnquantizedLinearMethod,
15
15
  )
16
- from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
16
+ from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
22
21
  from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
23
22
  from sglang.srt.utils import set_weight_attrs
24
23