sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|
130
130
|
|
131
131
|
@triton.jit
|
132
132
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
133
|
-
|
133
|
+
expert_id_minus_1 = tl.program_id(0) - 1
|
134
134
|
low = 0
|
135
135
|
high = num_toks - 1
|
136
136
|
target_location = -1
|
137
137
|
while low <= high:
|
138
138
|
mid = (low + high) // 2
|
139
139
|
|
140
|
-
if tl.load(reorder_topk_ids + mid) >
|
140
|
+
if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
|
141
141
|
high = mid - 1
|
142
142
|
else:
|
143
143
|
low = mid + 1
|
144
144
|
target_location = mid
|
145
|
-
tl.store(seg_indptr +
|
145
|
+
tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
|
146
146
|
|
147
147
|
|
148
|
-
def run_moe_ep_preproess(topk_ids: torch.Tensor,
|
148
|
+
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
|
149
149
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
150
150
|
|
151
|
-
seg_indptr = torch.zeros(
|
151
|
+
seg_indptr = torch.zeros(
|
152
|
+
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
153
|
+
)
|
152
154
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
153
155
|
|
154
|
-
compute_seg_indptr_triton_kernel[(
|
156
|
+
compute_seg_indptr_triton_kernel[(num_local_experts,)](
|
155
157
|
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
156
158
|
)
|
157
159
|
|
@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|
164
166
|
return reorder_topk_ids, src2dst, seg_indptr
|
165
167
|
|
166
168
|
|
167
|
-
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
168
|
-
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
169
|
-
|
170
|
-
seg_indptr = torch.zeros(
|
171
|
-
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
172
|
-
)
|
173
|
-
src2dst = torch.empty(
|
174
|
-
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
175
|
-
)
|
176
|
-
|
177
|
-
BLOCK_SIZE = 512
|
178
|
-
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
179
|
-
compute_src2dst_triton_kernel[grid](
|
180
|
-
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
181
|
-
)
|
182
|
-
|
183
|
-
return reorder_topk_ids, src2dst, seg_indptr
|
184
|
-
|
185
|
-
|
186
169
|
@triton.jit
|
187
170
|
def pre_reorder_triton_kernel_for_cutlass_moe(
|
188
171
|
input_ptr,
|
@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
|
|
190
173
|
src2dst_ptr,
|
191
174
|
topk_ids_ptr,
|
192
175
|
a1_scales_ptr,
|
193
|
-
|
176
|
+
num_local_experts,
|
194
177
|
topk,
|
195
178
|
hidden_size,
|
196
179
|
BLOCK_SIZE: tl.constexpr,
|
197
180
|
):
|
198
181
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
199
182
|
|
200
|
-
src_idx = tl.program_id(0)
|
201
|
-
src2dst_ptr = src2dst_ptr + src_idx * topk
|
202
|
-
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
203
|
-
|
204
|
-
src_ptr = input_ptr + src_idx * hidden_size
|
205
|
-
for idx in range(topk):
|
206
|
-
expert_id = tl.load(topk_ids_ptr + idx)
|
207
|
-
if expert_id != num_experts:
|
208
|
-
if a1_scales_ptr is not None:
|
209
|
-
scale = 1.0 / tl.load(a1_scales_ptr)
|
210
|
-
else:
|
211
|
-
scale = 1.0
|
212
|
-
|
213
|
-
dst_idx = tl.load(src2dst_ptr + idx)
|
214
|
-
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
215
|
-
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
216
|
-
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
217
|
-
mask = offset < hidden_size
|
218
|
-
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
219
|
-
out_data = (in_data * scale).to(OutDtype)
|
220
|
-
tl.store(dst_ptr + offset, out_data, mask=mask)
|
221
|
-
|
222
|
-
|
223
|
-
@triton.jit
|
224
|
-
def pre_reorder_triton_kernel(
|
225
|
-
input_ptr,
|
226
|
-
gateup_input_ptr,
|
227
|
-
src2dst_ptr,
|
228
|
-
topk_ids_ptr,
|
229
|
-
a1_scales_ptr,
|
230
|
-
start_expert_id,
|
231
|
-
end_expert_id,
|
232
|
-
topk,
|
233
|
-
hidden_size,
|
234
|
-
BLOCK_SIZE: tl.constexpr,
|
235
|
-
use_per_token_if_dynamic: tl.constexpr,
|
236
|
-
):
|
237
|
-
OutDtype = gateup_input_ptr.dtype.element_ty
|
238
|
-
|
239
183
|
src_idx_int32 = tl.program_id(0)
|
240
184
|
src_idx = src_idx_int32.to(tl.int64)
|
241
185
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
|
|
244
188
|
|
245
189
|
vec = tl.arange(0, BLOCK_SIZE)
|
246
190
|
|
247
|
-
if a1_scales_ptr is not None and use_per_token_if_dynamic:
|
248
|
-
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
|
249
|
-
|
250
191
|
for idx in range(topk):
|
251
192
|
expert_id = tl.load(topk_ids_ptr + idx)
|
252
|
-
if expert_id
|
193
|
+
if expert_id != num_local_experts:
|
253
194
|
if a1_scales_ptr is not None:
|
254
|
-
|
255
|
-
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
195
|
+
scale = 1.0 / tl.load(a1_scales_ptr)
|
256
196
|
else:
|
257
197
|
scale = 1.0
|
258
198
|
|
@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
|
|
267
207
|
tl.store(dst_ptr + offset, out_data, mask=mask)
|
268
208
|
|
269
209
|
|
270
|
-
@triton.jit
|
271
|
-
def silu_and_mul_triton_kernel(
|
272
|
-
gateup_output,
|
273
|
-
down_input,
|
274
|
-
hidden_size,
|
275
|
-
reorder_topk_ids,
|
276
|
-
scales,
|
277
|
-
start_expert_id,
|
278
|
-
end_expert_id,
|
279
|
-
BLOCK_SIZE: tl.constexpr,
|
280
|
-
):
|
281
|
-
InDtype = gateup_output.dtype.element_ty
|
282
|
-
OutDtype = down_input.dtype.element_ty
|
283
|
-
|
284
|
-
half_hidden_size = hidden_size // 2
|
285
|
-
|
286
|
-
pid = tl.program_id(0)
|
287
|
-
expert_id = tl.load(reorder_topk_ids + pid)
|
288
|
-
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
289
|
-
gateup_output_ptr = gateup_output + pid * hidden_size
|
290
|
-
gate_output_ptr = gateup_output_ptr
|
291
|
-
up_output_ptr = gateup_output_ptr + half_hidden_size
|
292
|
-
down_input_ptr = down_input + pid * half_hidden_size
|
293
|
-
|
294
|
-
if scales is not None:
|
295
|
-
scale = tl.load(scales + expert_id - start_expert_id)
|
296
|
-
scale = (1 / scale).to(InDtype)
|
297
|
-
else:
|
298
|
-
scale = 1
|
299
|
-
|
300
|
-
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
301
|
-
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
302
|
-
mask = offset < half_hidden_size
|
303
|
-
|
304
|
-
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
305
|
-
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
306
|
-
|
307
|
-
# silu & mul & quantize
|
308
|
-
gate_output = gate_output * tl.sigmoid(gate_output)
|
309
|
-
gate_output = gate_output.to(InDtype)
|
310
|
-
|
311
|
-
silu_mul_output = gate_output * up_output * scale
|
312
|
-
silu_mul_output = silu_mul_output.to(OutDtype)
|
313
|
-
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
314
|
-
|
315
|
-
|
316
210
|
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
317
211
|
@triton.jit
|
318
212
|
def _silu_and_mul_post_quant_kernel(
|
@@ -461,84 +355,15 @@ def silu_and_mul_masked_post_quant_fwd(
|
|
461
355
|
|
462
356
|
|
463
357
|
@triton.jit
|
464
|
-
def
|
465
|
-
return 2 * tl.sigmoid(2 * x) - 1
|
466
|
-
|
467
|
-
|
468
|
-
@triton.jit
|
469
|
-
def gelu_and_mul_triton_kernel(
|
470
|
-
gateup_output,
|
471
|
-
down_input,
|
472
|
-
hidden_size,
|
473
|
-
reorder_topk_ids,
|
474
|
-
scales,
|
475
|
-
start_expert_id,
|
476
|
-
end_expert_id,
|
477
|
-
BLOCK_SIZE: tl.constexpr,
|
478
|
-
):
|
479
|
-
InDtype = gateup_output.dtype.element_ty
|
480
|
-
OutDtype = down_input.dtype.element_ty
|
481
|
-
|
482
|
-
half_hidden_size = hidden_size // 2
|
483
|
-
|
484
|
-
pid = tl.program_id(0)
|
485
|
-
expert_id = tl.load(reorder_topk_ids + pid)
|
486
|
-
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
487
|
-
gateup_output_ptr = gateup_output + pid * hidden_size
|
488
|
-
gate_output_ptr = gateup_output_ptr
|
489
|
-
up_output_ptr = gateup_output_ptr + half_hidden_size
|
490
|
-
down_input_ptr = down_input + pid * half_hidden_size
|
491
|
-
|
492
|
-
if scales is not None:
|
493
|
-
scale = tl.load(scales + expert_id - start_expert_id)
|
494
|
-
scale = (1 / scale).to(InDtype)
|
495
|
-
else:
|
496
|
-
scale = 1
|
497
|
-
|
498
|
-
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
499
|
-
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
500
|
-
mask = offset < half_hidden_size
|
501
|
-
|
502
|
-
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
503
|
-
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
504
|
-
|
505
|
-
# gelu & mul & quantize
|
506
|
-
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
507
|
-
# sqrt(2/pi)
|
508
|
-
kAlpha = 0.7978845608028654
|
509
|
-
gate_output = (
|
510
|
-
0.5
|
511
|
-
* gate_output
|
512
|
-
* (
|
513
|
-
1
|
514
|
-
+ tanh(
|
515
|
-
kAlpha
|
516
|
-
* (
|
517
|
-
gate_output
|
518
|
-
+ 0.044715 * gate_output * gate_output * gate_output
|
519
|
-
)
|
520
|
-
)
|
521
|
-
)
|
522
|
-
)
|
523
|
-
gate_output = gate_output.to(InDtype)
|
524
|
-
|
525
|
-
gelu_mul_output = gate_output * up_output * scale
|
526
|
-
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
527
|
-
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
528
|
-
|
529
|
-
|
530
|
-
@triton.jit
|
531
|
-
def post_reorder_triton_kernel(
|
358
|
+
def post_reorder_triton_kernel_for_cutlass_moe(
|
532
359
|
down_output_ptr,
|
533
360
|
output_ptr,
|
534
361
|
src2dst_ptr,
|
535
362
|
topk_ids_ptr,
|
536
363
|
topk_weights_ptr,
|
537
|
-
start_expert_id,
|
538
|
-
end_expert_id,
|
539
364
|
topk,
|
365
|
+
num_local_experts,
|
540
366
|
hidden_size,
|
541
|
-
dst_start,
|
542
367
|
BLOCK_SIZE: tl.constexpr,
|
543
368
|
):
|
544
369
|
InDtype = down_output_ptr.dtype.element_ty
|
@@ -549,7 +374,6 @@ def post_reorder_triton_kernel(
|
|
549
374
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
550
375
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
551
376
|
|
552
|
-
computed = False
|
553
377
|
store_ptr = output_ptr + src_idx * hidden_size
|
554
378
|
|
555
379
|
vec = tl.arange(0, BLOCK_SIZE)
|
@@ -561,37 +385,25 @@ def post_reorder_triton_kernel(
|
|
561
385
|
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
562
386
|
for idx in range(topk):
|
563
387
|
expert_id = tl.load(topk_ids_ptr + idx)
|
564
|
-
if expert_id
|
565
|
-
computed = True
|
388
|
+
if expert_id != num_local_experts:
|
566
389
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
567
390
|
dst_idx = dst_idx_int32.to(tl.int64)
|
568
|
-
dst_idx = dst_idx - dst_start
|
569
391
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
570
392
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
571
393
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
572
394
|
sum_vec += in_data * weigh_scale
|
573
395
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
574
396
|
|
575
|
-
if computed == False:
|
576
|
-
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
577
|
-
offset = start_offset + vec
|
578
|
-
mask = offset < hidden_size
|
579
|
-
tl.store(
|
580
|
-
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
581
|
-
)
|
582
|
-
|
583
397
|
|
584
398
|
@triton.jit
|
585
|
-
def
|
399
|
+
def post_reorder_triton_kernel(
|
586
400
|
down_output_ptr,
|
587
401
|
output_ptr,
|
588
402
|
src2dst_ptr,
|
589
403
|
topk_ids_ptr,
|
590
404
|
topk_weights_ptr,
|
591
|
-
num_experts,
|
592
405
|
topk,
|
593
406
|
hidden_size,
|
594
|
-
dst_start,
|
595
407
|
BLOCK_SIZE: tl.constexpr,
|
596
408
|
):
|
597
409
|
InDtype = down_output_ptr.dtype.element_ty
|
@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
|
613
425
|
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
614
426
|
for idx in range(topk):
|
615
427
|
expert_id = tl.load(topk_ids_ptr + idx)
|
616
|
-
if expert_id
|
428
|
+
if expert_id > 0:
|
617
429
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
618
430
|
dst_idx = dst_idx_int32.to(tl.int64)
|
619
|
-
dst_idx = dst_idx - dst_start
|
620
431
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
621
432
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
622
433
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
|
624
435
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
625
436
|
|
626
437
|
|
627
|
-
@triton.jit
|
628
|
-
def compute_m_range(
|
629
|
-
pid,
|
630
|
-
batch_size,
|
631
|
-
seg_indptr,
|
632
|
-
weight_indices,
|
633
|
-
m_num_tiles_indptr,
|
634
|
-
BLOCK_SIZE_M: tl.constexpr,
|
635
|
-
):
|
636
|
-
idx = 0
|
637
|
-
for bs in range(batch_size):
|
638
|
-
tiles = tl.load(m_num_tiles_indptr + bs)
|
639
|
-
if pid >= tiles:
|
640
|
-
idx = bs
|
641
|
-
|
642
|
-
idx_start = tl.load(m_num_tiles_indptr + idx)
|
643
|
-
|
644
|
-
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
645
|
-
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
646
|
-
expert_id = tl.load(weight_indices + idx)
|
647
|
-
return m_range_start, m_range_end, expert_id
|
648
|
-
|
649
|
-
|
650
|
-
@triton.jit
|
651
|
-
def grouped_gemm_triton_kernel(
|
652
|
-
a,
|
653
|
-
b,
|
654
|
-
c,
|
655
|
-
batch_size,
|
656
|
-
N,
|
657
|
-
K,
|
658
|
-
seg_indptr,
|
659
|
-
weight_indices,
|
660
|
-
m_num_tiles_indptr,
|
661
|
-
scale_a,
|
662
|
-
scale_b,
|
663
|
-
use_fp8_w8a8: tl.constexpr,
|
664
|
-
group_n: tl.constexpr,
|
665
|
-
group_k: tl.constexpr,
|
666
|
-
a_stride_0: tl.constexpr,
|
667
|
-
b_stride_0: tl.constexpr,
|
668
|
-
b_stride_1: tl.constexpr,
|
669
|
-
as_stride_0: tl.constexpr,
|
670
|
-
as_stride_1: tl.constexpr,
|
671
|
-
bs_stride_0: tl.constexpr,
|
672
|
-
bs_stride_2: tl.constexpr,
|
673
|
-
bs_stride_1: tl.constexpr,
|
674
|
-
use_per_token_if_dynamic: tl.constexpr,
|
675
|
-
BLOCK_SIZE_M: tl.constexpr,
|
676
|
-
BLOCK_SIZE_N: tl.constexpr,
|
677
|
-
BLOCK_SIZE_K: tl.constexpr,
|
678
|
-
):
|
679
|
-
c_dtype = c.dtype.element_ty
|
680
|
-
|
681
|
-
pid_m = tl.program_id(0)
|
682
|
-
pid_n = tl.program_id(1)
|
683
|
-
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
684
|
-
if pid_m >= total_m_block:
|
685
|
-
return
|
686
|
-
|
687
|
-
m_range_start, m_range_end, expert_id = compute_m_range(
|
688
|
-
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
689
|
-
)
|
690
|
-
if m_range_end - m_range_start == 0:
|
691
|
-
return
|
692
|
-
|
693
|
-
n_range_start = pid_n * BLOCK_SIZE_N
|
694
|
-
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
695
|
-
|
696
|
-
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
697
|
-
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
698
|
-
|
699
|
-
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
700
|
-
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
701
|
-
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
702
|
-
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
703
|
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
704
|
-
|
705
|
-
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
706
|
-
b_ptr = b + (
|
707
|
-
(expert_id * b_stride_0)
|
708
|
-
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
709
|
-
+ offs_k[None, :]
|
710
|
-
)
|
711
|
-
|
712
|
-
if group_k > 0 and group_n > 0:
|
713
|
-
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
|
714
|
-
offs_bsn = (n_range_start + offs_bn) // group_n
|
715
|
-
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
|
716
|
-
|
717
|
-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
718
|
-
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
719
|
-
a_tile = tl.load(
|
720
|
-
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
721
|
-
)
|
722
|
-
b_tile = tl.load(
|
723
|
-
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
724
|
-
)
|
725
|
-
|
726
|
-
if group_k > 0 and group_n > 0:
|
727
|
-
k_start = k * BLOCK_SIZE_K
|
728
|
-
offs_ks = k_start // group_k
|
729
|
-
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
|
730
|
-
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
|
731
|
-
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
|
732
|
-
else:
|
733
|
-
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
734
|
-
a_ptr += BLOCK_SIZE_K
|
735
|
-
b_ptr += BLOCK_SIZE_K
|
736
|
-
|
737
|
-
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
738
|
-
if use_per_token_if_dynamic:
|
739
|
-
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
|
740
|
-
else:
|
741
|
-
scale_a_value = tl.load(scale_a + expert_id)
|
742
|
-
scale_b_value = tl.load(scale_b + expert_id)
|
743
|
-
accumulator *= scale_a_value * scale_b_value
|
744
|
-
|
745
|
-
c_tile = accumulator.to(c_dtype)
|
746
|
-
|
747
|
-
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
748
|
-
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
749
|
-
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
750
|
-
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
751
|
-
tl.store(c_ptr, c_tile, mask=c_mask)
|
752
|
-
|
753
|
-
|
754
|
-
@triton.jit
|
755
|
-
def compute_m_num_tiles_indptr(
|
756
|
-
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
757
|
-
):
|
758
|
-
for bs in range(batch_size):
|
759
|
-
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
760
|
-
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
761
|
-
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
762
|
-
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
763
|
-
|
764
|
-
|
765
|
-
def grouped_gemm_triton(
|
766
|
-
a: torch.Tensor,
|
767
|
-
b: torch.Tensor,
|
768
|
-
c: torch.Tensor,
|
769
|
-
batch_size: int,
|
770
|
-
weight_column_major: bool,
|
771
|
-
seg_indptr: Optional[torch.Tensor] = None,
|
772
|
-
weight_indices: Optional[torch.Tensor] = None,
|
773
|
-
use_fp8_w8a8: bool = False,
|
774
|
-
scale_a: torch.Tensor = None,
|
775
|
-
scale_b: torch.Tensor = None,
|
776
|
-
block_shape: Optional[List[int]] = None,
|
777
|
-
c_dtype=None,
|
778
|
-
use_per_token_if_dynamic: bool = True,
|
779
|
-
):
|
780
|
-
assert weight_column_major == True # TODO: more
|
781
|
-
if use_fp8_w8a8 and block_shape is None:
|
782
|
-
assert scale_a is not None and scale_b is not None
|
783
|
-
|
784
|
-
if block_shape is not None:
|
785
|
-
a_original = a
|
786
|
-
|
787
|
-
assert len(block_shape) == 2
|
788
|
-
block_n, block_k = block_shape[0], block_shape[1]
|
789
|
-
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
790
|
-
|
791
|
-
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
792
|
-
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
793
|
-
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
794
|
-
|
795
|
-
dispose_tensor(a_original)
|
796
|
-
|
797
|
-
# TODO: adjust config or tune kernel
|
798
|
-
# Reduce block size to prevent L40 shared memory overflow.
|
799
|
-
config = {
|
800
|
-
"BLOCK_SIZE_M": 64,
|
801
|
-
"BLOCK_SIZE_N": 32,
|
802
|
-
"BLOCK_SIZE_K": 128,
|
803
|
-
}
|
804
|
-
|
805
|
-
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
806
|
-
compute_m_num_tiles_indptr[(1,)](
|
807
|
-
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
808
|
-
)
|
809
|
-
|
810
|
-
if c is None:
|
811
|
-
assert c_dtype is not None
|
812
|
-
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
813
|
-
|
814
|
-
grid = lambda META: (
|
815
|
-
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
816
|
-
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
817
|
-
)
|
818
|
-
|
819
|
-
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
|
820
|
-
assert (
|
821
|
-
scale_a.shape[0] == a.shape[0]
|
822
|
-
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
|
823
|
-
|
824
|
-
grouped_gemm_triton_kernel[grid](
|
825
|
-
a,
|
826
|
-
b,
|
827
|
-
c,
|
828
|
-
batch_size,
|
829
|
-
b.size(1),
|
830
|
-
b.size(2),
|
831
|
-
seg_indptr,
|
832
|
-
weight_indices,
|
833
|
-
m_num_tiles_indptr,
|
834
|
-
scale_a,
|
835
|
-
scale_b,
|
836
|
-
use_fp8_w8a8,
|
837
|
-
0 if block_shape is None else block_shape[0],
|
838
|
-
0 if block_shape is None else block_shape[1],
|
839
|
-
a.stride(0),
|
840
|
-
b.stride(0),
|
841
|
-
b.stride(1),
|
842
|
-
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
|
843
|
-
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
|
844
|
-
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
845
|
-
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
846
|
-
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
847
|
-
use_per_token_if_dynamic,
|
848
|
-
**config,
|
849
|
-
)
|
850
|
-
return c
|
851
|
-
|
852
|
-
|
853
438
|
@triton.jit
|
854
439
|
def _fwd_kernel_ep_scatter_1(
|
855
440
|
num_recv_tokens_per_expert,
|
@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
|
|
1234
819
|
mask = dst_id < num_toks
|
1235
820
|
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
1236
821
|
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
1237
|
-
expert_dst_start = tl.load(seg_indptr + expert_id)
|
822
|
+
expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
|
1238
823
|
expert_dst_offset = dst_id - expert_dst_start
|
1239
824
|
dst_id = expert_id * m_max + expert_dst_offset
|
1240
825
|
tl.store(src2dst + src_id, dst_id, mask=mask)
|
@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
|
|
1248
833
|
gateup_input_scale_ptr,
|
1249
834
|
src2dst_ptr,
|
1250
835
|
topk_ids_ptr,
|
1251
|
-
start_expert_id,
|
1252
|
-
end_expert_id,
|
1253
836
|
topk,
|
1254
|
-
m_max,
|
1255
837
|
hidden_size,
|
1256
838
|
scale_size,
|
1257
839
|
BLOCK_SIZE: tl.constexpr,
|
@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
|
|
1267
849
|
vec = tl.arange(0, BLOCK_SIZE)
|
1268
850
|
for idx in range(topk):
|
1269
851
|
expert_id = tl.load(topk_ids_ptr + idx)
|
1270
|
-
if expert_id >=
|
852
|
+
if expert_id >= 0:
|
1271
853
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
1272
854
|
dst_idx = dst_idx_int32.to(tl.int64)
|
1273
|
-
dst_idx = dst_idx - start_expert_id * m_max
|
1274
855
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
1275
856
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
1276
857
|
offset = start_offset + vec
|
@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
|
|
1287
868
|
|
1288
869
|
def moe_ep_deepgemm_preprocess(
|
1289
870
|
topk_ids: torch.Tensor,
|
1290
|
-
|
871
|
+
num_local_experts: int,
|
1291
872
|
hidden_states: torch.Tensor,
|
1292
873
|
top_k: int,
|
1293
|
-
start_expert_id,
|
1294
|
-
end_expert_id,
|
1295
874
|
block_shape,
|
1296
875
|
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
1297
876
|
):
|
1298
877
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
1299
|
-
seg_indptr = torch.zeros(
|
878
|
+
seg_indptr = torch.zeros(
|
879
|
+
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
880
|
+
)
|
1300
881
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
1301
|
-
masked_m = torch.
|
882
|
+
masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
|
1302
883
|
|
1303
|
-
compute_seg_indptr_triton_kernel[(
|
884
|
+
compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
|
1304
885
|
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
1305
886
|
)
|
1306
887
|
|
1307
888
|
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
1308
|
-
compute_masked_m_triton_kernel[(
|
889
|
+
compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
|
1309
890
|
|
1310
891
|
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
1311
|
-
m_max = (hidden_states.size(0) +
|
1312
|
-
expected_m = (topk_ids.numel()
|
892
|
+
m_max = (hidden_states.size(0) // 256 + 1) * 256
|
893
|
+
expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
|
1313
894
|
gateup_input = torch.empty(
|
1314
|
-
(
|
895
|
+
(num_local_experts, m_max, hidden_states.size(1)),
|
1315
896
|
device=hidden_states.device,
|
1316
897
|
dtype=output_dtype,
|
1317
898
|
)
|
@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
|
|
1330
911
|
block_shape = [128, 128]
|
1331
912
|
assert len(block_shape) == 2
|
1332
913
|
block_n, block_k = block_shape[0], block_shape[1]
|
914
|
+
|
915
|
+
# TODO: fuse this with the preprocess
|
1333
916
|
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
1334
917
|
|
1335
918
|
gateup_input_scale = torch.empty(
|
@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
|
|
1345
928
|
gateup_input_scale,
|
1346
929
|
src2dst,
|
1347
930
|
topk_ids,
|
1348
|
-
start_expert_id,
|
1349
|
-
end_expert_id,
|
1350
931
|
top_k,
|
1351
|
-
m_max,
|
1352
932
|
hidden_states.size(1),
|
1353
933
|
scale.size(1),
|
1354
934
|
BLOCK_SIZE=1024,
|
1355
935
|
)
|
1356
936
|
|
1357
937
|
return (
|
1358
|
-
|
1359
|
-
masked_m[start_expert_id : (end_expert_id + 1)],
|
938
|
+
masked_m,
|
1360
939
|
expected_m,
|
1361
940
|
src2dst,
|
1362
941
|
gateup_input,
|