sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
|
|
1
|
+
import logging
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import Iterable, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_moe_expert_parallel_world_size,
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
parallel_state,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
17
|
+
from sglang.srt.layers.dp_attention import (
|
18
|
+
get_attention_tp_rank,
|
19
|
+
get_attention_tp_size,
|
20
|
+
get_local_attention_dp_size,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
23
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
24
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
27
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
30
|
+
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
31
|
+
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
32
|
+
|
33
|
+
_is_cuda = is_cuda()
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
cached_get_processor = lru_cache(get_processor)
|
38
|
+
|
39
|
+
|
40
|
+
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: Glm4vMoeConfig,
|
44
|
+
quant_config: Optional[QuantizationConfig] = None,
|
45
|
+
prefix: str = "",
|
46
|
+
) -> None:
|
47
|
+
nn.Module.__init__(self)
|
48
|
+
|
49
|
+
config.moe_layer_freq = 1
|
50
|
+
self.config = config
|
51
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
52
|
+
self.dp_size = get_local_attention_dp_size()
|
53
|
+
self.quant_config = quant_config
|
54
|
+
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
55
|
+
self.num_fused_shared_experts = (
|
56
|
+
0
|
57
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
58
|
+
else config.n_shared_experts
|
59
|
+
)
|
60
|
+
|
61
|
+
self.model = Glm4MoeModel(
|
62
|
+
config,
|
63
|
+
quant_config,
|
64
|
+
prefix=add_prefix("language_model", prefix),
|
65
|
+
)
|
66
|
+
self.visual = Glm4vVisionModel(
|
67
|
+
config.vision_config,
|
68
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=add_prefix("visual", prefix),
|
71
|
+
)
|
72
|
+
|
73
|
+
self.lm_head = ParallelLMHead(
|
74
|
+
config.vocab_size,
|
75
|
+
config.hidden_size,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=add_prefix("lm_head", prefix),
|
78
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
79
|
+
)
|
80
|
+
self.logits_processor = LogitsProcessor(config)
|
81
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
82
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
83
|
+
|
84
|
+
def determine_num_fused_shared_experts(
|
85
|
+
self, architecture: str = "Glm4MoeForCausalLM"
|
86
|
+
):
|
87
|
+
self.num_fused_shared_experts = 0
|
88
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
89
|
+
return
|
90
|
+
|
91
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
92
|
+
disable_reason = None
|
93
|
+
if (
|
94
|
+
not _is_cuda
|
95
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
96
|
+
or self.config.architectures[0] != architecture
|
97
|
+
or self.config.n_shared_experts != 1
|
98
|
+
):
|
99
|
+
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
100
|
+
elif get_moe_expert_parallel_world_size() > 1:
|
101
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
102
|
+
|
103
|
+
if disable_reason is not None:
|
104
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
105
|
+
self.num_fused_shared_experts = 0
|
106
|
+
log_info_on_rank0(
|
107
|
+
logger,
|
108
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
109
|
+
)
|
110
|
+
return
|
111
|
+
|
112
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
113
|
+
|
114
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
115
|
+
|
116
|
+
if is_nextn:
|
117
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
118
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
119
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
120
|
+
# compatible with old design
|
121
|
+
nextn_layer_id = (
|
122
|
+
0
|
123
|
+
if self.config.num_hidden_layers == 1
|
124
|
+
else self.config.num_hidden_layers
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
128
|
+
|
129
|
+
stacked_params_mapping = [
|
130
|
+
# (param_name, shard_name, shard_id)
|
131
|
+
("qkv_proj", "q_proj", "q"),
|
132
|
+
("qkv_proj", "k_proj", "k"),
|
133
|
+
("qkv_proj", "v_proj", "v"),
|
134
|
+
("gate_up_proj", "gate_proj", 0),
|
135
|
+
("gate_up_proj", "up_proj", 1),
|
136
|
+
]
|
137
|
+
if self.num_fused_shared_experts > 0:
|
138
|
+
assert self.num_fused_shared_experts == 1
|
139
|
+
weights_list = list(weights)
|
140
|
+
weights_dict = dict(weights_list)
|
141
|
+
if self.quant_config is not None:
|
142
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
143
|
+
suffix_list = [
|
144
|
+
"down_proj.weight",
|
145
|
+
"down_proj.weight_scale",
|
146
|
+
"gate_proj.weight",
|
147
|
+
"gate_proj.weight_scale",
|
148
|
+
"up_proj.weight",
|
149
|
+
"up_proj.weight_scale",
|
150
|
+
]
|
151
|
+
elif (
|
152
|
+
self.quant_config.get_name() == "fp8"
|
153
|
+
or self.quant_config.get_name() == "blockwise_int8"
|
154
|
+
or self.quant_config.get_name() == "compressed_tensors"
|
155
|
+
):
|
156
|
+
suffix_list = [
|
157
|
+
"down_proj.weight",
|
158
|
+
"down_proj.weight_scale",
|
159
|
+
"gate_proj.weight",
|
160
|
+
"gate_proj.weight_scale",
|
161
|
+
"up_proj.weight",
|
162
|
+
"up_proj.weight_scale",
|
163
|
+
]
|
164
|
+
elif self.quant_config.get_name() == "awq":
|
165
|
+
suffix_list = [
|
166
|
+
"down_proj.qweight",
|
167
|
+
"down_proj.qzeros",
|
168
|
+
"down_proj.scales",
|
169
|
+
"gate_proj.qweight",
|
170
|
+
"gate_proj.qzeros",
|
171
|
+
"gate_proj.scales",
|
172
|
+
"up_proj.qweight",
|
173
|
+
"up_proj.qzeros",
|
174
|
+
"up_proj.scales",
|
175
|
+
]
|
176
|
+
elif self.quant_config.get_name() == "modelopt_fp4":
|
177
|
+
suffix_list = [
|
178
|
+
"down_proj.weight",
|
179
|
+
"down_proj.weight_scale",
|
180
|
+
"down_proj.weight_scale_2",
|
181
|
+
"down_proj.input_scale",
|
182
|
+
"gate_proj.weight",
|
183
|
+
"gate_proj.weight_scale",
|
184
|
+
"gate_proj.weight_scale_2",
|
185
|
+
"gate_proj.input_scale",
|
186
|
+
"up_proj.weight",
|
187
|
+
"up_proj.weight_scale",
|
188
|
+
"up_proj.weight_scale_2",
|
189
|
+
"up_proj.input_scale",
|
190
|
+
]
|
191
|
+
else:
|
192
|
+
raise ValueError(
|
193
|
+
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
suffix_list = [
|
197
|
+
"down_proj.weight",
|
198
|
+
"gate_proj.weight",
|
199
|
+
"up_proj.weight",
|
200
|
+
]
|
201
|
+
names_to_remove = []
|
202
|
+
|
203
|
+
moe_layers = (
|
204
|
+
range(
|
205
|
+
self.config.first_k_dense_replace,
|
206
|
+
self.config.num_hidden_layers,
|
207
|
+
self.config.moe_layer_freq,
|
208
|
+
)
|
209
|
+
if not is_nextn
|
210
|
+
else [nextn_layer_id]
|
211
|
+
)
|
212
|
+
|
213
|
+
for moe_layer in moe_layers:
|
214
|
+
for suffix in suffix_list:
|
215
|
+
shared_expert_weight_name = (
|
216
|
+
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
217
|
+
)
|
218
|
+
# online fp8 quantization does not load weight_scale
|
219
|
+
if shared_expert_weight_name not in weights_dict:
|
220
|
+
continue
|
221
|
+
weights_list.append(
|
222
|
+
(
|
223
|
+
f"model.layers.{moe_layer}."
|
224
|
+
f"mlp.experts."
|
225
|
+
f"{self.config.n_routed_experts + 0}"
|
226
|
+
f".{suffix}",
|
227
|
+
weights_dict[shared_expert_weight_name],
|
228
|
+
)
|
229
|
+
)
|
230
|
+
names_to_remove += [shared_expert_weight_name]
|
231
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
232
|
+
|
233
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
234
|
+
# (param_name, weight_name, expert_id, shard_id)
|
235
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
236
|
+
ckpt_gate_proj_name="gate_proj",
|
237
|
+
ckpt_down_proj_name="down_proj",
|
238
|
+
ckpt_up_proj_name="up_proj",
|
239
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
240
|
+
)
|
241
|
+
|
242
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
243
|
+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
244
|
+
self.config.q_lora_rank is not None
|
245
|
+
)
|
246
|
+
cached_a_proj = {} if fuse_qkv_a_proj else None
|
247
|
+
|
248
|
+
if is_nextn:
|
249
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
250
|
+
nextn_spec_weight_names = [
|
251
|
+
"shared_head.norm",
|
252
|
+
"eh_proj",
|
253
|
+
"enorm",
|
254
|
+
"hnorm",
|
255
|
+
]
|
256
|
+
|
257
|
+
params_dict = dict(self.named_parameters())
|
258
|
+
weight_names = []
|
259
|
+
for name, loaded_weight in weights:
|
260
|
+
weight_names.append(name)
|
261
|
+
|
262
|
+
if not is_nextn:
|
263
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
264
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
265
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
266
|
+
name_list = name.split(".")
|
267
|
+
if (
|
268
|
+
len(name_list) >= 3
|
269
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
270
|
+
):
|
271
|
+
continue
|
272
|
+
else:
|
273
|
+
if not name.startswith(nextn_layer_prefix):
|
274
|
+
continue
|
275
|
+
|
276
|
+
# Use shared head and embed weights from target model
|
277
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
278
|
+
continue
|
279
|
+
|
280
|
+
is_decoder = True
|
281
|
+
# For nextn specific weights
|
282
|
+
for weight_name in nextn_spec_weight_names:
|
283
|
+
if weight_name in name:
|
284
|
+
name = name.replace(nextn_layer_prefix, "model")
|
285
|
+
is_decoder = False
|
286
|
+
break
|
287
|
+
# For decoder layer weights
|
288
|
+
if is_decoder:
|
289
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
290
|
+
|
291
|
+
if "language_model." in name:
|
292
|
+
name = name.replace("language_model.", "")
|
293
|
+
if "model.visual." in name:
|
294
|
+
name = name.replace("model.visual.", "visual.")
|
295
|
+
if "rotary_emb.inv_freq" in name:
|
296
|
+
continue
|
297
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
298
|
+
# Skip non-stacked layers and experts (experts handled below).
|
299
|
+
if weight_name not in name:
|
300
|
+
continue
|
301
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
302
|
+
# Since we handle the experts below in expert_params_mapping,
|
303
|
+
# we need to skip here BEFORE we update the name, otherwise
|
304
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
305
|
+
# will then be updated below in expert_params_mapping
|
306
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
307
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
308
|
+
continue
|
309
|
+
name = name.replace(weight_name, param_name)
|
310
|
+
# Skip loading extra bias for GPTQ models.
|
311
|
+
if name.endswith(".bias") and name not in params_dict:
|
312
|
+
continue
|
313
|
+
param = params_dict[name]
|
314
|
+
|
315
|
+
weight_loader = param.weight_loader
|
316
|
+
weight_loader(param, loaded_weight, shard_id)
|
317
|
+
break
|
318
|
+
else:
|
319
|
+
for mapping in expert_params_mapping:
|
320
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
321
|
+
if weight_name not in name:
|
322
|
+
continue
|
323
|
+
name = name.replace(weight_name, param_name)
|
324
|
+
param = params_dict[name]
|
325
|
+
weight_loader = param.weight_loader
|
326
|
+
weight_loader(
|
327
|
+
param,
|
328
|
+
loaded_weight,
|
329
|
+
name,
|
330
|
+
shard_id=shard_id,
|
331
|
+
expert_id=expert_id,
|
332
|
+
)
|
333
|
+
break
|
334
|
+
else:
|
335
|
+
if "visual" in name:
|
336
|
+
# adapt to VisionAttention
|
337
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
338
|
+
|
339
|
+
# Skip loading extra bias for GPTQ models.
|
340
|
+
if name.endswith(".bias") and name not in params_dict:
|
341
|
+
continue
|
342
|
+
if fuse_qkv_a_proj and (
|
343
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
344
|
+
):
|
345
|
+
cached_a_proj[name] = loaded_weight
|
346
|
+
q_a_proj_name = (
|
347
|
+
name
|
348
|
+
if "q_a_proj" in name
|
349
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
350
|
+
)
|
351
|
+
kv_a_proj_name = (
|
352
|
+
name
|
353
|
+
if "kv_a_proj_with_mqa" in name
|
354
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
355
|
+
)
|
356
|
+
|
357
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
358
|
+
if (
|
359
|
+
q_a_proj_name in cached_a_proj
|
360
|
+
and kv_a_proj_name in cached_a_proj
|
361
|
+
):
|
362
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
363
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
364
|
+
fused_weight = torch.cat(
|
365
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
366
|
+
)
|
367
|
+
param_name = (
|
368
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
369
|
+
if "q_a_proj" in name
|
370
|
+
else name.replace(
|
371
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
372
|
+
)
|
373
|
+
)
|
374
|
+
param = params_dict[param_name]
|
375
|
+
|
376
|
+
weight_loader = getattr(
|
377
|
+
param, "weight_loader", default_weight_loader
|
378
|
+
)
|
379
|
+
weight_loader(param, fused_weight)
|
380
|
+
cached_a_proj.pop(q_a_proj_name)
|
381
|
+
cached_a_proj.pop(kv_a_proj_name)
|
382
|
+
else:
|
383
|
+
if (
|
384
|
+
"k_scale" in name or "v_scale" in name
|
385
|
+
) and name not in params_dict:
|
386
|
+
# modelopt attn kv scale is named differently
|
387
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
388
|
+
name = name.replace("_proj", "attn_mqa")
|
389
|
+
else:
|
390
|
+
logger.warning(
|
391
|
+
f"Unknown scale found in checkpoint: {name}"
|
392
|
+
)
|
393
|
+
param = params_dict[name]
|
394
|
+
weight_loader = getattr(
|
395
|
+
param, "weight_loader", default_weight_loader
|
396
|
+
)
|
397
|
+
weight_loader(param, loaded_weight)
|
398
|
+
|
399
|
+
|
400
|
+
EntryClass = [Glm4vMoeForConditionalGeneration]
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
56
56
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
57
57
|
from sglang.srt.layers.radix_attention import RadixAttention
|
58
58
|
from sglang.srt.layers.rotary_embedding import get_rope
|
59
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
59
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
60
60
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
61
61
|
ParallelLMHead,
|
62
62
|
VocabParallelEmbedding,
|
@@ -64,7 +64,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
64
64
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
65
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
-
from sglang.srt.utils import
|
67
|
+
from sglang.srt.utils import (
|
68
|
+
LazyValue,
|
69
|
+
add_prefix,
|
70
|
+
is_cuda,
|
71
|
+
is_flashinfer_available,
|
72
|
+
make_layers,
|
73
|
+
)
|
74
|
+
|
75
|
+
_is_cuda = is_cuda()
|
76
|
+
_is_flashinfer_available = is_flashinfer_available()
|
77
|
+
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
78
|
+
|
79
|
+
|
80
|
+
if _is_cuda:
|
81
|
+
from sgl_kernel import FusedSetKVBufferArg
|
68
82
|
|
69
83
|
|
70
84
|
class GptOssConfig(PretrainedConfig):
|
@@ -151,10 +165,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
151
165
|
)
|
152
166
|
|
153
167
|
def forward(
|
154
|
-
self,
|
168
|
+
self,
|
169
|
+
hidden_states: torch.Tensor,
|
170
|
+
forward_batch: Optional[ForwardBatch] = None,
|
171
|
+
should_allreduce_fusion: bool = False,
|
155
172
|
) -> torch.Tensor:
|
156
173
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
157
|
-
return self.forward_normal(hidden_states)
|
174
|
+
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
158
175
|
else:
|
159
176
|
raise Exception("forward_deepep branch not implemented yet")
|
160
177
|
|
@@ -165,7 +182,11 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
165
182
|
if name not in ["correction_bias"]
|
166
183
|
]
|
167
184
|
|
168
|
-
def forward_normal(
|
185
|
+
def forward_normal(
|
186
|
+
self,
|
187
|
+
hidden_states: torch.Tensor,
|
188
|
+
should_allreduce_fusion: bool = False,
|
189
|
+
) -> torch.Tensor:
|
169
190
|
num_tokens, hidden_dim = hidden_states.shape
|
170
191
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
171
192
|
|
@@ -179,13 +200,39 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
179
200
|
kwargs["topk_output"] = (self.top_k, router_logits)
|
180
201
|
final_hidden_states = self.experts(**kwargs)
|
181
202
|
|
182
|
-
if self.tp_size > 1:
|
203
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
183
204
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
184
205
|
|
185
206
|
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
186
207
|
return ans
|
187
208
|
|
188
209
|
|
210
|
+
def _enable_fused_set_kv_buffer():
|
211
|
+
return _is_cuda
|
212
|
+
|
213
|
+
|
214
|
+
# TODO maybe move to a model-common utils
|
215
|
+
def _create_fused_set_kv_buffer_arg(
|
216
|
+
value: torch.Tensor,
|
217
|
+
layer: RadixAttention,
|
218
|
+
forward_batch: ForwardBatch,
|
219
|
+
):
|
220
|
+
layer_id = layer.layer_id
|
221
|
+
token_to_kv_pool = forward_batch.token_to_kv_pool
|
222
|
+
|
223
|
+
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
224
|
+
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
225
|
+
|
226
|
+
return FusedSetKVBufferArg(
|
227
|
+
value=value,
|
228
|
+
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
229
|
+
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
230
|
+
k_scale=layer.k_scale,
|
231
|
+
v_scale=layer.v_scale,
|
232
|
+
cache_loc=forward_batch.out_cache_loc,
|
233
|
+
)
|
234
|
+
|
235
|
+
|
189
236
|
class GptOssAttention(nn.Module):
|
190
237
|
def __init__(
|
191
238
|
self,
|
@@ -247,7 +294,7 @@ class GptOssAttention(nn.Module):
|
|
247
294
|
)
|
248
295
|
|
249
296
|
self.sinks = nn.Parameter(
|
250
|
-
torch.empty(self.num_heads, dtype=
|
297
|
+
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
|
251
298
|
)
|
252
299
|
|
253
300
|
self.o_proj = RowParallelLinear(
|
@@ -293,7 +340,21 @@ class GptOssAttention(nn.Module):
|
|
293
340
|
return hidden_states, forward_batch, None
|
294
341
|
qkv, _ = self.qkv_proj(hidden_states)
|
295
342
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
296
|
-
|
343
|
+
|
344
|
+
q, k = self.rotary_emb(
|
345
|
+
positions,
|
346
|
+
q,
|
347
|
+
k,
|
348
|
+
fused_set_kv_buffer_arg=(
|
349
|
+
_create_fused_set_kv_buffer_arg(
|
350
|
+
value=v,
|
351
|
+
layer=self.attn,
|
352
|
+
forward_batch=forward_batch,
|
353
|
+
)
|
354
|
+
if _enable_fused_set_kv_buffer()
|
355
|
+
else None
|
356
|
+
),
|
357
|
+
)
|
297
358
|
inner_state = q, k, v, forward_batch
|
298
359
|
return None, forward_batch, inner_state
|
299
360
|
|
@@ -301,7 +362,11 @@ class GptOssAttention(nn.Module):
|
|
301
362
|
hidden_states, forward_batch, inner_state = intermediate_state
|
302
363
|
if inner_state is None:
|
303
364
|
return hidden_states
|
304
|
-
attn_output = self.attn(
|
365
|
+
attn_output = self.attn(
|
366
|
+
*inner_state,
|
367
|
+
sinks=self.sinks,
|
368
|
+
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
369
|
+
)
|
305
370
|
output, _ = self.o_proj(attn_output)
|
306
371
|
return output
|
307
372
|
|
@@ -370,6 +435,7 @@ class GptOssDecoderLayer(nn.Module):
|
|
370
435
|
|
371
436
|
# GptOss all layers are sparse and have no nextn now
|
372
437
|
self.is_layer_sparse = True
|
438
|
+
self.is_nextn = False
|
373
439
|
is_previous_layer_sparse = True
|
374
440
|
|
375
441
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
@@ -402,6 +468,42 @@ class GptOssDecoderLayer(nn.Module):
|
|
402
468
|
post_attention_layernorm=self.post_attention_layernorm,
|
403
469
|
)
|
404
470
|
|
471
|
+
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
472
|
+
|
473
|
+
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
474
|
+
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
475
|
+
|
476
|
+
batch_size = (
|
477
|
+
forward_batch.input_ids.shape[0]
|
478
|
+
if hasattr(forward_batch, "input_ids")
|
479
|
+
else 0
|
480
|
+
)
|
481
|
+
|
482
|
+
if batch_size > 128:
|
483
|
+
return False
|
484
|
+
|
485
|
+
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
486
|
+
|
487
|
+
def _build_fuse_allreduce_lookup_table(self):
|
488
|
+
static_conditions_met = (
|
489
|
+
self.layer_id != self.config.num_hidden_layers - 1
|
490
|
+
and get_tensor_model_parallel_world_size() > 1
|
491
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
492
|
+
and _is_sm100_supported
|
493
|
+
and _is_flashinfer_available
|
494
|
+
)
|
495
|
+
|
496
|
+
if not static_conditions_met:
|
497
|
+
return {}
|
498
|
+
|
499
|
+
lookup_table = {}
|
500
|
+
for batch_size in range(129): # 0 to 128
|
501
|
+
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
502
|
+
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
503
|
+
lookup_table[batch_size] = should_fuse
|
504
|
+
|
505
|
+
return lookup_table
|
506
|
+
|
405
507
|
def forward(
|
406
508
|
self,
|
407
509
|
positions: torch.Tensor,
|
@@ -424,12 +526,21 @@ class GptOssDecoderLayer(nn.Module):
|
|
424
526
|
hidden_states, residual, forward_batch
|
425
527
|
)
|
426
528
|
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
hidden_states, residual, forward_batch
|
529
|
+
should_allreduce_fusion = (
|
530
|
+
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
531
|
+
and not self.is_nextn
|
431
532
|
)
|
432
533
|
|
534
|
+
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
535
|
+
|
536
|
+
if should_allreduce_fusion:
|
537
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
538
|
+
|
539
|
+
if not should_allreduce_fusion:
|
540
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
541
|
+
hidden_states, residual, forward_batch
|
542
|
+
)
|
543
|
+
|
433
544
|
return hidden_states, residual
|
434
545
|
|
435
546
|
|
@@ -550,6 +661,18 @@ class GptOssForCausalLM(nn.Module):
|
|
550
661
|
self.logits_processor = LogitsProcessor(config)
|
551
662
|
self.capture_aux_hidden_states = False
|
552
663
|
|
664
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
665
|
+
lambda: {
|
666
|
+
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
667
|
+
for layer_id in range(self.start_layer, self.end_layer)
|
668
|
+
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
669
|
+
}
|
670
|
+
)
|
671
|
+
|
672
|
+
@property
|
673
|
+
def routed_experts_weights_of_layer(self):
|
674
|
+
return self._routed_experts_weights_of_layer.value
|
675
|
+
|
553
676
|
@torch.no_grad()
|
554
677
|
def forward(
|
555
678
|
self,
|
@@ -1033,12 +1156,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1033
1156
|
else:
|
1034
1157
|
logging.info("All parameters loaded successfully.")
|
1035
1158
|
|
1036
|
-
self.routed_experts_weights_of_layer = {
|
1037
|
-
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
1038
|
-
for layer_id in range(self.start_layer, self.end_layer)
|
1039
|
-
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
1040
|
-
}
|
1041
|
-
|
1042
1159
|
def get_embed_and_head(self):
|
1043
1160
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
1044
1161
|
|
sglang/srt/models/granite.py
CHANGED
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
|
|
363
363
|
else:
|
364
364
|
return self.pooler(hidden_states, forward_batch)
|
365
365
|
|
366
|
-
def get_hidden_dim(self, module_name):
|
367
|
-
# return input_dim, output_dim
|
368
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
369
|
-
return self.config.hidden_size, self.config.hidden_size
|
370
|
-
elif module_name in ["kv_proj"]:
|
371
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
372
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
373
|
-
)
|
374
|
-
elif module_name == "gate_up_proj":
|
375
|
-
return self.config.hidden_size, self.config.intermediate_size
|
376
|
-
elif module_name == "down_proj":
|
377
|
-
return self.config.intermediate_size, self.config.hidden_size
|
378
|
-
else:
|
379
|
-
raise NotImplementedError()
|
380
|
-
|
381
|
-
def get_module_name(self, name):
|
382
|
-
params_mapping = {
|
383
|
-
"q_proj": "qkv_proj",
|
384
|
-
"k_proj": "qkv_proj",
|
385
|
-
"v_proj": "qkv_proj",
|
386
|
-
"gate_proj": "gate_up_proj",
|
387
|
-
"up_proj": "gate_up_proj",
|
388
|
-
}
|
389
|
-
return params_mapping.get(name, name)
|
390
|
-
|
391
366
|
def get_module_name_from_weight_name(self, name):
|
392
367
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
393
368
|
if weight_name in name:
|