sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from dataclasses import dataclass
|
|
5
5
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
6
6
|
|
7
7
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
|
-
from sglang.srt.layers.moe import
|
9
|
-
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
8
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
10
9
|
BaseDispatcher,
|
11
10
|
BaseDispatcherConfig,
|
11
|
+
CombineInput,
|
12
|
+
CombineInputFormat,
|
12
13
|
DispatchOutput,
|
13
14
|
DispatchOutputFormat,
|
14
15
|
)
|
16
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
15
17
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
16
18
|
from sglang.srt.utils import (
|
17
19
|
get_bool_env_var,
|
@@ -40,11 +42,6 @@ from enum import Enum, IntEnum, auto
|
|
40
42
|
import torch
|
41
43
|
import torch.distributed as dist
|
42
44
|
|
43
|
-
from sglang.srt.layers.moe.ep_moe.kernels import (
|
44
|
-
deepep_permute_triton_kernel,
|
45
|
-
deepep_post_reorder_triton_kernel,
|
46
|
-
deepep_run_moe_deep_preprocess,
|
47
|
-
)
|
48
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
46
|
|
50
47
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
@@ -56,6 +53,7 @@ class DeepEPNormalOutput(NamedTuple):
|
|
56
53
|
"""DeepEP normal dispatch output."""
|
57
54
|
|
58
55
|
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
56
|
+
# hidden_states_scale
|
59
57
|
topk_idx: torch.Tensor
|
60
58
|
topk_weights: torch.Tensor
|
61
59
|
num_recv_tokens_per_expert: List[int]
|
@@ -79,24 +77,32 @@ class DeepEPLLOutput(NamedTuple):
|
|
79
77
|
return DispatchOutputFormat.DEEPEP_LL
|
80
78
|
|
81
79
|
|
82
|
-
|
83
|
-
|
80
|
+
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
81
|
+
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
84
82
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
expected_m: int
|
83
|
+
|
84
|
+
class DeepEPNormalCombineInput(NamedTuple):
|
85
|
+
"""DeepEP normal combine input."""
|
86
|
+
|
87
|
+
pass
|
91
88
|
|
92
89
|
@property
|
93
|
-
def format(self) ->
|
94
|
-
return
|
90
|
+
def format(self) -> CombineInputFormat:
|
91
|
+
return CombineInputFormat.DEEPEP_NORMAL
|
95
92
|
|
96
93
|
|
97
|
-
|
98
|
-
|
99
|
-
|
94
|
+
class DeepEPLLCombineInput(NamedTuple):
|
95
|
+
"""DeepEP low latency combine input."""
|
96
|
+
|
97
|
+
pass
|
98
|
+
|
99
|
+
@property
|
100
|
+
def format(self) -> CombineInputFormat:
|
101
|
+
return CombineInputFormat.DEEPEP_LL
|
102
|
+
|
103
|
+
|
104
|
+
assert isinstance(DeepEPNormalCombineInput, CombineInput)
|
105
|
+
assert isinstance(DeepEPLLCombineInput, CombineInput)
|
100
106
|
|
101
107
|
|
102
108
|
class DeepEPDispatchMode(IntEnum):
|
@@ -272,6 +278,9 @@ class _DeepEPDispatcherImplBase:
|
|
272
278
|
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
273
279
|
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
274
280
|
)
|
281
|
+
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
|
282
|
+
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
283
|
+
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
275
284
|
|
276
285
|
self.handle = None
|
277
286
|
|
@@ -409,7 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
409
418
|
topk_idx: torch.Tensor,
|
410
419
|
topk_weights: torch.Tensor,
|
411
420
|
):
|
412
|
-
|
421
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
422
|
+
deepep_post_reorder_triton_kernel,
|
423
|
+
)
|
424
|
+
|
425
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
413
426
|
output = hidden_states
|
414
427
|
else:
|
415
428
|
if hidden_states.shape[0] > 0:
|
@@ -523,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
523
536
|
masked_m
|
524
537
|
)
|
525
538
|
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
expected_m,
|
534
|
-
)
|
535
|
-
else:
|
536
|
-
deepep_output = DeepEPLLOutput(
|
537
|
-
hidden_states,
|
538
|
-
topk_idx,
|
539
|
-
topk_weights,
|
540
|
-
masked_m,
|
541
|
-
expected_m,
|
542
|
-
)
|
539
|
+
deepep_output = DeepEPLLOutput(
|
540
|
+
hidden_states,
|
541
|
+
topk_idx,
|
542
|
+
topk_weights,
|
543
|
+
masked_m,
|
544
|
+
expected_m,
|
545
|
+
)
|
543
546
|
return deepep_output
|
544
547
|
|
545
548
|
def _dispatch_core(
|
@@ -1,19 +1,61 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import NamedTuple
|
3
|
+
from typing import TYPE_CHECKING, NamedTuple
|
4
4
|
|
5
|
-
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
8
|
+
BaseDispatcher,
|
9
|
+
CombineInput,
|
10
|
+
CombineInputFormat,
|
6
11
|
DispatchOutput,
|
7
12
|
DispatchOutputFormat,
|
8
13
|
)
|
9
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
17
|
+
|
10
18
|
|
11
19
|
class StandardDispatchOutput(NamedTuple):
|
12
20
|
"""Standard dispatch output."""
|
13
21
|
|
22
|
+
hidden_states: torch.Tensor
|
23
|
+
topk_output: TopKOutput
|
24
|
+
|
14
25
|
@property
|
15
26
|
def format(self) -> DispatchOutputFormat:
|
16
27
|
return DispatchOutputFormat.STANDARD
|
17
28
|
|
18
29
|
|
19
30
|
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
31
|
+
|
32
|
+
|
33
|
+
class StandardCombineInput(NamedTuple):
|
34
|
+
"""Standard combine input."""
|
35
|
+
|
36
|
+
hidden_states: torch.Tensor
|
37
|
+
|
38
|
+
@property
|
39
|
+
def format(self) -> CombineInputFormat:
|
40
|
+
return CombineInputFormat.STANDARD
|
41
|
+
|
42
|
+
|
43
|
+
assert isinstance(StandardCombineInput, CombineInput)
|
44
|
+
|
45
|
+
|
46
|
+
class StandardDispatcher(BaseDispatcher):
|
47
|
+
|
48
|
+
def dispatch(
|
49
|
+
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
50
|
+
) -> DispatchOutput:
|
51
|
+
return StandardDispatchOutput(
|
52
|
+
hidden_states=hidden_states, topk_output=topk_output
|
53
|
+
)
|
54
|
+
|
55
|
+
def combine(self, combine_input: CombineInput) -> torch.Tensor:
|
56
|
+
if isinstance(combine_input, StandardCombineInput):
|
57
|
+
return combine_input.hidden_states
|
58
|
+
else:
|
59
|
+
# TODO: this branch should be removed in the future
|
60
|
+
assert isinstance(combine_input, torch.Tensor)
|
61
|
+
return combine_input
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
|
304
304
|
global_num_experts = router_logits.shape[-1]
|
305
305
|
|
306
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
307
|
-
if global_num_experts == 256
|
307
|
+
if global_num_experts == 256:
|
308
308
|
|
309
309
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
310
310
|
router_logits = router_logits.to(torch.float32)
|
311
311
|
|
312
|
-
|
312
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
313
313
|
router_logits,
|
314
314
|
k=self.topk_config.top_k,
|
315
315
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
@@ -321,6 +321,24 @@ class TopK(CustomOp):
|
|
321
321
|
routed_scaling_factor=routed_scaling_factor,
|
322
322
|
eps=float(1e-20),
|
323
323
|
)
|
324
|
+
|
325
|
+
if self.topk_config.renormalize:
|
326
|
+
topk_weights_sum = (
|
327
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
328
|
+
if self.topk_config.num_fused_shared_experts == 0
|
329
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
330
|
+
)
|
331
|
+
topk_weights = topk_weights / topk_weights_sum
|
332
|
+
|
333
|
+
if expert_location_dispatch_info is not None:
|
334
|
+
topk_ids = topk_ids_logical_to_physical(
|
335
|
+
topk_ids, expert_location_dispatch_info
|
336
|
+
)
|
337
|
+
get_global_expert_distribution_recorder().on_select_experts(
|
338
|
+
topk_ids=topk_ids
|
339
|
+
)
|
340
|
+
|
341
|
+
return StandardTopKOutput(topk_weights, topk_ids, _)
|
324
342
|
else:
|
325
343
|
self.topk_config.torch_native = True
|
326
344
|
return select_experts(
|
@@ -347,17 +365,28 @@ def fused_topk_torch_native(
|
|
347
365
|
gating_output: torch.Tensor,
|
348
366
|
topk: int,
|
349
367
|
renormalize: bool,
|
368
|
+
correction_bias: torch.Tensor = None,
|
350
369
|
):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
370
|
+
if correction_bias is not None:
|
371
|
+
n_routed_experts = gating_output.shape[-1]
|
372
|
+
scores = gating_output.softmax(dim=-1)
|
373
|
+
scores_for_choice = scores.view(
|
374
|
+
-1, n_routed_experts
|
375
|
+
) + correction_bias.unsqueeze(0)
|
376
|
+
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
377
|
+
topk_weights = scores.gather(1, topk_ids)
|
378
|
+
else:
|
379
|
+
assert (
|
380
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
381
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
382
|
+
M, _ = hidden_states.shape
|
383
|
+
topk_weights = torch.empty(
|
384
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
385
|
+
)
|
386
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
387
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
388
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
389
|
+
|
361
390
|
if renormalize:
|
362
391
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
363
392
|
return topk_weights, topk_ids
|
@@ -370,6 +399,7 @@ def fused_topk_cpu(
|
|
370
399
|
renormalize: bool,
|
371
400
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
372
401
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
402
|
+
correction_bias: torch.Tensor = None,
|
373
403
|
):
|
374
404
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
375
405
|
hidden_states=hidden_states,
|
@@ -815,6 +845,7 @@ def select_experts(
|
|
815
845
|
gating_output=router_logits,
|
816
846
|
topk=top_k,
|
817
847
|
renormalize=renormalize,
|
848
|
+
correction_bias=correction_bias,
|
818
849
|
)
|
819
850
|
elif custom_routing_function is None:
|
820
851
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib.util
|
4
|
+
import logging
|
4
5
|
from enum import Enum
|
5
6
|
from functools import lru_cache
|
6
7
|
from typing import TYPE_CHECKING, Optional
|
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
|
|
12
13
|
get_attention_dp_size,
|
13
14
|
is_dp_attention_enabled,
|
14
15
|
)
|
15
|
-
from sglang.srt.utils import logger
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.server_args import ServerArgs
|
19
19
|
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
20
22
|
|
21
23
|
class MoeA2ABackend(Enum):
|
22
24
|
|
@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
131
133
|
global MOE_A2A_BACKEND
|
132
134
|
if MOE_A2A_BACKEND is None:
|
133
135
|
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
134
|
-
MOE_A2A_BACKEND = MoeA2ABackend
|
136
|
+
MOE_A2A_BACKEND = MoeA2ABackend.NONE
|
135
137
|
return MOE_A2A_BACKEND
|
136
138
|
|
137
139
|
|
@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
|
|
139
141
|
global MOE_RUNNER_BACKEND
|
140
142
|
if MOE_RUNNER_BACKEND is None:
|
141
143
|
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
142
|
-
MOE_RUNNER_BACKEND = MoeRunnerBackend
|
144
|
+
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
143
145
|
return MOE_RUNNER_BACKEND
|
144
146
|
|
145
147
|
|
@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode:
|
|
147
149
|
global DEEPEP_MODE
|
148
150
|
if DEEPEP_MODE is None:
|
149
151
|
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
150
|
-
DEEPEP_MODE = DeepEPMode
|
152
|
+
DEEPEP_MODE = DeepEPMode.AUTO
|
151
153
|
return DEEPEP_MODE
|
152
154
|
|
153
155
|
|
@@ -162,7 +164,6 @@ def get_deepep_config() -> str:
|
|
162
164
|
def is_tbo_enabled() -> bool:
|
163
165
|
global IS_TBO_ENABLED
|
164
166
|
if IS_TBO_ENABLED is None:
|
165
|
-
logger.warning("IS_TBO_ENABLED is not initialized, using False")
|
166
167
|
IS_TBO_ENABLED = False
|
167
168
|
return IS_TBO_ENABLED
|
168
169
|
|
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
-
from sglang.srt.layers.moe.
|
37
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
+
StandardDispatchOutput,
|
39
|
+
CombineInput,
|
40
|
+
)
|
38
41
|
|
39
42
|
from sglang.srt.utils import is_cuda, is_hip
|
40
43
|
|
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
736
739
|
)
|
737
740
|
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
738
741
|
|
742
|
+
def create_moe_runner(
|
743
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
744
|
+
):
|
745
|
+
self.moe_runner_config = moe_runner_config
|
746
|
+
|
739
747
|
def apply(
|
740
748
|
self,
|
741
749
|
layer: torch.nn.Module,
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
750
|
+
dispatch_output: StandardDispatchOutput,
|
751
|
+
) -> CombineInput:
|
752
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
753
|
+
|
746
754
|
assert (
|
747
|
-
moe_runner_config.activation == "silu"
|
755
|
+
self.moe_runner_config.activation == "silu"
|
748
756
|
), "Only SiLU activation is supported."
|
749
757
|
|
750
758
|
# The input must currently be float16
|
759
|
+
x = dispatch_output.hidden_states
|
760
|
+
topk_output = dispatch_output.topk_output
|
761
|
+
|
751
762
|
orig_dtype = x.dtype
|
752
763
|
x = x.half()
|
753
764
|
|
754
765
|
topk_weights, topk_ids, router_logits = topk_output
|
755
766
|
|
756
|
-
|
767
|
+
output = fused_marlin_moe(
|
757
768
|
x,
|
758
769
|
layer.w13_qweight,
|
759
770
|
layer.w2_qweight,
|
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
768
779
|
w2_zeros=layer.w2_qzeros,
|
769
780
|
num_bits=self.quant_config.weight_bits,
|
770
781
|
).to(orig_dtype)
|
782
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
import inspect
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from dataclasses import dataclass
|
6
7
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
7
8
|
|
8
9
|
import torch
|
@@ -10,7 +11,7 @@ from torch import nn
|
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
12
13
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
13
|
-
from sglang.srt.layers.moe.
|
14
|
+
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
14
15
|
|
15
16
|
|
16
17
|
class QuantizeMethodBase(ABC):
|
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
89
90
|
layer: torch.nn.Module,
|
90
91
|
num_experts: int,
|
91
92
|
hidden_size: int,
|
92
|
-
|
93
|
+
intermediate_size_per_partition: int,
|
93
94
|
params_dtype: torch.dtype,
|
94
95
|
**extra_weight_attrs,
|
95
96
|
):
|
96
97
|
raise NotImplementedError
|
97
98
|
|
99
|
+
@abstractmethod
|
100
|
+
def create_moe_runner(
|
101
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
102
|
+
):
|
103
|
+
raise NotImplementedError
|
104
|
+
|
98
105
|
@abstractmethod
|
99
106
|
def apply(
|
100
107
|
self,
|
101
108
|
layer: torch.nn.Module,
|
102
|
-
|
103
|
-
|
104
|
-
moe_runner_config: MoeRunnerConfig,
|
105
|
-
) -> torch.Tensor:
|
109
|
+
dispatch_output: DispatchOutput,
|
110
|
+
) -> CombineInput:
|
106
111
|
raise NotImplementedError
|
107
112
|
|
108
113
|
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
from torch.nn import Module
|
10
10
|
|
11
11
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
22
24
|
from sglang.srt.utils import set_weight_attrs
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
29
33
|
|
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
257
261
|
layer: Module,
|
258
262
|
num_experts: int,
|
259
263
|
hidden_size: int,
|
260
|
-
|
264
|
+
intermediate_size_per_partition: int,
|
261
265
|
params_dtype: torch.dtype,
|
262
266
|
**extra_weight_attrs,
|
263
267
|
):
|
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
273
277
|
)
|
274
278
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
275
279
|
# Required by column parallel or enabling merged weights
|
276
|
-
if
|
280
|
+
if intermediate_size_per_partition % block_n != 0:
|
277
281
|
raise ValueError(
|
278
282
|
f"The output_size of gate's and up's weight = "
|
279
|
-
f"{
|
283
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
280
284
|
f"weight quantization block_n = {block_n}."
|
281
285
|
)
|
282
286
|
if tp_size > 1:
|
283
287
|
# Required by row parallel
|
284
|
-
if
|
288
|
+
if intermediate_size_per_partition % block_k != 0:
|
285
289
|
raise ValueError(
|
286
290
|
f"The input_size of down's weight = "
|
287
|
-
f"{
|
291
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
288
292
|
f"weight quantization block_k = {block_k}."
|
289
293
|
)
|
290
294
|
|
291
295
|
# WEIGHTS
|
292
296
|
w13_weight = torch.nn.Parameter(
|
293
297
|
torch.empty(
|
294
|
-
num_experts,
|
298
|
+
num_experts,
|
299
|
+
2 * intermediate_size_per_partition,
|
300
|
+
hidden_size,
|
301
|
+
dtype=params_dtype,
|
295
302
|
),
|
296
303
|
requires_grad=False,
|
297
304
|
)
|
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
300
307
|
|
301
308
|
w2_weight = torch.nn.Parameter(
|
302
309
|
torch.empty(
|
303
|
-
num_experts,
|
310
|
+
num_experts,
|
311
|
+
hidden_size,
|
312
|
+
intermediate_size_per_partition,
|
313
|
+
dtype=params_dtype,
|
304
314
|
),
|
305
315
|
requires_grad=False,
|
306
316
|
)
|
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
311
321
|
w13_weight_scale = torch.nn.Parameter(
|
312
322
|
torch.ones(
|
313
323
|
num_experts,
|
314
|
-
2 * ((
|
324
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
315
325
|
(hidden_size + block_k - 1) // block_k,
|
316
326
|
dtype=torch.float32,
|
317
327
|
),
|
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
321
331
|
torch.ones(
|
322
332
|
num_experts,
|
323
333
|
(hidden_size + block_n - 1) // block_n,
|
324
|
-
(
|
334
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
325
335
|
dtype=torch.float32,
|
326
336
|
),
|
327
337
|
requires_grad=False,
|
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
344
354
|
# Block quant doesn't need to process weights after loading
|
345
355
|
return
|
346
356
|
|
357
|
+
def create_moe_runner(
|
358
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
359
|
+
):
|
360
|
+
self.moe_runner_config = moe_runner_config
|
361
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
362
|
+
|
347
363
|
def apply(
|
348
364
|
self,
|
349
365
|
layer: torch.nn.Module,
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
# Expert fusion with INT8 quantization
|
357
|
-
return fused_experts(
|
358
|
-
x,
|
359
|
-
layer.w13_weight,
|
360
|
-
layer.w2_weight,
|
361
|
-
topk_output=topk_output,
|
362
|
-
moe_runner_config=moe_runner_config,
|
366
|
+
dispatch_output: StandardDispatchOutput,
|
367
|
+
) -> CombineInput:
|
368
|
+
|
369
|
+
quant_info = TritonMoeQuantInfo(
|
370
|
+
w13_weight=layer.w13_weight,
|
371
|
+
w2_weight=layer.w2_weight,
|
363
372
|
use_int8_w8a8=True,
|
364
|
-
|
365
|
-
w2_scale=
|
366
|
-
|
373
|
+
w13_scale=layer.w13_weight_scale_inv,
|
374
|
+
w2_scale=layer.w2_weight_scale_inv,
|
375
|
+
a13_scale=layer.w13_input_scale,
|
367
376
|
a2_scale=layer.w2_input_scale,
|
368
377
|
block_shape=self.quant_config.weight_block_size,
|
369
378
|
)
|
379
|
+
|
380
|
+
return self.runner.run(dispatch_output, quant_info)
|