sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/glm4v_moe.py
CHANGED
|
@@ -6,13 +6,10 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
|
8
8
|
|
|
9
|
-
from sglang.srt.distributed import
|
|
10
|
-
get_moe_expert_parallel_world_size,
|
|
11
|
-
get_tensor_model_parallel_world_size,
|
|
12
|
-
)
|
|
9
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
13
10
|
from sglang.srt.layers.attention import vision_utils
|
|
14
11
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
15
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
12
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
16
13
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
17
14
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
18
15
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
|
20
17
|
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
|
21
18
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
|
22
19
|
from sglang.srt.server_args import get_global_server_args
|
|
23
|
-
from sglang.srt.utils import add_prefix, is_cuda
|
|
20
|
+
from sglang.srt.utils import add_prefix, is_cuda
|
|
24
21
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
25
22
|
|
|
26
23
|
_is_cuda = is_cuda()
|
|
@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
39
36
|
) -> None:
|
|
40
37
|
nn.Module.__init__(self)
|
|
41
38
|
|
|
42
|
-
config.moe_layer_freq = 1
|
|
43
39
|
self.config = config
|
|
44
40
|
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
|
45
41
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
46
42
|
self.quant_config = quant_config
|
|
47
|
-
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
|
48
43
|
self.num_fused_shared_experts = (
|
|
49
44
|
0
|
|
50
45
|
if get_global_server_args().disable_shared_experts_fusion
|
|
@@ -58,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
58
53
|
)
|
|
59
54
|
self.visual = Glm4vVisionModel(
|
|
60
55
|
config.vision_config,
|
|
61
|
-
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
|
62
56
|
quant_config=quant_config,
|
|
63
57
|
prefix=add_prefix("visual", prefix),
|
|
64
58
|
)
|
|
@@ -77,38 +71,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
77
71
|
# For EAGLE3 support
|
|
78
72
|
self.capture_aux_hidden_states = False
|
|
79
73
|
|
|
80
|
-
def determine_num_fused_shared_experts(
|
|
81
|
-
self, architecture: str = "Glm4MoeForCausalLM"
|
|
82
|
-
):
|
|
83
|
-
self.num_fused_shared_experts = 0
|
|
84
|
-
if get_global_server_args().disable_shared_experts_fusion:
|
|
85
|
-
return
|
|
86
|
-
|
|
87
|
-
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
|
88
|
-
disable_reason = None
|
|
89
|
-
if (
|
|
90
|
-
not _is_cuda
|
|
91
|
-
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
|
92
|
-
or self.config.architectures[0] != architecture
|
|
93
|
-
or self.config.n_shared_experts != 1
|
|
94
|
-
):
|
|
95
|
-
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
|
96
|
-
elif get_moe_expert_parallel_world_size() > 1:
|
|
97
|
-
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
|
98
|
-
|
|
99
|
-
if disable_reason is not None:
|
|
100
|
-
get_global_server_args().disable_shared_experts_fusion = True
|
|
101
|
-
self.num_fused_shared_experts = 0
|
|
102
|
-
log_info_on_rank0(
|
|
103
|
-
logger,
|
|
104
|
-
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
|
105
|
-
)
|
|
106
|
-
return
|
|
107
|
-
|
|
108
|
-
self.num_fused_shared_experts = self.config.n_shared_experts
|
|
109
|
-
|
|
110
74
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
|
111
|
-
|
|
112
75
|
if is_nextn:
|
|
113
76
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
|
114
77
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
|
@@ -130,117 +93,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
130
93
|
("gate_up_proj", "gate_proj", 0),
|
|
131
94
|
("gate_up_proj", "up_proj", 1),
|
|
132
95
|
]
|
|
133
|
-
if self.num_fused_shared_experts > 0:
|
|
134
|
-
assert self.num_fused_shared_experts == 1
|
|
135
|
-
weights_list = list(weights)
|
|
136
|
-
weights_dict = dict(weights_list)
|
|
137
|
-
if self.quant_config is not None:
|
|
138
|
-
if self.quant_config.get_name() == "w8a8_int8":
|
|
139
|
-
suffix_list = [
|
|
140
|
-
"down_proj.weight",
|
|
141
|
-
"down_proj.weight_scale",
|
|
142
|
-
"gate_proj.weight",
|
|
143
|
-
"gate_proj.weight_scale",
|
|
144
|
-
"up_proj.weight",
|
|
145
|
-
"up_proj.weight_scale",
|
|
146
|
-
]
|
|
147
|
-
elif (
|
|
148
|
-
self.quant_config.get_name() == "fp8"
|
|
149
|
-
or self.quant_config.get_name() == "blockwise_int8"
|
|
150
|
-
or self.quant_config.get_name() == "compressed_tensors"
|
|
151
|
-
):
|
|
152
|
-
suffix_list = [
|
|
153
|
-
"down_proj.weight",
|
|
154
|
-
"down_proj.weight_scale",
|
|
155
|
-
"gate_proj.weight",
|
|
156
|
-
"gate_proj.weight_scale",
|
|
157
|
-
"up_proj.weight",
|
|
158
|
-
"up_proj.weight_scale",
|
|
159
|
-
]
|
|
160
|
-
elif self.quant_config.get_name() == "awq":
|
|
161
|
-
suffix_list = [
|
|
162
|
-
"down_proj.qweight",
|
|
163
|
-
"down_proj.qzeros",
|
|
164
|
-
"down_proj.scales",
|
|
165
|
-
"gate_proj.qweight",
|
|
166
|
-
"gate_proj.qzeros",
|
|
167
|
-
"gate_proj.scales",
|
|
168
|
-
"up_proj.qweight",
|
|
169
|
-
"up_proj.qzeros",
|
|
170
|
-
"up_proj.scales",
|
|
171
|
-
]
|
|
172
|
-
elif self.quant_config.get_name() == "modelopt_fp4":
|
|
173
|
-
suffix_list = [
|
|
174
|
-
"down_proj.weight",
|
|
175
|
-
"down_proj.weight_scale",
|
|
176
|
-
"down_proj.weight_scale_2",
|
|
177
|
-
"down_proj.input_scale",
|
|
178
|
-
"gate_proj.weight",
|
|
179
|
-
"gate_proj.weight_scale",
|
|
180
|
-
"gate_proj.weight_scale_2",
|
|
181
|
-
"gate_proj.input_scale",
|
|
182
|
-
"up_proj.weight",
|
|
183
|
-
"up_proj.weight_scale",
|
|
184
|
-
"up_proj.weight_scale_2",
|
|
185
|
-
"up_proj.input_scale",
|
|
186
|
-
]
|
|
187
|
-
else:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
|
190
|
-
)
|
|
191
|
-
else:
|
|
192
|
-
suffix_list = [
|
|
193
|
-
"down_proj.weight",
|
|
194
|
-
"gate_proj.weight",
|
|
195
|
-
"up_proj.weight",
|
|
196
|
-
]
|
|
197
|
-
names_to_remove = []
|
|
198
|
-
|
|
199
|
-
moe_layers = (
|
|
200
|
-
range(
|
|
201
|
-
self.config.first_k_dense_replace,
|
|
202
|
-
self.config.num_hidden_layers,
|
|
203
|
-
self.config.moe_layer_freq,
|
|
204
|
-
)
|
|
205
|
-
if not is_nextn
|
|
206
|
-
else [nextn_layer_id]
|
|
207
|
-
)
|
|
208
96
|
|
|
209
|
-
for moe_layer in moe_layers:
|
|
210
|
-
for suffix in suffix_list:
|
|
211
|
-
shared_expert_weight_name = (
|
|
212
|
-
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
|
213
|
-
)
|
|
214
|
-
# online fp8 quantization does not load weight_scale
|
|
215
|
-
if shared_expert_weight_name not in weights_dict:
|
|
216
|
-
continue
|
|
217
|
-
weights_list.append(
|
|
218
|
-
(
|
|
219
|
-
f"model.layers.{moe_layer}."
|
|
220
|
-
f"mlp.experts."
|
|
221
|
-
f"{self.config.n_routed_experts + 0}"
|
|
222
|
-
f".{suffix}",
|
|
223
|
-
weights_dict[shared_expert_weight_name],
|
|
224
|
-
)
|
|
225
|
-
)
|
|
226
|
-
names_to_remove += [shared_expert_weight_name]
|
|
227
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
|
228
|
-
|
|
229
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
230
|
-
# (param_name, weight_name, expert_id, shard_id)
|
|
231
97
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
232
98
|
ckpt_gate_proj_name="gate_proj",
|
|
233
99
|
ckpt_down_proj_name="down_proj",
|
|
234
100
|
ckpt_up_proj_name="up_proj",
|
|
235
|
-
num_experts=self.config.n_routed_experts
|
|
101
|
+
num_experts=self.config.n_routed_experts,
|
|
236
102
|
)
|
|
237
103
|
|
|
238
|
-
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
|
239
|
-
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
|
240
|
-
self.config.q_lora_rank is not None
|
|
241
|
-
)
|
|
242
|
-
cached_a_proj = {} if fuse_qkv_a_proj else None
|
|
243
|
-
|
|
244
104
|
if is_nextn:
|
|
245
105
|
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
|
246
106
|
nextn_spec_weight_names = [
|
|
@@ -300,23 +160,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
300
160
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
301
161
|
# will then be updated below in expert_params_mapping
|
|
302
162
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
303
|
-
if
|
|
163
|
+
if "mlp.experts" in name:
|
|
304
164
|
continue
|
|
305
165
|
name = name.replace(weight_name, param_name)
|
|
306
166
|
# Skip loading extra bias for GPTQ models.
|
|
307
167
|
if name.endswith(".bias") and name not in params_dict:
|
|
308
168
|
continue
|
|
309
|
-
|
|
169
|
+
if name not in params_dict:
|
|
170
|
+
continue
|
|
310
171
|
|
|
172
|
+
param = params_dict[name]
|
|
311
173
|
weight_loader = param.weight_loader
|
|
312
174
|
weight_loader(param, loaded_weight, shard_id)
|
|
313
175
|
break
|
|
314
176
|
else:
|
|
177
|
+
# Track if this is an expert weight to enable early skipping
|
|
178
|
+
is_expert_weight = False
|
|
179
|
+
|
|
315
180
|
for mapping in expert_params_mapping:
|
|
316
181
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
317
182
|
if weight_name not in name:
|
|
318
183
|
continue
|
|
184
|
+
|
|
185
|
+
# Mark as expert weight regardless of whether we can process it
|
|
186
|
+
is_expert_weight = True
|
|
187
|
+
|
|
319
188
|
name = name.replace(weight_name, param_name)
|
|
189
|
+
if name not in params_dict:
|
|
190
|
+
# Expert weight not on this rank, will be skipped below
|
|
191
|
+
continue
|
|
192
|
+
|
|
320
193
|
param = params_dict[name]
|
|
321
194
|
weight_loader = param.weight_loader
|
|
322
195
|
weight_loader(
|
|
@@ -328,64 +201,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
328
201
|
)
|
|
329
202
|
break
|
|
330
203
|
else:
|
|
204
|
+
if is_expert_weight:
|
|
205
|
+
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
|
206
|
+
continue
|
|
207
|
+
|
|
331
208
|
if "visual" in name:
|
|
332
|
-
# adapt to VisionAttention
|
|
209
|
+
# adapt to VisionAttention for GLM-V
|
|
333
210
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
|
334
211
|
|
|
335
212
|
# Skip loading extra bias for GPTQ models.
|
|
336
213
|
if name.endswith(".bias") and name not in params_dict:
|
|
337
214
|
continue
|
|
338
|
-
if
|
|
339
|
-
|
|
340
|
-
):
|
|
341
|
-
cached_a_proj[name] = loaded_weight
|
|
342
|
-
q_a_proj_name = (
|
|
343
|
-
name
|
|
344
|
-
if "q_a_proj" in name
|
|
345
|
-
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
|
346
|
-
)
|
|
347
|
-
kv_a_proj_name = (
|
|
348
|
-
name
|
|
349
|
-
if "kv_a_proj_with_mqa" in name
|
|
350
|
-
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
|
351
|
-
)
|
|
352
|
-
|
|
353
|
-
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
|
354
|
-
if (
|
|
355
|
-
q_a_proj_name in cached_a_proj
|
|
356
|
-
and kv_a_proj_name in cached_a_proj
|
|
357
|
-
):
|
|
358
|
-
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
|
359
|
-
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
|
360
|
-
fused_weight = torch.cat(
|
|
361
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
|
362
|
-
)
|
|
363
|
-
param_name = (
|
|
364
|
-
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
|
365
|
-
if "q_a_proj" in name
|
|
366
|
-
else name.replace(
|
|
367
|
-
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
|
368
|
-
)
|
|
369
|
-
)
|
|
370
|
-
param = params_dict[param_name]
|
|
215
|
+
if name not in params_dict:
|
|
216
|
+
continue
|
|
371
217
|
|
|
372
|
-
|
|
373
|
-
param, "weight_loader", default_weight_loader
|
|
374
|
-
)
|
|
375
|
-
weight_loader(param, fused_weight)
|
|
376
|
-
cached_a_proj.pop(q_a_proj_name)
|
|
377
|
-
cached_a_proj.pop(kv_a_proj_name)
|
|
378
|
-
else:
|
|
379
|
-
if (
|
|
380
|
-
"k_scale" in name or "v_scale" in name
|
|
381
|
-
) and name not in params_dict:
|
|
382
|
-
# modelopt attn kv scale is named differently
|
|
383
|
-
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
|
384
|
-
name = name.replace("_proj", "attn_mqa")
|
|
385
|
-
else:
|
|
386
|
-
logger.warning(
|
|
387
|
-
f"Unknown scale found in checkpoint: {name}"
|
|
388
|
-
)
|
|
218
|
+
if name in params_dict.keys():
|
|
389
219
|
param = params_dict[name]
|
|
390
220
|
weight_loader = getattr(
|
|
391
221
|
param, "weight_loader", default_weight_loader
|
|
@@ -395,6 +225,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
|
395
225
|
self.config, name, loaded_weight
|
|
396
226
|
)
|
|
397
227
|
weight_loader(param, loaded_weight)
|
|
228
|
+
else:
|
|
229
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
|
398
230
|
|
|
399
231
|
|
|
400
232
|
EntryClass = [Glm4vMoeForConditionalGeneration]
|
sglang/srt/models/gpt_oss.py
CHANGED
|
@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
|
|
|
70
70
|
enable_fused_set_kv_buffer,
|
|
71
71
|
)
|
|
72
72
|
from sglang.srt.server_args import get_global_server_args
|
|
73
|
-
from sglang.srt.utils import
|
|
74
|
-
LazyValue,
|
|
75
|
-
add_prefix,
|
|
76
|
-
is_cuda,
|
|
77
|
-
is_flashinfer_available,
|
|
78
|
-
is_sm100_supported,
|
|
79
|
-
make_layers,
|
|
80
|
-
)
|
|
73
|
+
from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers
|
|
81
74
|
|
|
82
75
|
_is_cuda = is_cuda()
|
|
83
|
-
_is_flashinfer_available = is_flashinfer_available()
|
|
84
|
-
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
85
76
|
|
|
86
77
|
|
|
87
78
|
if _is_cuda:
|