sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,13 @@ from sglang.srt.utils import (
|
|
|
25
25
|
is_hip,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
+
try:
|
|
29
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
30
|
+
|
|
31
|
+
_support_tensor_descriptor = True
|
|
32
|
+
except:
|
|
33
|
+
_support_tensor_descriptor = False
|
|
34
|
+
|
|
28
35
|
_is_hip = is_hip()
|
|
29
36
|
_is_cuda = is_cuda()
|
|
30
37
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
@@ -41,6 +48,10 @@ elif _is_hip:
|
|
|
41
48
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
|
42
49
|
|
|
43
50
|
|
|
51
|
+
def support_tensor_descriptor():
|
|
52
|
+
return _support_tensor_descriptor
|
|
53
|
+
|
|
54
|
+
|
|
44
55
|
@triton.jit
|
|
45
56
|
def write_zeros_to_output(
|
|
46
57
|
c_ptr,
|
|
@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
|
|
|
108
119
|
use_int4_w4a16: tl.constexpr,
|
|
109
120
|
use_int8_w8a16: tl.constexpr,
|
|
110
121
|
even_Ks: tl.constexpr,
|
|
122
|
+
filter_expert: tl.constexpr,
|
|
111
123
|
):
|
|
112
124
|
"""
|
|
113
125
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
|
|
|
161
173
|
token_mask = offs_token < num_valid_tokens
|
|
162
174
|
|
|
163
175
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
164
|
-
if off_experts == -1:
|
|
176
|
+
if filter_expert and off_experts == -1:
|
|
165
177
|
# -----------------------------------------------------------
|
|
166
178
|
# Write back zeros to the output when the expert is not
|
|
167
179
|
# in the current expert parallel rank.
|
|
@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
|
|
|
296
308
|
def fused_moe_kernel(
|
|
297
309
|
# Pointers to matrices
|
|
298
310
|
a_ptr,
|
|
311
|
+
a_desc,
|
|
299
312
|
b_ptr,
|
|
313
|
+
b_desc,
|
|
300
314
|
bias_ptr,
|
|
301
315
|
c_ptr,
|
|
302
316
|
a_scale_ptr,
|
|
@@ -344,6 +358,8 @@ def fused_moe_kernel(
|
|
|
344
358
|
use_int8_w8a16: tl.constexpr,
|
|
345
359
|
per_channel_quant: tl.constexpr,
|
|
346
360
|
even_Ks: tl.constexpr,
|
|
361
|
+
c_sorted: tl.constexpr,
|
|
362
|
+
filter_expert: tl.constexpr,
|
|
347
363
|
):
|
|
348
364
|
"""
|
|
349
365
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
@@ -399,9 +415,10 @@ def fused_moe_kernel(
|
|
|
399
415
|
offs_token = offs_token.to(tl.int64)
|
|
400
416
|
token_mask = offs_token < num_valid_tokens
|
|
401
417
|
|
|
402
|
-
|
|
418
|
+
off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
|
|
419
|
+
off_experts = off_experts_i32.to(tl.int64)
|
|
403
420
|
|
|
404
|
-
if off_experts == -1:
|
|
421
|
+
if filter_expert and off_experts == -1:
|
|
405
422
|
# -----------------------------------------------------------
|
|
406
423
|
# Write back zeros to the output when the expert is not
|
|
407
424
|
# in the current expert parallel rank.
|
|
@@ -421,15 +438,23 @@ def fused_moe_kernel(
|
|
|
421
438
|
|
|
422
439
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
|
423
440
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
441
|
+
if a_desc is not None:
|
|
442
|
+
assert use_fp8_w8a8 and group_n > 0 and group_k > 0
|
|
443
|
+
start_offs_m = pid_m * BLOCK_SIZE_M
|
|
444
|
+
else:
|
|
445
|
+
a_ptrs = a_ptr + (
|
|
446
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if b_desc is not None:
|
|
450
|
+
start_offs_n = pid_n * BLOCK_SIZE_N
|
|
451
|
+
else:
|
|
452
|
+
b_ptrs = (
|
|
453
|
+
b_ptr
|
|
454
|
+
+ off_experts * stride_be
|
|
455
|
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
456
|
+
)
|
|
427
457
|
|
|
428
|
-
b_ptrs = (
|
|
429
|
-
b_ptr
|
|
430
|
-
+ off_experts * stride_be
|
|
431
|
-
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
432
|
-
)
|
|
433
458
|
if bias_ptr is not None:
|
|
434
459
|
bias = tl.load(
|
|
435
460
|
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
|
@@ -443,8 +468,14 @@ def fused_moe_kernel(
|
|
|
443
468
|
if use_fp8_w8a8 or use_int8_w8a8:
|
|
444
469
|
# block-wise
|
|
445
470
|
if group_k > 0 and group_n > 0:
|
|
446
|
-
|
|
447
|
-
|
|
471
|
+
if a_desc is not None:
|
|
472
|
+
a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
|
|
473
|
+
else:
|
|
474
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
475
|
+
if BLOCK_SIZE_N > group_n:
|
|
476
|
+
offs_bsn = offs_bn // group_n
|
|
477
|
+
else:
|
|
478
|
+
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
|
|
448
479
|
b_scale_ptrs = (
|
|
449
480
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
|
450
481
|
)
|
|
@@ -469,37 +500,49 @@ def fused_moe_kernel(
|
|
|
469
500
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
470
501
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
471
502
|
|
|
472
|
-
for
|
|
503
|
+
for k_start in range(0, K, BLOCK_SIZE_K):
|
|
473
504
|
# Load the next block of A and B, generate a mask by checking the
|
|
474
505
|
# K dimension.
|
|
475
|
-
if
|
|
506
|
+
if a_desc is not None:
|
|
507
|
+
a = a_desc.load([start_offs_m, k_start])
|
|
508
|
+
elif even_Ks:
|
|
476
509
|
a = tl.load(
|
|
477
510
|
a_ptrs,
|
|
478
511
|
mask=token_mask[:, None],
|
|
479
512
|
other=0.0,
|
|
480
513
|
)
|
|
481
|
-
b = tl.load(b_ptrs)
|
|
482
514
|
else:
|
|
483
515
|
a = tl.load(
|
|
484
516
|
a_ptrs,
|
|
485
|
-
mask=token_mask[:, None] & (offs_k[None, :] < K -
|
|
517
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
|
|
486
518
|
other=0.0,
|
|
487
519
|
)
|
|
488
|
-
|
|
520
|
+
|
|
521
|
+
if b_desc is not None:
|
|
522
|
+
b = (
|
|
523
|
+
b_desc.load([off_experts_i32, start_offs_n, k_start])
|
|
524
|
+
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
|
|
525
|
+
.T
|
|
526
|
+
)
|
|
527
|
+
elif even_Ks:
|
|
528
|
+
b = tl.load(b_ptrs)
|
|
529
|
+
else:
|
|
530
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
|
|
489
531
|
|
|
490
532
|
# We accumulate along the K dimension.
|
|
491
533
|
if use_int8_w8a16:
|
|
492
534
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
|
493
535
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
494
536
|
if group_k > 0 and group_n > 0:
|
|
495
|
-
k_start = k * BLOCK_SIZE_K
|
|
496
537
|
offs_ks = k_start // group_k
|
|
497
538
|
a_scale = tl.load(
|
|
498
539
|
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
|
499
540
|
)
|
|
500
541
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
|
501
|
-
|
|
502
|
-
|
|
542
|
+
if BLOCK_SIZE_N > group_n:
|
|
543
|
+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
|
544
|
+
else:
|
|
545
|
+
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
|
|
503
546
|
else:
|
|
504
547
|
if use_fp8_w8a8:
|
|
505
548
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
@@ -508,8 +551,10 @@ def fused_moe_kernel(
|
|
|
508
551
|
else:
|
|
509
552
|
accumulator += tl.dot(a, b)
|
|
510
553
|
# Advance the ptrs to the next K block.
|
|
511
|
-
|
|
512
|
-
|
|
554
|
+
if a_desc is None:
|
|
555
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
556
|
+
if b_desc is None:
|
|
557
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
513
558
|
|
|
514
559
|
if use_int8_w8a16:
|
|
515
560
|
accumulator *= b_scale
|
|
@@ -528,7 +573,12 @@ def fused_moe_kernel(
|
|
|
528
573
|
# -----------------------------------------------------------
|
|
529
574
|
# Write back the block of the output
|
|
530
575
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
531
|
-
|
|
576
|
+
if c_sorted:
|
|
577
|
+
c_ptrs = (
|
|
578
|
+
c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
|
|
579
|
+
)
|
|
580
|
+
else:
|
|
581
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
532
582
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
533
583
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
534
584
|
|
|
@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
|
|
|
557
607
|
per_channel_quant: bool,
|
|
558
608
|
block_shape: Optional[List[int]] = None,
|
|
559
609
|
no_combine: bool = False,
|
|
610
|
+
a_use_tma: bool = False,
|
|
611
|
+
b_use_tma: bool = False,
|
|
612
|
+
c_sorted: bool = False,
|
|
613
|
+
filter_expert: bool = True,
|
|
560
614
|
) -> None:
|
|
561
615
|
assert topk_weights.stride(1) == 1
|
|
562
616
|
assert sorted_token_ids.stride(0) == 1
|
|
@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
|
|
|
662
716
|
use_int4_w4a16=use_int4_w4a16,
|
|
663
717
|
use_int8_w8a16=use_int8_w8a16,
|
|
664
718
|
even_Ks=even_Ks,
|
|
719
|
+
filter_expert=filter_expert,
|
|
665
720
|
**config,
|
|
666
721
|
)
|
|
667
722
|
|
|
668
723
|
else:
|
|
724
|
+
if a_use_tma or b_use_tma:
|
|
725
|
+
# TMA descriptors require a global memory allocation
|
|
726
|
+
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
|
727
|
+
return torch.empty(size, device="cuda", dtype=torch.int8)
|
|
728
|
+
|
|
729
|
+
triton.set_allocator(alloc_fn)
|
|
730
|
+
if a_use_tma:
|
|
731
|
+
a_desc = TensorDescriptor(
|
|
732
|
+
A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
|
|
733
|
+
)
|
|
734
|
+
else:
|
|
735
|
+
a_desc = None
|
|
736
|
+
if b_use_tma:
|
|
737
|
+
b_desc = TensorDescriptor(
|
|
738
|
+
B,
|
|
739
|
+
B.shape,
|
|
740
|
+
B.stride(),
|
|
741
|
+
[1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
|
|
742
|
+
)
|
|
743
|
+
else:
|
|
744
|
+
b_desc = None
|
|
669
745
|
|
|
670
746
|
fused_moe_kernel[grid](
|
|
671
747
|
A,
|
|
748
|
+
a_desc,
|
|
672
749
|
B,
|
|
750
|
+
b_desc,
|
|
673
751
|
bias,
|
|
674
752
|
C,
|
|
675
753
|
A_scale,
|
|
@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
|
|
|
689
767
|
B.stride(1),
|
|
690
768
|
bias.stride(0) if bias is not None else 0,
|
|
691
769
|
bias.stride(1) if bias is not None else 0,
|
|
692
|
-
C.stride(
|
|
693
|
-
C.stride(
|
|
770
|
+
C.stride(-2),
|
|
771
|
+
C.stride(-1),
|
|
694
772
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
695
773
|
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
696
774
|
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
|
|
|
706
784
|
use_int8_w8a16=use_int8_w8a16,
|
|
707
785
|
per_channel_quant=per_channel_quant,
|
|
708
786
|
even_Ks=even_Ks,
|
|
787
|
+
c_sorted=c_sorted,
|
|
788
|
+
filter_expert=filter_expert,
|
|
709
789
|
**config,
|
|
710
790
|
)
|
|
711
791
|
|
|
@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
|
|
|
172
172
|
self.reduce_results = reduce_results
|
|
173
173
|
self.use_presharded_weights = use_presharded_weights
|
|
174
174
|
|
|
175
|
-
self.use_triton_kernels = get_moe_runner_backend().
|
|
175
|
+
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
|
|
176
176
|
|
|
177
177
|
self.quant_config = quant_config
|
|
178
178
|
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
|
@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
|
|
|
232
232
|
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
|
233
233
|
) or (
|
|
234
234
|
isinstance(self.quant_method, Fp8MoEMethod)
|
|
235
|
-
and self.quant_method.
|
|
235
|
+
and self.quant_method._should_use_cutlass_fused_experts()
|
|
236
236
|
)
|
|
237
237
|
|
|
238
238
|
def _load_per_tensor_weight_scale(
|
|
@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
|
|
|
839
839
|
dispatch_output=dispatch_output,
|
|
840
840
|
**kwargs,
|
|
841
841
|
)
|
|
842
|
-
final_hidden_states = self.dispatcher.combine(combine_input)
|
|
842
|
+
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
|
|
843
843
|
|
|
844
844
|
# TODO: should we add some conditions here?
|
|
845
845
|
final_hidden_states = final_hidden_states[
|
|
@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
|
|
|
47
47
|
|
|
48
48
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
|
49
49
|
|
|
50
|
-
assert TopKOutputChecker.
|
|
50
|
+
assert TopKOutputChecker.format_is_triton_kernels(topk_output)
|
|
51
51
|
|
|
52
52
|
routing_data, gather_idx, scatter_idx = topk_output
|
|
53
53
|
|
|
@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
|
|
|
172
172
|
b2: torch.Tensor,
|
|
173
173
|
topk_output: TopKOutput,
|
|
174
174
|
moe_runner_config: MoeRunnerConfig,
|
|
175
|
+
apply_router_weight_on_input: bool = False,
|
|
175
176
|
use_fp8_w8a8: bool = False,
|
|
176
177
|
per_channel_quant: bool = False,
|
|
177
178
|
global_num_experts: int = -1,
|
|
@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
|
|
|
184
185
|
) -> torch.Tensor:
|
|
185
186
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
|
186
187
|
|
|
187
|
-
assert TopKOutputChecker.
|
|
188
|
+
assert TopKOutputChecker.format_is_triton_kernels(topk_output)
|
|
188
189
|
|
|
189
190
|
routing_data, gather_idx, scatter_idx = topk_output
|
|
190
191
|
|
|
@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
|
|
|
201
202
|
scatter_indx=scatter_idx,
|
|
202
203
|
inplace=False, # triton kernel doesn't support inplace
|
|
203
204
|
activation=moe_runner_config.activation,
|
|
205
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
204
206
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
205
207
|
per_channel_quant=per_channel_quant,
|
|
206
208
|
global_num_experts=global_num_experts,
|
|
@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
|
|
|
228
230
|
scatter_indx: ScatterIndx,
|
|
229
231
|
inplace: bool = False,
|
|
230
232
|
activation: str = "silu",
|
|
233
|
+
apply_router_weight_on_input: bool = False,
|
|
231
234
|
use_fp8_w8a8: bool = False,
|
|
232
235
|
per_channel_quant: bool = False,
|
|
233
236
|
global_num_experts: int = -1,
|
|
@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
|
|
|
296
299
|
routing_data,
|
|
297
300
|
gather_indx=gather_indx,
|
|
298
301
|
precision_config=w1_pcg,
|
|
299
|
-
gammas=None,
|
|
302
|
+
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
|
300
303
|
fused_activation=act,
|
|
301
304
|
)
|
|
302
305
|
|
|
@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
|
|
|
307
310
|
routing_data,
|
|
308
311
|
scatter_indx=scatter_indx,
|
|
309
312
|
precision_config=w2_pcg,
|
|
310
|
-
gammas=routing_data.gate_scal,
|
|
313
|
+
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
|
311
314
|
)
|