sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,35 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Modeling from:
16
+ # ./llama.py and
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
18
+ """Inference-only GLM-4.1V model compatible with HuggingFace weights."""
19
+
1
20
  import logging
2
- from functools import lru_cache, partial
21
+ from functools import lru_cache
3
22
  from typing import Iterable, List, Optional, Tuple
4
23
 
5
24
  import torch
6
25
  import torch.nn as nn
7
26
  import torch.nn.functional as F
27
+ from einops import rearrange
8
28
  from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
9
29
 
10
30
  from sglang.srt.layers.activation import SiluAndMul
11
31
  from sglang.srt.layers.attention import vision_utils
12
- from sglang.srt.layers.dp_attention import get_attention_tp_size
32
+ from sglang.srt.layers.attention.vision import VisionAttention
13
33
  from sglang.srt.layers.layernorm import RMSNorm
14
34
  from sglang.srt.layers.linear import (
15
35
  ColumnParallelLinear,
@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
20
40
  from sglang.srt.layers.pooler import Pooler, PoolingType
21
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
42
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
- from sglang.srt.managers.schedule_batch import MultimodalDataItem
43
+ from sglang.srt.managers.mm_utils import (
44
+ MultiModalityDataPaddingPatternMultimodalTokens,
45
+ general_mm_embed_routine,
46
+ )
47
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
25
50
  from sglang.srt.models.glm4 import Glm4Model
26
- from sglang.srt.models.qwen2_5_vl import (
27
- Qwen2_5_VisionBlock,
28
- Qwen2_5_VLForConditionalGeneration,
29
- )
30
51
  from sglang.srt.utils import add_prefix
31
52
  from sglang.srt.utils.hf_transformers_utils import get_processor
32
53
 
@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
56
77
  super().__init__()
57
78
  self.gate_up_proj = MergedColumnParallelLinear(
58
79
  input_size=in_features,
59
- output_sizes=[hidden_features] * 2,
80
+ output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
60
81
  bias=bias,
61
82
  quant_config=quant_config,
62
83
  prefix=add_prefix("gate_up_proj", prefix),
@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
77
98
  return x
78
99
 
79
100
 
80
- class Glm4vVisionBlock(Qwen2_5_VisionBlock):
101
+ class Glm4vVisionBlock(nn.Module):
81
102
  def __init__(
82
103
  self,
83
- config: Glm4vVisionConfig,
84
- norm_layer: Optional[nn.Module] = None,
104
+ dim: int,
105
+ intermediate_dim: int,
106
+ num_heads: int,
107
+ attn_implementation: Optional[str] = None,
85
108
  quant_config: Optional[QuantizationConfig] = None,
86
109
  prefix: str = "",
110
+ num_dummy_heads: int = 0,
111
+ rms_norm_eps: float = 1e-5,
87
112
  ) -> None:
88
- super().__init__(
89
- dim=config.hidden_size,
90
- intermediate_dim=config.out_hidden_size,
91
- num_heads=config.num_heads,
92
- hidden_act=config.hidden_act,
93
- norm_layer=norm_layer,
113
+ super().__init__()
114
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
115
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
116
+
117
+ if attn_implementation is None:
118
+ softmax_in_single_precision = False
119
+ qkv_backend = None
120
+ flatten_batch = True
121
+ elif attn_implementation == "sdpa":
122
+ softmax_in_single_precision = False
123
+ qkv_backend = "sdpa"
124
+ flatten_batch = True
125
+ elif attn_implementation == "flash_attention_2":
126
+ softmax_in_single_precision = False
127
+ qkv_backend = "triton_attn"
128
+ flatten_batch = True
129
+ elif attn_implementation == "eager":
130
+ softmax_in_single_precision = True
131
+ qkv_backend = "sdpa"
132
+ flatten_batch = True
133
+ elif attn_implementation == "flash_attention_3":
134
+ softmax_in_single_precision = False
135
+ qkv_backend = "fa3"
136
+ flatten_batch = True
137
+
138
+ self.attn = VisionAttention(
139
+ embed_dim=dim,
140
+ num_heads=num_heads,
141
+ projection_size=dim,
142
+ use_qkv_parallel=True,
143
+ rotary_embed="normal",
144
+ proj_bias=True,
145
+ qkv_backend=qkv_backend,
146
+ softmax_in_single_precision=softmax_in_single_precision,
147
+ flatten_batch=flatten_batch,
94
148
  quant_config=quant_config,
95
- prefix=prefix,
96
- num_dummy_heads=config.num_dummy_heads,
97
- rms_norm_eps=config.rms_norm_eps,
149
+ prefix=add_prefix("attn", prefix),
150
+ num_dummy_heads=num_dummy_heads,
98
151
  )
99
-
100
152
  self.mlp = Glm4vVisionMLP(
101
- config.hidden_size,
102
- config.out_hidden_size,
103
- bias=False,
153
+ dim,
154
+ intermediate_dim,
104
155
  quant_config=quant_config,
105
156
  prefix=add_prefix("mlp", prefix),
106
157
  )
107
158
 
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ cu_seqlens: torch.Tensor,
163
+ position_embeddings: torch.Tensor,
164
+ ) -> torch.Tensor:
165
+ S, B, H = x.shape
166
+ # norm1: flatten to 2D -> [S*B, H], then reshape back
167
+ x2d = x.reshape(-1, H)
168
+ hidden_states = self.norm1(x2d).reshape(S, B, H)
169
+
170
+ # Attention expects [B, S, H]
171
+ hidden_states = rearrange(hidden_states, "s b h -> b s h")
172
+ attn = self.attn(
173
+ hidden_states,
174
+ cu_seqlens=cu_seqlens,
175
+ position_embeddings=position_embeddings,
176
+ )
177
+ attn = rearrange(attn, "b s h -> s b h")
178
+
179
+ # norm2 with fused residual-add: also 2D
180
+ attn2d = attn.reshape(-1, H)
181
+ x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
182
+ x_norm = x_norm_2d.reshape(S, B, H)
183
+ x_after_add = x_after_add_2d.reshape(S, B, H)
184
+
185
+ # MLP and final residual
186
+ mlp_out = self.mlp(x_norm)
187
+ x = x_after_add + mlp_out
188
+ return x
189
+
108
190
 
109
191
  class Glm4vVisionPatchEmbed(nn.Module):
110
192
  def __init__(
@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
320
402
  def __init__(
321
403
  self,
322
404
  vision_config: Glm4vVisionConfig,
323
- norm_eps: float = 1e-6,
324
405
  quant_config: Optional[QuantizationConfig] = None,
325
406
  prefix: str = "",
326
407
  ) -> None:
@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
344
425
  hidden_size=self.hidden_size,
345
426
  )
346
427
 
347
- norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
348
428
  head_dim = self.hidden_size // self.num_heads
349
429
  self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
350
430
 
351
431
  self.blocks = nn.ModuleList(
352
432
  [
353
433
  Glm4vVisionBlock(
354
- config=vision_config,
355
- norm_layer=norm_layer,
434
+ dim=self.hidden_size,
435
+ intermediate_dim=self.out_hidden_size,
436
+ num_heads=self.num_heads,
356
437
  quant_config=quant_config,
357
438
  prefix=add_prefix(f"blocks.{layer_idx}", prefix),
439
+ rms_norm_eps=vision_config.rms_norm_eps,
358
440
  )
359
441
  for layer_idx in range(depth)
360
442
  ]
@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
461
543
  return x
462
544
 
463
545
 
464
- class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
546
+ class Glm4vForConditionalGeneration(nn.Module):
465
547
  def __init__(
466
548
  self,
467
549
  config: Glm4vConfig,
468
550
  quant_config: Optional[QuantizationConfig] = None,
469
551
  prefix: str = "",
470
552
  ) -> None:
471
- nn.Module.__init__(self)
553
+ super().__init__()
472
554
 
473
555
  self.config = config
474
- vision_utils.update_vit_attn_dummy_heads_config(self.config)
475
- self.model = Glm4Model(
476
- config,
477
- quant_config,
478
- prefix=add_prefix("model", prefix),
479
- )
480
556
  self.visual = Glm4vVisionModel(
481
557
  config.vision_config,
482
- norm_eps=getattr(config, "rms_norm_eps", 1e-5),
483
558
  quant_config=quant_config,
484
559
  prefix=add_prefix("visual", prefix),
485
560
  )
486
561
 
562
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
563
+
564
+ self.model = Glm4Model(
565
+ config,
566
+ quant_config=quant_config,
567
+ prefix=add_prefix("model", prefix),
568
+ )
569
+
487
570
  if config.tie_word_embeddings:
488
571
  self.lm_head = self.model.embed_tokens
489
572
  else:
@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
494
577
  prefix=add_prefix("lm_head", prefix),
495
578
  )
496
579
 
580
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
581
+
497
582
  self.logits_processor = LogitsProcessor(config)
498
583
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
- self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
500
584
 
501
585
  # For EAGLE3 support
502
586
  self.capture_aux_hidden_states = False
503
587
 
588
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
589
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
590
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
591
+
504
592
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
505
593
  pixel_values = torch.cat(
506
594
  [item.feature.squeeze(0) for item in items], dim=0
@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
542
630
  video_embeds = torch.split(video_embeds, split_sizes)
543
631
  return torch.cat(video_embeds)
544
632
 
545
- def _update_hf_config(self):
546
- """update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
547
- tp_size = get_attention_tp_size()
548
- num_heads = self.config.vision_config.num_heads
549
- head_dim = self.config.vision_config.hidden_size // num_heads
550
- num_dummy_heads = 0
633
+ def get_input_embeddings(self):
634
+ return self.model.embed_tokens
551
635
 
552
- if num_heads % tp_size != 0:
553
- num_dummy_heads = (
554
- (num_heads + tp_size - 1) // tp_size
555
- ) * tp_size - num_heads
636
+ @torch.no_grad()
637
+ def forward(
638
+ self,
639
+ input_ids: torch.Tensor,
640
+ positions: torch.Tensor,
641
+ forward_batch: ForwardBatch,
642
+ get_embedding: bool = False,
643
+ ):
644
+ """Run forward pass for GLM-4.1V.
645
+
646
+ Args:
647
+ input_ids: Flattened (concatenated) input_ids corresponding to a
648
+ batch.
649
+ positions: Flattened (concatenated) position ids corresponding to a
650
+ batch.
651
+ **NOTE**: If mrope is enabled (default setting for GLM-4.1V
652
+ opensource models), the shape will be `(3, seq_len)`,
653
+ otherwise it will be `(seq_len,).
654
+ (Use input_metadata.mrope_positions to replace it)
655
+ """
656
+ if self.is_mrope_enabled:
657
+ positions = forward_batch.mrope_positions
658
+
659
+ if not (
660
+ forward_batch.forward_mode.is_decode()
661
+ or not forward_batch.contains_image_inputs()
662
+ ):
663
+ if self.is_mrope_enabled:
664
+ assert positions.ndim == 2 and positions.size(0) == 3, (
665
+ "multimodal section rotary embedding requires "
666
+ f"(3, seq_len) positions, but got {positions.size()}"
667
+ )
668
+
669
+ hidden_states = general_mm_embed_routine(
670
+ input_ids=input_ids,
671
+ forward_batch=forward_batch,
672
+ language_model=self.model,
673
+ multimodal_model=self,
674
+ positions=positions,
675
+ )
556
676
 
557
- setattr(self.config.vision_config, "head_dim", head_dim)
558
- setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
677
+ aux_hidden_states = None
678
+ if self.capture_aux_hidden_states:
679
+ hidden_states, aux_hidden_states = hidden_states
680
+
681
+ if not get_embedding:
682
+ return self.logits_processor(
683
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
684
+ )
685
+ else:
686
+ return self.pooler(hidden_states, forward_batch)
559
687
 
560
688
  def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
561
689
  """pad attn qkv weights for dummy heads"""
@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
598
726
  ]
599
727
  params_dict = dict(self.named_parameters(remove_duplicate=False))
600
728
  for name, loaded_weight in weights:
601
- if "language_model." in name:
602
- name = name.replace("language_model.", "")
603
- if "model.visual." in name:
604
- name = name.replace("model.visual.", "visual.")
605
-
606
729
  if "rotary_emb.inv_freq" in name:
607
730
  continue
731
+ if "language_model" in name:
732
+ name = name.replace(r"model.language_model.", r"model.")
733
+ if "model.visual." in name:
734
+ name = name.replace("model.visual.", "visual.")
608
735
 
609
736
  for param_name, weight_name, shard_id in stacked_params_mapping:
610
737
  if weight_name not in name:
@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
639
766
  )
640
767
  weight_loader(param, loaded_weight)
641
768
 
769
+ def get_embed_and_head(self):
770
+ return self.model.embed_tokens.weight, self.lm_head.weight
771
+
772
+ def set_embed_and_head(self, embed, head):
773
+ del self.model.embed_tokens.weight
774
+ self.model.embed_tokens.weight = embed
775
+ if self.config.tie_word_embeddings:
776
+ self.lm_head = self.model.embed_tokens
777
+ else:
778
+ del self.lm_head.weight
779
+ self.lm_head.weight = head
780
+ torch.cuda.empty_cache()
781
+ torch.cuda.synchronize()
782
+
642
783
 
643
784
  EntryClass = [Glm4vForConditionalGeneration]
@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
53
53
  )
54
54
  self.visual = Glm4vVisionModel(
55
55
  config.vision_config,
56
- norm_eps=getattr(config, "rms_norm_eps", 1e-5),
57
56
  quant_config=quant_config,
58
57
  prefix=add_prefix("visual", prefix),
59
58
  )
@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
70
70
  enable_fused_set_kv_buffer,
71
71
  )
72
72
  from sglang.srt.server_args import get_global_server_args
73
- from sglang.srt.utils import (
74
- LazyValue,
75
- add_prefix,
76
- is_cuda,
77
- is_flashinfer_available,
78
- is_sm100_supported,
79
- make_layers,
80
- )
73
+ from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers
81
74
 
82
75
  _is_cuda = is_cuda()
83
- _is_flashinfer_available = is_flashinfer_available()
84
- _is_sm100_supported = is_cuda() and is_sm100_supported()
85
76
 
86
77
 
87
78
  if _is_cuda: