sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,10 @@ import torch
6
6
  import torch.nn as nn
7
7
  from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
8
 
9
- from sglang.srt.distributed import (
10
- get_moe_expert_parallel_world_size,
11
- get_tensor_model_parallel_world_size,
12
- )
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
13
10
  from sglang.srt.layers.attention import vision_utils
14
11
  from sglang.srt.layers.logits_processor import LogitsProcessor
15
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
12
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
16
13
  from sglang.srt.layers.pooler import Pooler, PoolingType
17
14
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
15
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
20
17
  from sglang.srt.models.glm4_moe import Glm4MoeModel
21
18
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
22
19
  from sglang.srt.server_args import get_global_server_args
23
- from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
20
+ from sglang.srt.utils import add_prefix, is_cuda
24
21
  from sglang.srt.utils.hf_transformers_utils import get_processor
25
22
 
26
23
  _is_cuda = is_cuda()
@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
39
36
  ) -> None:
40
37
  nn.Module.__init__(self)
41
38
 
42
- config.moe_layer_freq = 1
43
39
  self.config = config
44
40
  vision_utils.update_vit_attn_dummy_heads_config(self.config)
45
41
  self.tp_size = get_tensor_model_parallel_world_size()
46
42
  self.quant_config = quant_config
47
- self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
48
43
  self.num_fused_shared_experts = (
49
44
  0
50
45
  if get_global_server_args().disable_shared_experts_fusion
@@ -58,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
58
53
  )
59
54
  self.visual = Glm4vVisionModel(
60
55
  config.vision_config,
61
- norm_eps=getattr(config, "rms_norm_eps", 1e-5),
62
56
  quant_config=quant_config,
63
57
  prefix=add_prefix("visual", prefix),
64
58
  )
@@ -77,38 +71,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
77
71
  # For EAGLE3 support
78
72
  self.capture_aux_hidden_states = False
79
73
 
80
- def determine_num_fused_shared_experts(
81
- self, architecture: str = "Glm4MoeForCausalLM"
82
- ):
83
- self.num_fused_shared_experts = 0
84
- if get_global_server_args().disable_shared_experts_fusion:
85
- return
86
-
87
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
88
- disable_reason = None
89
- if (
90
- not _is_cuda
91
- or torch.cuda.get_device_capability("cuda") < (8, 0)
92
- or self.config.architectures[0] != architecture
93
- or self.config.n_shared_experts != 1
94
- ):
95
- disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
96
- elif get_moe_expert_parallel_world_size() > 1:
97
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
98
-
99
- if disable_reason is not None:
100
- get_global_server_args().disable_shared_experts_fusion = True
101
- self.num_fused_shared_experts = 0
102
- log_info_on_rank0(
103
- logger,
104
- f"{disable_reason} Shared experts fusion optimization is disabled.",
105
- )
106
- return
107
-
108
- self.num_fused_shared_experts = self.config.n_shared_experts
109
-
110
74
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
111
-
112
75
  if is_nextn:
113
76
  if hasattr(self.config, "num_nextn_predict_layers"):
114
77
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -130,117 +93,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
130
93
  ("gate_up_proj", "gate_proj", 0),
131
94
  ("gate_up_proj", "up_proj", 1),
132
95
  ]
133
- if self.num_fused_shared_experts > 0:
134
- assert self.num_fused_shared_experts == 1
135
- weights_list = list(weights)
136
- weights_dict = dict(weights_list)
137
- if self.quant_config is not None:
138
- if self.quant_config.get_name() == "w8a8_int8":
139
- suffix_list = [
140
- "down_proj.weight",
141
- "down_proj.weight_scale",
142
- "gate_proj.weight",
143
- "gate_proj.weight_scale",
144
- "up_proj.weight",
145
- "up_proj.weight_scale",
146
- ]
147
- elif (
148
- self.quant_config.get_name() == "fp8"
149
- or self.quant_config.get_name() == "blockwise_int8"
150
- or self.quant_config.get_name() == "compressed_tensors"
151
- ):
152
- suffix_list = [
153
- "down_proj.weight",
154
- "down_proj.weight_scale",
155
- "gate_proj.weight",
156
- "gate_proj.weight_scale",
157
- "up_proj.weight",
158
- "up_proj.weight_scale",
159
- ]
160
- elif self.quant_config.get_name() == "awq":
161
- suffix_list = [
162
- "down_proj.qweight",
163
- "down_proj.qzeros",
164
- "down_proj.scales",
165
- "gate_proj.qweight",
166
- "gate_proj.qzeros",
167
- "gate_proj.scales",
168
- "up_proj.qweight",
169
- "up_proj.qzeros",
170
- "up_proj.scales",
171
- ]
172
- elif self.quant_config.get_name() == "modelopt_fp4":
173
- suffix_list = [
174
- "down_proj.weight",
175
- "down_proj.weight_scale",
176
- "down_proj.weight_scale_2",
177
- "down_proj.input_scale",
178
- "gate_proj.weight",
179
- "gate_proj.weight_scale",
180
- "gate_proj.weight_scale_2",
181
- "gate_proj.input_scale",
182
- "up_proj.weight",
183
- "up_proj.weight_scale",
184
- "up_proj.weight_scale_2",
185
- "up_proj.input_scale",
186
- ]
187
- else:
188
- raise ValueError(
189
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
190
- )
191
- else:
192
- suffix_list = [
193
- "down_proj.weight",
194
- "gate_proj.weight",
195
- "up_proj.weight",
196
- ]
197
- names_to_remove = []
198
-
199
- moe_layers = (
200
- range(
201
- self.config.first_k_dense_replace,
202
- self.config.num_hidden_layers,
203
- self.config.moe_layer_freq,
204
- )
205
- if not is_nextn
206
- else [nextn_layer_id]
207
- )
208
96
 
209
- for moe_layer in moe_layers:
210
- for suffix in suffix_list:
211
- shared_expert_weight_name = (
212
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
213
- )
214
- # online fp8 quantization does not load weight_scale
215
- if shared_expert_weight_name not in weights_dict:
216
- continue
217
- weights_list.append(
218
- (
219
- f"model.layers.{moe_layer}."
220
- f"mlp.experts."
221
- f"{self.config.n_routed_experts + 0}"
222
- f".{suffix}",
223
- weights_dict[shared_expert_weight_name],
224
- )
225
- )
226
- names_to_remove += [shared_expert_weight_name]
227
- weights = [w for w in weights_list if w[0] not in names_to_remove]
228
-
229
- # Params for weights, fp8 weight scales, fp8 activation scales
230
- # (param_name, weight_name, expert_id, shard_id)
231
97
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
232
98
  ckpt_gate_proj_name="gate_proj",
233
99
  ckpt_down_proj_name="down_proj",
234
100
  ckpt_up_proj_name="up_proj",
235
- num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
101
+ num_experts=self.config.n_routed_experts,
236
102
  )
237
103
 
238
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
239
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
240
- self.config.q_lora_rank is not None
241
- )
242
- cached_a_proj = {} if fuse_qkv_a_proj else None
243
-
244
104
  if is_nextn:
245
105
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
246
106
  nextn_spec_weight_names = [
@@ -300,23 +160,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
300
160
  # name will be updated to mlp.experts[0].gate_up_proj, which
301
161
  # will then be updated below in expert_params_mapping
302
162
  # for mlp.experts[0].gate_gate_up_proj, which breaks load.
303
- if ("mlp.experts." in name) and name not in params_dict:
163
+ if "mlp.experts" in name:
304
164
  continue
305
165
  name = name.replace(weight_name, param_name)
306
166
  # Skip loading extra bias for GPTQ models.
307
167
  if name.endswith(".bias") and name not in params_dict:
308
168
  continue
309
- param = params_dict[name]
169
+ if name not in params_dict:
170
+ continue
310
171
 
172
+ param = params_dict[name]
311
173
  weight_loader = param.weight_loader
312
174
  weight_loader(param, loaded_weight, shard_id)
313
175
  break
314
176
  else:
177
+ # Track if this is an expert weight to enable early skipping
178
+ is_expert_weight = False
179
+
315
180
  for mapping in expert_params_mapping:
316
181
  param_name, weight_name, expert_id, shard_id = mapping
317
182
  if weight_name not in name:
318
183
  continue
184
+
185
+ # Mark as expert weight regardless of whether we can process it
186
+ is_expert_weight = True
187
+
319
188
  name = name.replace(weight_name, param_name)
189
+ if name not in params_dict:
190
+ # Expert weight not on this rank, will be skipped below
191
+ continue
192
+
320
193
  param = params_dict[name]
321
194
  weight_loader = param.weight_loader
322
195
  weight_loader(
@@ -328,64 +201,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
328
201
  )
329
202
  break
330
203
  else:
204
+ if is_expert_weight:
205
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
206
+ continue
207
+
331
208
  if "visual" in name:
332
- # adapt to VisionAttention
209
+ # adapt to VisionAttention for GLM-V
333
210
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
334
211
 
335
212
  # Skip loading extra bias for GPTQ models.
336
213
  if name.endswith(".bias") and name not in params_dict:
337
214
  continue
338
- if fuse_qkv_a_proj and (
339
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
340
- ):
341
- cached_a_proj[name] = loaded_weight
342
- q_a_proj_name = (
343
- name
344
- if "q_a_proj" in name
345
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
346
- )
347
- kv_a_proj_name = (
348
- name
349
- if "kv_a_proj_with_mqa" in name
350
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
351
- )
352
-
353
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
354
- if (
355
- q_a_proj_name in cached_a_proj
356
- and kv_a_proj_name in cached_a_proj
357
- ):
358
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
359
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
360
- fused_weight = torch.cat(
361
- [q_a_proj_weight, kv_a_proj_weight], dim=0
362
- )
363
- param_name = (
364
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
365
- if "q_a_proj" in name
366
- else name.replace(
367
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
368
- )
369
- )
370
- param = params_dict[param_name]
215
+ if name not in params_dict:
216
+ continue
371
217
 
372
- weight_loader = getattr(
373
- param, "weight_loader", default_weight_loader
374
- )
375
- weight_loader(param, fused_weight)
376
- cached_a_proj.pop(q_a_proj_name)
377
- cached_a_proj.pop(kv_a_proj_name)
378
- else:
379
- if (
380
- "k_scale" in name or "v_scale" in name
381
- ) and name not in params_dict:
382
- # modelopt attn kv scale is named differently
383
- if any(scale in name for scale in ["k_scale", "v_scale"]):
384
- name = name.replace("_proj", "attn_mqa")
385
- else:
386
- logger.warning(
387
- f"Unknown scale found in checkpoint: {name}"
388
- )
218
+ if name in params_dict.keys():
389
219
  param = params_dict[name]
390
220
  weight_loader = getattr(
391
221
  param, "weight_loader", default_weight_loader
@@ -395,6 +225,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
395
225
  self.config, name, loaded_weight
396
226
  )
397
227
  weight_loader(param, loaded_weight)
228
+ else:
229
+ logger.warning(f"Parameter {name} not found in params_dict")
398
230
 
399
231
 
400
232
  EntryClass = [Glm4vMoeForConditionalGeneration]
@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
70
70
  enable_fused_set_kv_buffer,
71
71
  )
72
72
  from sglang.srt.server_args import get_global_server_args
73
- from sglang.srt.utils import (
74
- LazyValue,
75
- add_prefix,
76
- is_cuda,
77
- is_flashinfer_available,
78
- is_sm100_supported,
79
- make_layers,
80
- )
73
+ from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers
81
74
 
82
75
  _is_cuda = is_cuda()
83
- _is_flashinfer_available = is_flashinfer_available()
84
- _is_sm100_supported = is_cuda() and is_sm100_supported()
85
76
 
86
77
 
87
78
  if _is_cuda: