sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -40,9 +40,10 @@ import triton.language as tl
|
|
40
40
|
|
41
41
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
42
42
|
from sglang.srt.layers.dp_attention import (
|
43
|
-
|
43
|
+
DpPaddingMode,
|
44
44
|
get_attention_dp_rank,
|
45
45
|
get_attention_tp_size,
|
46
|
+
set_dp_buffer_len,
|
46
47
|
)
|
47
48
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
48
49
|
from sglang.srt.utils import (
|
@@ -240,6 +241,9 @@ class ForwardBatch:
|
|
240
241
|
prefix_chunk_num_tokens: Optional[List[int]] = None
|
241
242
|
# KV Indices for each chunk
|
242
243
|
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
244
|
+
# For MLA chunked prefix cache used in chunked prefill
|
245
|
+
# Tell attention backend whether lse needs to be returned
|
246
|
+
mha_return_lse: Optional[bool] = None
|
243
247
|
|
244
248
|
# For multimodal
|
245
249
|
mm_inputs: Optional[List[MultimodalInputs]] = None
|
@@ -274,13 +278,13 @@ class ForwardBatch:
|
|
274
278
|
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
275
279
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
276
280
|
# The padding mode for DP attention
|
277
|
-
dp_padding_mode: Optional[
|
281
|
+
dp_padding_mode: Optional[DpPaddingMode] = None
|
278
282
|
# for extend, local start pos and num tokens is different in logits processor
|
279
283
|
# this will be computed in get_dp_local_info
|
280
284
|
# this will be recomputed in LogitsMetadata.from_forward_batch
|
281
285
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
282
286
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
283
|
-
|
287
|
+
global_dp_buffer_len: Optional[int] = None
|
284
288
|
is_extend_in_batch: bool = False
|
285
289
|
can_run_dp_cuda_graph: bool = False
|
286
290
|
global_forward_mode: Optional[ForwardMode] = None
|
@@ -628,7 +632,7 @@ class ForwardBatch:
|
|
628
632
|
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
629
633
|
) * attn_tp_size
|
630
634
|
|
631
|
-
dp_padding_mode =
|
635
|
+
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
|
632
636
|
self.dp_padding_mode = dp_padding_mode
|
633
637
|
|
634
638
|
if dp_padding_mode.is_max_len():
|
@@ -642,17 +646,14 @@ class ForwardBatch:
|
|
642
646
|
else:
|
643
647
|
buffer_len = sum(global_num_tokens)
|
644
648
|
|
645
|
-
self.gathered_buffer = torch.zeros(
|
646
|
-
(buffer_len, model_runner.model_config.hidden_size),
|
647
|
-
dtype=model_runner.dtype,
|
648
|
-
device=model_runner.device,
|
649
|
-
)
|
650
|
-
|
651
649
|
if len(global_num_tokens) > 1:
|
652
650
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
653
651
|
else:
|
654
652
|
num_tokens = global_num_tokens[0]
|
655
653
|
|
654
|
+
self.global_dp_buffer_len = buffer_len
|
655
|
+
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
656
|
+
|
656
657
|
bs = self.batch_size
|
657
658
|
|
658
659
|
if self.forward_mode.is_decode():
|
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
|
|
60
60
|
initialize_dp_attention,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
63
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
64
63
|
from sglang.srt.layers.quantization import (
|
65
64
|
deep_gemm_wrapper,
|
66
65
|
monkey_patch_isinstance_for_vllm_base_layer,
|
@@ -75,12 +74,12 @@ from sglang.srt.managers.schedule_batch import (
|
|
75
74
|
global_server_args_dict,
|
76
75
|
)
|
77
76
|
from sglang.srt.mem_cache.allocator import (
|
78
|
-
AscendPagedTokenToKVPoolAllocator,
|
79
77
|
BaseTokenToKVPoolAllocator,
|
80
78
|
PagedTokenToKVPoolAllocator,
|
81
79
|
SWATokenToKVPoolAllocator,
|
82
80
|
TokenToKVPoolAllocator,
|
83
81
|
)
|
82
|
+
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
|
84
83
|
from sglang.srt.mem_cache.memory_pool import (
|
85
84
|
AscendMLAPagedTokenToKVPool,
|
86
85
|
AscendTokenToKVPool,
|
@@ -92,10 +91,16 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
92
91
|
)
|
93
92
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
94
93
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
94
|
+
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
95
95
|
from sglang.srt.model_loader import get_model
|
96
96
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
97
97
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
98
98
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
99
|
+
from sglang.srt.offloader import (
|
100
|
+
create_offloader_from_server_args,
|
101
|
+
get_offloader,
|
102
|
+
set_offloader,
|
103
|
+
)
|
99
104
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
100
105
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
101
106
|
from sglang.srt.server_args import ServerArgs
|
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
|
|
118
123
|
is_npu,
|
119
124
|
monkey_patch_p2p_access_check,
|
120
125
|
monkey_patch_vllm_gguf_config,
|
121
|
-
set_cpu_offload_max_bytes,
|
122
126
|
set_cuda_arch,
|
123
127
|
)
|
124
128
|
from sglang.srt.weight_sync.tensor_bucket import (
|
@@ -168,6 +172,7 @@ class ModelRunner:
|
|
168
172
|
pp_size: int,
|
169
173
|
nccl_port: int,
|
170
174
|
server_args: ServerArgs,
|
175
|
+
dp_rank: Optional[int] = None,
|
171
176
|
is_draft_worker: bool = False,
|
172
177
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
173
178
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
@@ -176,10 +181,6 @@ class ModelRunner:
|
|
176
181
|
self.mem_fraction_static = mem_fraction_static
|
177
182
|
self.device = server_args.device
|
178
183
|
self.gpu_id = gpu_id
|
179
|
-
|
180
|
-
# Apply the rank zero filter to logger
|
181
|
-
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
182
|
-
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
183
184
|
self.tp_rank = tp_rank
|
184
185
|
self.tp_size = tp_size
|
185
186
|
self.moe_ep_rank = moe_ep_rank
|
@@ -205,15 +206,17 @@ class ModelRunner:
|
|
205
206
|
self.is_hybrid = model_config.is_hybrid
|
206
207
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
207
208
|
self.attention_chunk_size = model_config.attention_chunk_size
|
208
|
-
|
209
209
|
self.forward_pass_id = 0
|
210
210
|
|
211
|
-
#
|
212
|
-
|
213
|
-
|
211
|
+
# Apply the rank zero filter to logger
|
212
|
+
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
213
|
+
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
214
214
|
if server_args.show_time_cost:
|
215
215
|
enable_show_time_cost()
|
216
216
|
|
217
|
+
# Model-specific adjustment
|
218
|
+
self.model_specific_adjustment()
|
219
|
+
|
217
220
|
# Global vars
|
218
221
|
global_server_args_dict.update(
|
219
222
|
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
@@ -222,15 +225,8 @@ class ModelRunner:
|
|
222
225
|
"use_mla_backend": self.use_mla_backend,
|
223
226
|
"speculative_algorithm": self.spec_algorithm,
|
224
227
|
}
|
225
|
-
| {
|
226
|
-
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
227
|
-
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
228
|
-
}
|
229
228
|
)
|
230
229
|
|
231
|
-
# CPU offload
|
232
|
-
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
233
|
-
|
234
230
|
# Init OpenMP threads binding for CPU
|
235
231
|
if self.device == "cpu":
|
236
232
|
self.init_threads_binding()
|
@@ -238,17 +234,22 @@ class ModelRunner:
|
|
238
234
|
# Get memory before model loading
|
239
235
|
min_per_gpu_memory = self.init_torch_distributed()
|
240
236
|
|
237
|
+
# CPU offload
|
238
|
+
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
239
|
+
|
241
240
|
# Update deep gemm configure
|
242
241
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
243
242
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
244
243
|
|
245
|
-
#
|
244
|
+
# Initialize the model runner
|
246
245
|
self.initialize(min_per_gpu_memory)
|
247
246
|
|
248
|
-
#
|
247
|
+
# Temporary cached values
|
249
248
|
self.support_pp = (
|
250
249
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
251
250
|
)
|
251
|
+
|
252
|
+
# For weight updates
|
252
253
|
self._model_update_group = {}
|
253
254
|
|
254
255
|
def initialize(self, min_per_gpu_memory: float):
|
@@ -277,6 +278,7 @@ class ModelRunner:
|
|
277
278
|
)
|
278
279
|
)
|
279
280
|
|
281
|
+
# Expert parallelism
|
280
282
|
self.eplb_manager = (
|
281
283
|
EPLBManager(self)
|
282
284
|
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
@@ -310,8 +312,13 @@ class ModelRunner:
|
|
310
312
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
311
313
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
312
314
|
self.num_effective_layers = self.end_layer - self.start_layer
|
313
|
-
assert (
|
314
|
-
|
315
|
+
assert (
|
316
|
+
(not model_has_mtp_layers)
|
317
|
+
or (self.spec_algorithm.is_none())
|
318
|
+
or (
|
319
|
+
(not self.spec_algorithm.is_none())
|
320
|
+
and (self.num_effective_layers == model_num_layers)
|
321
|
+
)
|
315
322
|
), "PP is not compatible with MTP models."
|
316
323
|
|
317
324
|
# Apply torchao quantization
|
@@ -340,9 +347,12 @@ class ModelRunner:
|
|
340
347
|
if self.device == "cuda":
|
341
348
|
self.init_cublas()
|
342
349
|
self.init_attention_backend()
|
343
|
-
self.
|
350
|
+
self.init_device_graphs()
|
351
|
+
elif self.device == "npu":
|
352
|
+
self.init_attention_backend()
|
353
|
+
self.init_device_graphs()
|
344
354
|
else:
|
345
|
-
self.
|
355
|
+
self.graph_runner = None
|
346
356
|
self.cuda_graph_mem_usage = 0
|
347
357
|
self.init_attention_backend()
|
348
358
|
|
@@ -509,9 +519,6 @@ class ModelRunner:
|
|
509
519
|
|
510
520
|
if not self.use_mla_backend:
|
511
521
|
server_args.disable_chunked_prefix_cache = True
|
512
|
-
elif self.page_size > 1:
|
513
|
-
logger.info("Disable chunked prefix cache when page size > 1.")
|
514
|
-
server_args.disable_chunked_prefix_cache = True
|
515
522
|
|
516
523
|
if not server_args.disable_chunked_prefix_cache:
|
517
524
|
logger.info("Chunked prefix cache is turned on.")
|
@@ -604,12 +611,8 @@ class ModelRunner:
|
|
604
611
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
605
612
|
)
|
606
613
|
initialize_dp_attention(
|
607
|
-
|
608
|
-
|
609
|
-
tp_size=self.tp_size,
|
610
|
-
dp_size=self.server_args.dp_size,
|
611
|
-
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
612
|
-
pp_size=self.server_args.pp_size,
|
614
|
+
server_args=self.server_args,
|
615
|
+
model_config=self.model_config,
|
613
616
|
)
|
614
617
|
|
615
618
|
min_per_gpu_memory = get_available_gpu_memory(
|
@@ -689,6 +692,8 @@ class ModelRunner:
|
|
689
692
|
monkey_patch_vllm_parallel_state(reverse=True)
|
690
693
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
691
694
|
|
695
|
+
get_offloader().post_init()
|
696
|
+
|
692
697
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
693
698
|
if self.server_args.quantization_param_path is not None:
|
694
699
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
@@ -920,7 +925,8 @@ class ModelRunner:
|
|
920
925
|
)
|
921
926
|
|
922
927
|
# We need to get device after patch otherwise the device would be wrong
|
923
|
-
|
928
|
+
self.device_module = torch.get_device_module(self.device)
|
929
|
+
infered_device = self.device_module.current_device()
|
924
930
|
|
925
931
|
named_tensors = [
|
926
932
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
@@ -1051,8 +1057,6 @@ class ModelRunner:
|
|
1051
1057
|
else:
|
1052
1058
|
num_layers = self.num_effective_layers
|
1053
1059
|
if self.use_mla_backend:
|
1054
|
-
# FIXME: pipeline parallelism is not compatible with mla backend
|
1055
|
-
assert self.pp_size == 1
|
1056
1060
|
cell_size = (
|
1057
1061
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
1058
1062
|
* num_layers
|
@@ -1160,6 +1164,7 @@ class ModelRunner:
|
|
1160
1164
|
max_num_reqs: Optional[int] = None,
|
1161
1165
|
max_total_tokens: Optional[int] = None,
|
1162
1166
|
):
|
1167
|
+
# Determine the kv cache dtype
|
1163
1168
|
if self.server_args.kv_cache_dtype == "auto":
|
1164
1169
|
self.kv_cache_dtype = self.dtype
|
1165
1170
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
@@ -1178,6 +1183,8 @@ class ModelRunner:
|
|
1178
1183
|
)
|
1179
1184
|
|
1180
1185
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
1186
|
+
if SGLANG_CI_SMALL_KV_SIZE:
|
1187
|
+
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
1181
1188
|
|
1182
1189
|
if max_num_reqs is None:
|
1183
1190
|
max_num_reqs = min(
|
@@ -1190,9 +1197,6 @@ class ModelRunner:
|
|
1190
1197
|
4096,
|
1191
1198
|
)
|
1192
1199
|
|
1193
|
-
if SGLANG_CI_SMALL_KV_SIZE:
|
1194
|
-
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
1195
|
-
|
1196
1200
|
if not self.spec_algorithm.is_none():
|
1197
1201
|
if self.is_draft_worker:
|
1198
1202
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
@@ -1239,7 +1243,13 @@ class ModelRunner:
|
|
1239
1243
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
1240
1244
|
)
|
1241
1245
|
|
1246
|
+
# Initialize req_to_token_pool
|
1242
1247
|
if self.req_to_token_pool is None:
|
1248
|
+
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
|
1249
|
+
extra_max_context_len = 4
|
1250
|
+
if self.server_args.speculative_num_draft_tokens is not None:
|
1251
|
+
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
1252
|
+
|
1243
1253
|
if self.server_args.disaggregation_mode == "decode":
|
1244
1254
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
1245
1255
|
|
@@ -1248,7 +1258,8 @@ class ModelRunner:
|
|
1248
1258
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
1249
1259
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
1250
1260
|
size=max_num_reqs,
|
1251
|
-
max_context_len=self.model_config.context_len
|
1261
|
+
max_context_len=self.model_config.context_len
|
1262
|
+
+ extra_max_context_len,
|
1252
1263
|
device=self.device,
|
1253
1264
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1254
1265
|
pre_alloc_size=pre_alloc_size,
|
@@ -1256,7 +1267,8 @@ class ModelRunner:
|
|
1256
1267
|
else:
|
1257
1268
|
self.req_to_token_pool = ReqToTokenPool(
|
1258
1269
|
size=max_num_reqs,
|
1259
|
-
max_context_len=self.model_config.context_len
|
1270
|
+
max_context_len=self.model_config.context_len
|
1271
|
+
+ extra_max_context_len,
|
1260
1272
|
device=self.device,
|
1261
1273
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1262
1274
|
)
|
@@ -1264,6 +1276,7 @@ class ModelRunner:
|
|
1264
1276
|
# Draft worker shares req_to_token_pool with the target worker.
|
1265
1277
|
assert self.is_draft_worker
|
1266
1278
|
|
1279
|
+
# Initialize token_to_kv_pool
|
1267
1280
|
if self.server_args.attention_backend == "ascend":
|
1268
1281
|
if self.use_mla_backend:
|
1269
1282
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
@@ -1349,38 +1362,40 @@ class ModelRunner:
|
|
1349
1362
|
end_layer=self.end_layer,
|
1350
1363
|
)
|
1351
1364
|
|
1365
|
+
# Initialize token_to_kv_pool_allocator
|
1352
1366
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1353
1367
|
if self.token_to_kv_pool_allocator is None:
|
1354
|
-
if self.
|
1355
|
-
|
1356
|
-
self.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
)
|
1364
|
-
else:
|
1365
|
-
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
1366
|
-
self.max_total_num_tokens,
|
1367
|
-
dtype=self.kv_cache_dtype,
|
1368
|
-
device=self.device,
|
1369
|
-
kvcache=self.token_to_kv_pool,
|
1370
|
-
need_sort=need_sort,
|
1371
|
-
)
|
1368
|
+
if self.server_args.attention_backend == "ascend":
|
1369
|
+
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1370
|
+
self.max_total_num_tokens,
|
1371
|
+
page_size=self.page_size,
|
1372
|
+
dtype=self.kv_cache_dtype,
|
1373
|
+
device=self.device,
|
1374
|
+
kvcache=self.token_to_kv_pool,
|
1375
|
+
need_sort=need_sort,
|
1376
|
+
)
|
1372
1377
|
else:
|
1373
|
-
if
|
1374
|
-
self.
|
1375
|
-
self.
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1378
|
+
if self.page_size == 1:
|
1379
|
+
if self.is_hybrid:
|
1380
|
+
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
1381
|
+
self.full_max_total_num_tokens,
|
1382
|
+
self.swa_max_total_num_tokens,
|
1383
|
+
dtype=self.kv_cache_dtype,
|
1384
|
+
device=self.device,
|
1385
|
+
kvcache=self.token_to_kv_pool,
|
1386
|
+
need_sort=need_sort,
|
1387
|
+
)
|
1388
|
+
else:
|
1389
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
1390
|
+
self.max_total_num_tokens,
|
1391
|
+
dtype=self.kv_cache_dtype,
|
1392
|
+
device=self.device,
|
1393
|
+
kvcache=self.token_to_kv_pool,
|
1394
|
+
need_sort=need_sort,
|
1395
|
+
)
|
1382
1396
|
else:
|
1383
|
-
self.
|
1397
|
+
assert not self.is_hybrid
|
1398
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
1384
1399
|
self.max_total_num_tokens,
|
1385
1400
|
page_size=self.page_size,
|
1386
1401
|
dtype=self.kv_cache_dtype,
|
@@ -1554,15 +1569,13 @@ class ModelRunner:
|
|
1554
1569
|
)
|
1555
1570
|
|
1556
1571
|
return TRTLLMHAAttnBackend(self)
|
1557
|
-
|
1558
1572
|
elif backend_str == "intel_amx":
|
1559
1573
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1560
1574
|
IntelAMXAttnBackend,
|
1561
1575
|
)
|
1562
1576
|
|
1563
|
-
logger.info(f"Intel AMX attention backend is enabled.")
|
1564
1577
|
return IntelAMXAttnBackend(self)
|
1565
|
-
elif
|
1578
|
+
elif backend_str == "dual_chunk_flash_attn":
|
1566
1579
|
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1567
1580
|
DualChunkFlashAttentionBackend,
|
1568
1581
|
)
|
@@ -1588,9 +1601,9 @@ class ModelRunner:
|
|
1588
1601
|
.cuda()
|
1589
1602
|
)
|
1590
1603
|
|
1591
|
-
def
|
1604
|
+
def init_device_graphs(self):
|
1592
1605
|
"""Capture cuda graphs."""
|
1593
|
-
self.
|
1606
|
+
self.graph_runner = None
|
1594
1607
|
self.cuda_graph_mem_usage = 0
|
1595
1608
|
|
1596
1609
|
if not self.is_generation:
|
@@ -1605,7 +1618,9 @@ class ModelRunner:
|
|
1605
1618
|
logger.info(
|
1606
1619
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1607
1620
|
)
|
1608
|
-
self.
|
1621
|
+
self.graph_runner = (
|
1622
|
+
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
|
1623
|
+
)
|
1609
1624
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1610
1625
|
self.cuda_graph_mem_usage = before_mem - after_mem
|
1611
1626
|
logger.info(
|
@@ -1757,11 +1772,11 @@ class ModelRunner:
|
|
1757
1772
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1758
1773
|
can_run_cuda_graph = bool(
|
1759
1774
|
forward_batch.forward_mode.is_cuda_graph()
|
1760
|
-
and self.
|
1761
|
-
and self.
|
1775
|
+
and self.graph_runner
|
1776
|
+
and self.graph_runner.can_run(forward_batch)
|
1762
1777
|
)
|
1763
1778
|
if can_run_cuda_graph:
|
1764
|
-
ret = self.
|
1779
|
+
ret = self.graph_runner.replay(
|
1765
1780
|
forward_batch,
|
1766
1781
|
skip_attn_backend_init=skip_attn_backend_init,
|
1767
1782
|
pp_proxy_tensors=pp_proxy_tensors,
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Run the model with npu graph and torch.compile."""
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import threading
|
20
|
+
from typing import TYPE_CHECKING, Optional, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
33
|
+
|
34
|
+
|
35
|
+
class NPUGraphRunner(CudaGraphRunner):
|
36
|
+
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
|
37
|
+
|
38
|
+
def __init__(self, model_runner: ModelRunner):
|
39
|
+
super().__init__(model_runner)
|
40
|
+
|
41
|
+
def _create_device_graph(self):
|
42
|
+
return torch.npu.NPUGraph()
|
43
|
+
|
44
|
+
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
45
|
+
with torch.npu.graph(
|
46
|
+
graph,
|
47
|
+
pool=pool,
|
48
|
+
stream=stream,
|
49
|
+
auto_dispatch_capture=True,
|
50
|
+
):
|
51
|
+
out = run_once_fn()
|
52
|
+
return out
|
53
|
+
|
54
|
+
def _update_inputs(self, seq_lens):
|
55
|
+
self.graphs[self.bs].update(
|
56
|
+
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
|
57
|
+
)
|
58
|
+
|
59
|
+
def _cache_loc_dtype(self):
|
60
|
+
return torch.int32
|
61
|
+
|
62
|
+
def replay(
|
63
|
+
self,
|
64
|
+
forward_batch: ForwardBatch,
|
65
|
+
skip_attn_backend_init: bool = False,
|
66
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
67
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
68
|
+
if not skip_attn_backend_init:
|
69
|
+
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
70
|
+
else:
|
71
|
+
# In speculative decoding, these two fields are still needed.
|
72
|
+
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
73
|
+
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
74
|
+
|
75
|
+
# Replay
|
76
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
|
77
|
+
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
78
|
+
thread.start()
|
79
|
+
self.graphs[self.bs].replay()
|
80
|
+
thread.join()
|
81
|
+
|
82
|
+
output = self.output_buffers[self.bs]
|
83
|
+
if isinstance(output, LogitsProcessorOutput):
|
84
|
+
return LogitsProcessorOutput(
|
85
|
+
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
86
|
+
hidden_states=(
|
87
|
+
output.hidden_states[: self.raw_num_token]
|
88
|
+
if output.hidden_states is not None
|
89
|
+
else None
|
90
|
+
),
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
assert isinstance(output, PPProxyTensors)
|
94
|
+
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|
79
79
|
yield module
|
80
80
|
return
|
81
81
|
|
82
|
-
|
82
|
+
original_infos: Dict[str, Dict] = {}
|
83
83
|
|
84
84
|
# Store original device states and move parameters to GPU if they're on CPU
|
85
85
|
for name, p in module.named_parameters():
|
86
86
|
if p.device.type == "cpu":
|
87
|
-
|
88
|
-
|
87
|
+
original_data = p.data
|
88
|
+
device_data = p.data.to(target_device)
|
89
|
+
original_infos[name] = dict(
|
90
|
+
device=p.device,
|
91
|
+
original_data=original_data,
|
92
|
+
device_data=device_data,
|
93
|
+
)
|
94
|
+
p.data = device_data
|
89
95
|
# Parameters already on target device are not touched
|
90
96
|
|
91
97
|
try:
|
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|
95
101
|
# Restore parameters to their original devices, ignoring new parameters
|
96
102
|
pin_memory = is_pin_memory_available()
|
97
103
|
for name, p in module.named_parameters():
|
98
|
-
if name in
|
99
|
-
|
100
|
-
|
104
|
+
if name in original_infos:
|
105
|
+
original_info = original_infos[name]
|
106
|
+
device_data = original_info["device_data"]
|
107
|
+
original_data = original_info["original_data"]
|
108
|
+
original_device: torch.device = original_info["device"]
|
109
|
+
|
110
|
+
if (
|
111
|
+
(device_data.device == p.data.device)
|
112
|
+
and (device_data.data_ptr() == p.data.data_ptr())
|
113
|
+
and (device_data.shape == p.data.shape)
|
114
|
+
and (device_data.dtype == p.data.dtype)
|
115
|
+
):
|
116
|
+
original_data.copy_(p.data.to(original_data.device))
|
117
|
+
p.data = original_data
|
118
|
+
elif original_device.type == "cpu":
|
101
119
|
# `torch.empty_like` does not support `pin_memory` argument
|
102
120
|
cpu_data = torch.empty_strided(
|
103
121
|
size=p.data.size(),
|
sglang/srt/models/dbrx.py
CHANGED
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
|
|
32
32
|
RowParallelLinear,
|
33
33
|
)
|
34
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
|
-
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
35
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
36
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
+
from sglang.srt.layers.moe.topk import TopK
|
36
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
40
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
|
|
104
106
|
self.params_dtype = params_dtype
|
105
107
|
|
106
108
|
self.router = DbrxRouter(config, self.params_dtype)
|
109
|
+
self.topk = TopK(
|
110
|
+
self.top_k,
|
111
|
+
renormalize=True,
|
112
|
+
)
|
113
|
+
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
107
114
|
self.ws = nn.Parameter(
|
108
115
|
torch.empty(
|
109
116
|
self.num_total_experts,
|
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
|
|
169
176
|
hidden_states = hidden_states.view(-1, self.d_model)
|
170
177
|
# router_logits: (num_tokens, n_experts)
|
171
178
|
router_logits = self.router(hidden_states)
|
179
|
+
topk_output = self.topk(hidden_states, router_logits)
|
172
180
|
final_hidden_states = fused_moe(
|
173
181
|
hidden_states,
|
174
182
|
self.ws,
|
175
183
|
self.w2s,
|
176
|
-
|
177
|
-
self.
|
178
|
-
renormalize=True,
|
179
|
-
inplace=True,
|
184
|
+
topk_output,
|
185
|
+
self.moe_runner_config,
|
180
186
|
)
|
181
187
|
|
182
188
|
if self.tp_size > 1:
|
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
|
|
293
299
|
position_ids: torch.Tensor,
|
294
300
|
hidden_states: torch.Tensor,
|
295
301
|
forward_batch: ForwardBatch,
|
296
|
-
) -> torch.Tensor:
|
302
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
297
303
|
residual = hidden_states
|
298
304
|
hidden_states = self.norm_1(hidden_states)
|
299
305
|
x = self.attn(
|
sglang/srt/models/deepseek.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
40
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
40
41
|
from sglang.srt.layers.moe.topk import TopK
|
41
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
|
|
180
181
|
w1=self.w1,
|
181
182
|
w2=self.w2,
|
182
183
|
topk_output=topk_output,
|
183
|
-
inplace=True,
|
184
|
+
moe_runner_config=MoeRunnerConfig(inplace=True),
|
184
185
|
)
|
185
186
|
|
186
187
|
if self.config.n_shared_experts is not None:
|