sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
|
|
1
|
+
import logging
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import Iterable, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_moe_expert_parallel_world_size,
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
parallel_state,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
17
|
+
from sglang.srt.layers.dp_attention import (
|
18
|
+
get_attention_tp_rank,
|
19
|
+
get_attention_tp_size,
|
20
|
+
get_local_attention_dp_size,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
23
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
24
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
27
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
30
|
+
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
31
|
+
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
32
|
+
|
33
|
+
_is_cuda = is_cuda()
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
cached_get_processor = lru_cache(get_processor)
|
38
|
+
|
39
|
+
|
40
|
+
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: Glm4vMoeConfig,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
prefix: str = "",
|
46
|
+
) -> None:
|
47
|
+
nn.Module.__init__(self)
|
48
|
+
|
49
|
+
config.moe_layer_freq = 1
|
50
|
+
self.config = config
|
51
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
52
|
+
self.dp_size = get_local_attention_dp_size()
|
53
|
+
self.quant_config = quant_config
|
54
|
+
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
55
|
+
self.num_fused_shared_experts = (
|
56
|
+
0
|
57
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
58
|
+
else config.n_shared_experts
|
59
|
+
)
|
60
|
+
|
61
|
+
self.model = Glm4MoeModel(
|
62
|
+
config,
|
63
|
+
quant_config,
|
64
|
+
prefix=add_prefix("language_model", prefix),
|
65
|
+
)
|
66
|
+
self.visual = Glm4vVisionModel(
|
67
|
+
config.vision_config,
|
68
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=add_prefix("visual", prefix),
|
71
|
+
)
|
72
|
+
|
73
|
+
self.lm_head = ParallelLMHead(
|
74
|
+
config.vocab_size,
|
75
|
+
config.hidden_size,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=add_prefix("lm_head", prefix),
|
78
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
79
|
+
)
|
80
|
+
self.logits_processor = LogitsProcessor(config)
|
81
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
82
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
83
|
+
|
84
|
+
def determine_num_fused_shared_experts(
|
85
|
+
self, architecture: str = "Glm4MoeForCausalLM"
|
86
|
+
):
|
87
|
+
self.num_fused_shared_experts = 0
|
88
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
89
|
+
return
|
90
|
+
|
91
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
92
|
+
disable_reason = None
|
93
|
+
if (
|
94
|
+
not _is_cuda
|
95
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
96
|
+
or self.config.architectures[0] != architecture
|
97
|
+
or self.config.n_shared_experts != 1
|
98
|
+
):
|
99
|
+
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
100
|
+
elif get_moe_expert_parallel_world_size() > 1:
|
101
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
102
|
+
|
103
|
+
if disable_reason is not None:
|
104
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
105
|
+
self.num_fused_shared_experts = 0
|
106
|
+
log_info_on_rank0(
|
107
|
+
logger,
|
108
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
109
|
+
)
|
110
|
+
return
|
111
|
+
|
112
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
113
|
+
|
114
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
115
|
+
|
116
|
+
if is_nextn:
|
117
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
118
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
119
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
120
|
+
# compatible with old design
|
121
|
+
nextn_layer_id = (
|
122
|
+
0
|
123
|
+
if self.config.num_hidden_layers == 1
|
124
|
+
else self.config.num_hidden_layers
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
128
|
+
|
129
|
+
stacked_params_mapping = [
|
130
|
+
# (param_name, shard_name, shard_id)
|
131
|
+
("qkv_proj", "q_proj", "q"),
|
132
|
+
("qkv_proj", "k_proj", "k"),
|
133
|
+
("qkv_proj", "v_proj", "v"),
|
134
|
+
("gate_up_proj", "gate_proj", 0),
|
135
|
+
("gate_up_proj", "up_proj", 1),
|
136
|
+
]
|
137
|
+
if self.num_fused_shared_experts > 0:
|
138
|
+
assert self.num_fused_shared_experts == 1
|
139
|
+
weights_list = list(weights)
|
140
|
+
weights_dict = dict(weights_list)
|
141
|
+
if self.quant_config is not None:
|
142
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
143
|
+
suffix_list = [
|
144
|
+
"down_proj.weight",
|
145
|
+
"down_proj.weight_scale",
|
146
|
+
"gate_proj.weight",
|
147
|
+
"gate_proj.weight_scale",
|
148
|
+
"up_proj.weight",
|
149
|
+
"up_proj.weight_scale",
|
150
|
+
]
|
151
|
+
elif (
|
152
|
+
self.quant_config.get_name() == "fp8"
|
153
|
+
or self.quant_config.get_name() == "blockwise_int8"
|
154
|
+
or self.quant_config.get_name() == "compressed_tensors"
|
155
|
+
):
|
156
|
+
suffix_list = [
|
157
|
+
"down_proj.weight",
|
158
|
+
"down_proj.weight_scale",
|
159
|
+
"gate_proj.weight",
|
160
|
+
"gate_proj.weight_scale",
|
161
|
+
"up_proj.weight",
|
162
|
+
"up_proj.weight_scale",
|
163
|
+
]
|
164
|
+
elif self.quant_config.get_name() == "awq":
|
165
|
+
suffix_list = [
|
166
|
+
"down_proj.qweight",
|
167
|
+
"down_proj.qzeros",
|
168
|
+
"down_proj.scales",
|
169
|
+
"gate_proj.qweight",
|
170
|
+
"gate_proj.qzeros",
|
171
|
+
"gate_proj.scales",
|
172
|
+
"up_proj.qweight",
|
173
|
+
"up_proj.qzeros",
|
174
|
+
"up_proj.scales",
|
175
|
+
]
|
176
|
+
elif self.quant_config.get_name() == "modelopt_fp4":
|
177
|
+
suffix_list = [
|
178
|
+
"down_proj.weight",
|
179
|
+
"down_proj.weight_scale",
|
180
|
+
"down_proj.weight_scale_2",
|
181
|
+
"down_proj.input_scale",
|
182
|
+
"gate_proj.weight",
|
183
|
+
"gate_proj.weight_scale",
|
184
|
+
"gate_proj.weight_scale_2",
|
185
|
+
"gate_proj.input_scale",
|
186
|
+
"up_proj.weight",
|
187
|
+
"up_proj.weight_scale",
|
188
|
+
"up_proj.weight_scale_2",
|
189
|
+
"up_proj.input_scale",
|
190
|
+
]
|
191
|
+
else:
|
192
|
+
raise ValueError(
|
193
|
+
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
suffix_list = [
|
197
|
+
"down_proj.weight",
|
198
|
+
"gate_proj.weight",
|
199
|
+
"up_proj.weight",
|
200
|
+
]
|
201
|
+
names_to_remove = []
|
202
|
+
|
203
|
+
moe_layers = (
|
204
|
+
range(
|
205
|
+
self.config.first_k_dense_replace,
|
206
|
+
self.config.num_hidden_layers,
|
207
|
+
self.config.moe_layer_freq,
|
208
|
+
)
|
209
|
+
if not is_nextn
|
210
|
+
else [nextn_layer_id]
|
211
|
+
)
|
212
|
+
|
213
|
+
for moe_layer in moe_layers:
|
214
|
+
for suffix in suffix_list:
|
215
|
+
shared_expert_weight_name = (
|
216
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
217
|
+
)
|
218
|
+
# online fp8 quantization does not load weight_scale
|
219
|
+
if shared_expert_weight_name not in weights_dict:
|
220
|
+
continue
|
221
|
+
weights_list.append(
|
222
|
+
(
|
223
|
+
f"model.layers.{moe_layer}."
|
224
|
+
f"mlp.experts."
|
225
|
+
f"{self.config.n_routed_experts + 0}"
|
226
|
+
f".{suffix}",
|
227
|
+
weights_dict[shared_expert_weight_name],
|
228
|
+
)
|
229
|
+
)
|
230
|
+
names_to_remove += [shared_expert_weight_name]
|
231
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
232
|
+
|
233
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
234
|
+
# (param_name, weight_name, expert_id, shard_id)
|
235
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
236
|
+
ckpt_gate_proj_name="gate_proj",
|
237
|
+
ckpt_down_proj_name="down_proj",
|
238
|
+
ckpt_up_proj_name="up_proj",
|
239
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
240
|
+
)
|
241
|
+
|
242
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
243
|
+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
244
|
+
self.config.q_lora_rank is not None
|
245
|
+
)
|
246
|
+
cached_a_proj = {} if fuse_qkv_a_proj else None
|
247
|
+
|
248
|
+
if is_nextn:
|
249
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
250
|
+
nextn_spec_weight_names = [
|
251
|
+
"shared_head.norm",
|
252
|
+
"eh_proj",
|
253
|
+
"enorm",
|
254
|
+
"hnorm",
|
255
|
+
]
|
256
|
+
|
257
|
+
params_dict = dict(self.named_parameters())
|
258
|
+
weight_names = []
|
259
|
+
for name, loaded_weight in weights:
|
260
|
+
weight_names.append(name)
|
261
|
+
|
262
|
+
if not is_nextn:
|
263
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
264
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
265
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
266
|
+
name_list = name.split(".")
|
267
|
+
if (
|
268
|
+
len(name_list) >= 3
|
269
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
270
|
+
):
|
271
|
+
continue
|
272
|
+
else:
|
273
|
+
if not name.startswith(nextn_layer_prefix):
|
274
|
+
continue
|
275
|
+
|
276
|
+
# Use shared head and embed weights from target model
|
277
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
278
|
+
continue
|
279
|
+
|
280
|
+
is_decoder = True
|
281
|
+
# For nextn specific weights
|
282
|
+
for weight_name in nextn_spec_weight_names:
|
283
|
+
if weight_name in name:
|
284
|
+
name = name.replace(nextn_layer_prefix, "model")
|
285
|
+
is_decoder = False
|
286
|
+
break
|
287
|
+
# For decoder layer weights
|
288
|
+
if is_decoder:
|
289
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
290
|
+
|
291
|
+
if "language_model." in name:
|
292
|
+
name = name.replace("language_model.", "")
|
293
|
+
if "model.visual." in name:
|
294
|
+
name = name.replace("model.visual.", "visual.")
|
295
|
+
if "rotary_emb.inv_freq" in name:
|
296
|
+
continue
|
297
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
298
|
+
# Skip non-stacked layers and experts (experts handled below).
|
299
|
+
if weight_name not in name:
|
300
|
+
continue
|
301
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
302
|
+
# Since we handle the experts below in expert_params_mapping,
|
303
|
+
# we need to skip here BEFORE we update the name, otherwise
|
304
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
305
|
+
# will then be updated below in expert_params_mapping
|
306
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
307
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
308
|
+
continue
|
309
|
+
name = name.replace(weight_name, param_name)
|
310
|
+
# Skip loading extra bias for GPTQ models.
|
311
|
+
if name.endswith(".bias") and name not in params_dict:
|
312
|
+
continue
|
313
|
+
param = params_dict[name]
|
314
|
+
|
315
|
+
weight_loader = param.weight_loader
|
316
|
+
weight_loader(param, loaded_weight, shard_id)
|
317
|
+
break
|
318
|
+
else:
|
319
|
+
for mapping in expert_params_mapping:
|
320
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
321
|
+
if weight_name not in name:
|
322
|
+
continue
|
323
|
+
name = name.replace(weight_name, param_name)
|
324
|
+
param = params_dict[name]
|
325
|
+
weight_loader = param.weight_loader
|
326
|
+
weight_loader(
|
327
|
+
param,
|
328
|
+
loaded_weight,
|
329
|
+
name,
|
330
|
+
shard_id=shard_id,
|
331
|
+
expert_id=expert_id,
|
332
|
+
)
|
333
|
+
break
|
334
|
+
else:
|
335
|
+
if "visual" in name:
|
336
|
+
# adapt to VisionAttention
|
337
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
338
|
+
|
339
|
+
# Skip loading extra bias for GPTQ models.
|
340
|
+
if name.endswith(".bias") and name not in params_dict:
|
341
|
+
continue
|
342
|
+
if fuse_qkv_a_proj and (
|
343
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
344
|
+
):
|
345
|
+
cached_a_proj[name] = loaded_weight
|
346
|
+
q_a_proj_name = (
|
347
|
+
name
|
348
|
+
if "q_a_proj" in name
|
349
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
350
|
+
)
|
351
|
+
kv_a_proj_name = (
|
352
|
+
name
|
353
|
+
if "kv_a_proj_with_mqa" in name
|
354
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
355
|
+
)
|
356
|
+
|
357
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
358
|
+
if (
|
359
|
+
q_a_proj_name in cached_a_proj
|
360
|
+
and kv_a_proj_name in cached_a_proj
|
361
|
+
):
|
362
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
363
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
364
|
+
fused_weight = torch.cat(
|
365
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
366
|
+
)
|
367
|
+
param_name = (
|
368
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
369
|
+
if "q_a_proj" in name
|
370
|
+
else name.replace(
|
371
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
372
|
+
)
|
373
|
+
)
|
374
|
+
param = params_dict[param_name]
|
375
|
+
|
376
|
+
weight_loader = getattr(
|
377
|
+
param, "weight_loader", default_weight_loader
|
378
|
+
)
|
379
|
+
weight_loader(param, fused_weight)
|
380
|
+
cached_a_proj.pop(q_a_proj_name)
|
381
|
+
cached_a_proj.pop(kv_a_proj_name)
|
382
|
+
else:
|
383
|
+
if (
|
384
|
+
"k_scale" in name or "v_scale" in name
|
385
|
+
) and name not in params_dict:
|
386
|
+
# modelopt attn kv scale is named differently
|
387
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
388
|
+
name = name.replace("_proj", "attn_mqa")
|
389
|
+
else:
|
390
|
+
logger.warning(
|
391
|
+
f"Unknown scale found in checkpoint: {name}"
|
392
|
+
)
|
393
|
+
param = params_dict[name]
|
394
|
+
weight_loader = getattr(
|
395
|
+
param, "weight_loader", default_weight_loader
|
396
|
+
)
|
397
|
+
weight_loader(param, loaded_weight)
|
398
|
+
|
399
|
+
|
400
|
+
EntryClass = [Glm4vMoeForConditionalGeneration]
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
|
|
41
41
|
get_attention_tp_rank,
|
42
42
|
get_attention_tp_size,
|
43
43
|
get_local_attention_dp_size,
|
44
|
+
is_dp_attention_enabled,
|
44
45
|
)
|
45
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
47
|
from sglang.srt.layers.linear import (
|
@@ -56,7 +57,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
56
57
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
57
58
|
from sglang.srt.layers.radix_attention import RadixAttention
|
58
59
|
from sglang.srt.layers.rotary_embedding import get_rope
|
59
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
60
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
60
61
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
61
62
|
ParallelLMHead,
|
62
63
|
VocabParallelEmbedding,
|
@@ -64,7 +65,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
64
65
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
66
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
67
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
-
from sglang.srt.utils import
|
68
|
+
from sglang.srt.utils import (
|
69
|
+
LazyValue,
|
70
|
+
add_prefix,
|
71
|
+
is_cuda,
|
72
|
+
is_flashinfer_available,
|
73
|
+
make_layers,
|
74
|
+
)
|
75
|
+
|
76
|
+
_is_cuda = is_cuda()
|
77
|
+
_is_flashinfer_available = is_flashinfer_available()
|
78
|
+
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
79
|
+
|
80
|
+
|
81
|
+
if _is_cuda:
|
82
|
+
from sgl_kernel import FusedSetKVBufferArg
|
68
83
|
|
69
84
|
|
70
85
|
class GptOssConfig(PretrainedConfig):
|
@@ -151,10 +166,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
151
166
|
)
|
152
167
|
|
153
168
|
def forward(
|
154
|
-
self,
|
169
|
+
self,
|
170
|
+
hidden_states: torch.Tensor,
|
171
|
+
forward_batch: Optional[ForwardBatch] = None,
|
172
|
+
should_allreduce_fusion: bool = False,
|
155
173
|
) -> torch.Tensor:
|
156
174
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
157
|
-
return self.forward_normal(hidden_states)
|
175
|
+
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
158
176
|
else:
|
159
177
|
raise Exception("forward_deepep branch not implemented yet")
|
160
178
|
|
@@ -165,7 +183,11 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
165
183
|
if name not in ["correction_bias"]
|
166
184
|
]
|
167
185
|
|
168
|
-
def forward_normal(
|
186
|
+
def forward_normal(
|
187
|
+
self,
|
188
|
+
hidden_states: torch.Tensor,
|
189
|
+
should_allreduce_fusion: bool = False,
|
190
|
+
) -> torch.Tensor:
|
169
191
|
num_tokens, hidden_dim = hidden_states.shape
|
170
192
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
171
193
|
|
@@ -179,13 +201,39 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
179
201
|
kwargs["topk_output"] = (self.top_k, router_logits)
|
180
202
|
final_hidden_states = self.experts(**kwargs)
|
181
203
|
|
182
|
-
if self.tp_size > 1:
|
204
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
183
205
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
184
206
|
|
185
207
|
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
186
208
|
return ans
|
187
209
|
|
188
210
|
|
211
|
+
def _enable_fused_set_kv_buffer():
|
212
|
+
return _is_cuda
|
213
|
+
|
214
|
+
|
215
|
+
# TODO maybe move to a model-common utils
|
216
|
+
def _create_fused_set_kv_buffer_arg(
|
217
|
+
value: torch.Tensor,
|
218
|
+
layer: RadixAttention,
|
219
|
+
forward_batch: ForwardBatch,
|
220
|
+
):
|
221
|
+
layer_id = layer.layer_id
|
222
|
+
token_to_kv_pool = forward_batch.token_to_kv_pool
|
223
|
+
|
224
|
+
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
225
|
+
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
226
|
+
|
227
|
+
return FusedSetKVBufferArg(
|
228
|
+
value=value,
|
229
|
+
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
230
|
+
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
231
|
+
k_scale=layer.k_scale,
|
232
|
+
v_scale=layer.v_scale,
|
233
|
+
cache_loc=forward_batch.out_cache_loc,
|
234
|
+
)
|
235
|
+
|
236
|
+
|
189
237
|
class GptOssAttention(nn.Module):
|
190
238
|
def __init__(
|
191
239
|
self,
|
@@ -246,8 +294,12 @@ class GptOssAttention(nn.Module):
|
|
246
294
|
prefix=add_prefix("qkv_proj", prefix),
|
247
295
|
)
|
248
296
|
|
297
|
+
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
|
298
|
+
# others can use bfloat16
|
299
|
+
attn_backend = global_server_args_dict.get("attention_backend")
|
300
|
+
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
|
249
301
|
self.sinks = nn.Parameter(
|
250
|
-
torch.empty(self.num_heads, dtype=
|
302
|
+
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
|
251
303
|
)
|
252
304
|
|
253
305
|
self.o_proj = RowParallelLinear(
|
@@ -293,7 +345,21 @@ class GptOssAttention(nn.Module):
|
|
293
345
|
return hidden_states, forward_batch, None
|
294
346
|
qkv, _ = self.qkv_proj(hidden_states)
|
295
347
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
296
|
-
|
348
|
+
|
349
|
+
q, k = self.rotary_emb(
|
350
|
+
positions,
|
351
|
+
q,
|
352
|
+
k,
|
353
|
+
fused_set_kv_buffer_arg=(
|
354
|
+
_create_fused_set_kv_buffer_arg(
|
355
|
+
value=v,
|
356
|
+
layer=self.attn,
|
357
|
+
forward_batch=forward_batch,
|
358
|
+
)
|
359
|
+
if _enable_fused_set_kv_buffer()
|
360
|
+
else None
|
361
|
+
),
|
362
|
+
)
|
297
363
|
inner_state = q, k, v, forward_batch
|
298
364
|
return None, forward_batch, inner_state
|
299
365
|
|
@@ -301,7 +367,11 @@ class GptOssAttention(nn.Module):
|
|
301
367
|
hidden_states, forward_batch, inner_state = intermediate_state
|
302
368
|
if inner_state is None:
|
303
369
|
return hidden_states
|
304
|
-
attn_output = self.attn(
|
370
|
+
attn_output = self.attn(
|
371
|
+
*inner_state,
|
372
|
+
sinks=self.sinks,
|
373
|
+
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
374
|
+
)
|
305
375
|
output, _ = self.o_proj(attn_output)
|
306
376
|
return output
|
307
377
|
|
@@ -370,6 +440,7 @@ class GptOssDecoderLayer(nn.Module):
|
|
370
440
|
|
371
441
|
# GptOss all layers are sparse and have no nextn now
|
372
442
|
self.is_layer_sparse = True
|
443
|
+
self.is_nextn = False
|
373
444
|
is_previous_layer_sparse = True
|
374
445
|
|
375
446
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
@@ -402,6 +473,42 @@ class GptOssDecoderLayer(nn.Module):
|
|
402
473
|
post_attention_layernorm=self.post_attention_layernorm,
|
403
474
|
)
|
404
475
|
|
476
|
+
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
477
|
+
|
478
|
+
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
479
|
+
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
480
|
+
|
481
|
+
batch_size = (
|
482
|
+
forward_batch.input_ids.shape[0]
|
483
|
+
if hasattr(forward_batch, "input_ids")
|
484
|
+
else 0
|
485
|
+
)
|
486
|
+
|
487
|
+
if batch_size > 128:
|
488
|
+
return False
|
489
|
+
|
490
|
+
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
491
|
+
|
492
|
+
def _build_fuse_allreduce_lookup_table(self):
|
493
|
+
static_conditions_met = (
|
494
|
+
self.layer_id != self.config.num_hidden_layers - 1
|
495
|
+
and get_tensor_model_parallel_world_size() > 1
|
496
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
497
|
+
and _is_sm100_supported
|
498
|
+
and _is_flashinfer_available
|
499
|
+
)
|
500
|
+
|
501
|
+
if not static_conditions_met:
|
502
|
+
return {}
|
503
|
+
|
504
|
+
lookup_table = {}
|
505
|
+
for batch_size in range(129): # 0 to 128
|
506
|
+
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
507
|
+
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
508
|
+
lookup_table[batch_size] = should_fuse
|
509
|
+
|
510
|
+
return lookup_table
|
511
|
+
|
405
512
|
def forward(
|
406
513
|
self,
|
407
514
|
positions: torch.Tensor,
|
@@ -424,12 +531,21 @@ class GptOssDecoderLayer(nn.Module):
|
|
424
531
|
hidden_states, residual, forward_batch
|
425
532
|
)
|
426
533
|
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
hidden_states, residual, forward_batch
|
534
|
+
should_allreduce_fusion = (
|
535
|
+
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
536
|
+
and not self.is_nextn
|
431
537
|
)
|
432
538
|
|
539
|
+
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
540
|
+
|
541
|
+
if should_allreduce_fusion:
|
542
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
543
|
+
|
544
|
+
if not should_allreduce_fusion:
|
545
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
546
|
+
hidden_states, residual, forward_batch
|
547
|
+
)
|
548
|
+
|
433
549
|
return hidden_states, residual
|
434
550
|
|
435
551
|
|
@@ -450,7 +566,7 @@ class GptOssModel(nn.Module):
|
|
450
566
|
self.embed_tokens = VocabParallelEmbedding(
|
451
567
|
config.vocab_size,
|
452
568
|
config.hidden_size,
|
453
|
-
enable_tp=not
|
569
|
+
enable_tp=not is_dp_attention_enabled(),
|
454
570
|
prefix=add_prefix("embed_tokens", prefix),
|
455
571
|
)
|
456
572
|
else:
|
@@ -550,6 +666,18 @@ class GptOssForCausalLM(nn.Module):
|
|
550
666
|
self.logits_processor = LogitsProcessor(config)
|
551
667
|
self.capture_aux_hidden_states = False
|
552
668
|
|
669
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
670
|
+
lambda: {
|
671
|
+
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
672
|
+
for layer_id in range(self.start_layer, self.end_layer)
|
673
|
+
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
674
|
+
}
|
675
|
+
)
|
676
|
+
|
677
|
+
@property
|
678
|
+
def routed_experts_weights_of_layer(self):
|
679
|
+
return self._routed_experts_weights_of_layer.value
|
680
|
+
|
553
681
|
@torch.no_grad()
|
554
682
|
def forward(
|
555
683
|
self,
|
@@ -1033,12 +1161,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1033
1161
|
else:
|
1034
1162
|
logging.info("All parameters loaded successfully.")
|
1035
1163
|
|
1036
|
-
self.routed_experts_weights_of_layer = {
|
1037
|
-
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
1038
|
-
for layer_id in range(self.start_layer, self.end_layer)
|
1039
|
-
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
1040
|
-
}
|
1041
|
-
|
1042
1164
|
def get_embed_and_head(self):
|
1043
1165
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
1044
1166
|
|