sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
1
+ import logging
2
+ from functools import lru_cache, partial
3
+ from typing import Iterable, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
9
+
10
+ from sglang.srt.hf_transformers_utils import get_processor
11
+ from sglang.srt.layers.activation import SiluAndMul
12
+ from sglang.srt.layers.layernorm import RMSNorm
13
+ from sglang.srt.layers.linear import (
14
+ ColumnParallelLinear,
15
+ MergedColumnParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
22
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem
23
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
+ from sglang.srt.models.glm4 import Glm4Model
25
+ from sglang.srt.models.qwen2_5_vl import (
26
+ Qwen2_5_VisionBlock,
27
+ Qwen2_5_VLForConditionalGeneration,
28
+ )
29
+ from sglang.srt.utils import add_prefix
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ cached_get_processor = lru_cache(get_processor)
34
+
35
+
36
+ class Glm4vRMSNorm(RMSNorm):
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ original_shape = x.shape
39
+ x_2d = x.contiguous().reshape(-1, original_shape[-1])
40
+ x_2d = super().forward(x_2d)
41
+ x = x_2d.reshape(original_shape)
42
+ return x
43
+
44
+
45
+ class Glm4vVisionMLP(nn.Module):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: int,
50
+ bias: bool = False,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ prefix: str = "",
53
+ ):
54
+ super().__init__()
55
+ self.gate_up_proj = MergedColumnParallelLinear(
56
+ input_size=in_features,
57
+ output_sizes=[hidden_features] * 2,
58
+ bias=bias,
59
+ quant_config=quant_config,
60
+ prefix=add_prefix("gate_up_proj", prefix),
61
+ )
62
+ self.down_proj = RowParallelLinear(
63
+ hidden_features,
64
+ in_features,
65
+ bias=bias,
66
+ quant_config=quant_config,
67
+ prefix=add_prefix("down_proj", prefix),
68
+ )
69
+ self.act_fn = SiluAndMul()
70
+
71
+ def forward(self, x: torch.Tensor):
72
+ gate_up, _ = self.gate_up_proj(x)
73
+ x = self.act_fn(gate_up)
74
+ x, _ = self.down_proj(x)
75
+ return x
76
+
77
+
78
+ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
79
+ def __init__(
80
+ self,
81
+ config: Glm4vVisionConfig,
82
+ norm_layer: Optional[nn.Module] = None,
83
+ quant_config: Optional[QuantizationConfig] = None,
84
+ prefix: str = "",
85
+ ) -> None:
86
+ super().__init__(
87
+ dim=config.hidden_size,
88
+ intermediate_dim=config.out_hidden_size,
89
+ num_heads=config.num_heads,
90
+ hidden_act=config.hidden_act,
91
+ norm_layer=norm_layer,
92
+ quant_config=quant_config,
93
+ prefix=prefix,
94
+ )
95
+ self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
96
+ self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
97
+
98
+ self.mlp = Glm4vVisionMLP(
99
+ config.hidden_size,
100
+ config.out_hidden_size,
101
+ bias=False,
102
+ quant_config=quant_config,
103
+ prefix=add_prefix("mlp", prefix),
104
+ )
105
+
106
+
107
+ class Glm4vVisionPatchEmbed(nn.Module):
108
+ def __init__(
109
+ self,
110
+ patch_size: int = 14,
111
+ temporal_patch_size: int = 2,
112
+ in_channels: int = 3,
113
+ hidden_size: int = 1536,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.patch_size = patch_size
117
+ self.temporal_patch_size = temporal_patch_size
118
+ self.hidden_size = hidden_size
119
+ self.in_channels = in_channels
120
+
121
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
122
+ self.proj = nn.Conv3d(
123
+ in_channels,
124
+ hidden_size,
125
+ kernel_size=kernel_size,
126
+ stride=kernel_size,
127
+ bias=True,
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x = x.view(
132
+ -1,
133
+ self.in_channels,
134
+ self.temporal_patch_size,
135
+ self.patch_size,
136
+ self.patch_size,
137
+ )
138
+ x = self.proj(x).view(-1, self.hidden_size)
139
+ return x
140
+
141
+
142
+ class Glm4vPatchMerger(nn.Module):
143
+ def __init__(
144
+ self,
145
+ d_model: int,
146
+ context_dim: int,
147
+ quant_config: Optional[QuantizationConfig] = None,
148
+ bias: bool = False,
149
+ prefix: str = "",
150
+ ) -> None:
151
+ super().__init__()
152
+ self.hidden_size = d_model
153
+ self.proj = ColumnParallelLinear(
154
+ self.hidden_size,
155
+ self.hidden_size,
156
+ bias=bias,
157
+ quant_config=quant_config,
158
+ prefix=add_prefix("proj", prefix),
159
+ gather_output=True,
160
+ )
161
+ self.post_projection_norm = nn.LayerNorm(self.hidden_size)
162
+ self.gate_up_proj = MergedColumnParallelLinear(
163
+ input_size=self.hidden_size,
164
+ output_sizes=[context_dim] * 2,
165
+ bias=bias,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("gate_up_proj", prefix),
168
+ )
169
+ self.down_proj = RowParallelLinear(
170
+ context_dim,
171
+ self.hidden_size,
172
+ bias=bias,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("down_proj", prefix),
175
+ )
176
+ self.extra_activation_func = nn.GELU()
177
+
178
+ def forward(self, x: torch.Tensor):
179
+ x, _ = self.proj(x)
180
+ x = self.extra_activation_func(self.post_projection_norm(x))
181
+ gate_up, _ = self.gate_up_proj(x)
182
+ gate, up = gate_up.chunk(2, dim=-1)
183
+ x = F.silu(gate) * up
184
+ x, _ = self.down_proj(x)
185
+ return x
186
+
187
+
188
+ class Glm4vVisionEmbeddings(nn.Module):
189
+ def __init__(self, config: Glm4vVisionConfig):
190
+ super().__init__()
191
+ self.config = config
192
+ self.embed_dim = config.hidden_size
193
+ self.image_size = config.image_size
194
+ self.patch_size = config.patch_size
195
+
196
+ self.num_patches = (self.image_size // self.patch_size) ** 2
197
+ self.num_positions = self.num_patches
198
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
199
+ self.register_buffer(
200
+ "position_ids",
201
+ torch.arange(self.num_positions).expand((1, -1)),
202
+ persistent=False,
203
+ )
204
+
205
+ def forward(
206
+ self, embeddings, lengths, image_shapes, h_coords, w_coords
207
+ ) -> torch.Tensor:
208
+ pos_embed_weight = self.position_embedding.weight
209
+ hidden_size = pos_embed_weight.shape[1]
210
+ total_seq = h_coords.shape[0]
211
+ device = pos_embed_weight.device
212
+
213
+ # Move coordinates to correct device
214
+ h_coords, w_coords = h_coords.to(device), w_coords.to(device)
215
+
216
+ # Handle empty sequence case
217
+ if total_seq == 0:
218
+ adapted_pos_embed = torch.empty(
219
+ 0, hidden_size, device=device, dtype=pos_embed_weight.dtype
220
+ )
221
+ else:
222
+ # Convert inputs to tensors if needed
223
+ if isinstance(lengths, list):
224
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
225
+ if not isinstance(image_shapes, torch.Tensor):
226
+ image_shapes = torch.tensor(
227
+ image_shapes, device=device, dtype=torch.long
228
+ )
229
+
230
+ # Prepare 2D position embedding
231
+ orig_size_sq = pos_embed_weight.shape[0]
232
+ orig_size = int(orig_size_sq**0.5)
233
+ pos_embed_2d = (
234
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
235
+ .permute(2, 0, 1)
236
+ .unsqueeze(0)
237
+ .to(device=device, dtype=torch.float32)
238
+ )
239
+
240
+ # Calculate target dimensions for each patch
241
+ target_h = torch.cat(
242
+ [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
243
+ ).to(device=device, dtype=torch.float32)
244
+ target_w = torch.cat(
245
+ [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
246
+ ).to(device=device, dtype=torch.float32)
247
+
248
+ # Normalize coordinates to [-1, 1] range for grid_sample
249
+ h_coords = h_coords.to(device=device, dtype=torch.float32)
250
+ w_coords = w_coords.to(device=device, dtype=torch.float32)
251
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
252
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
253
+
254
+ # Create sampling grid
255
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
256
+
257
+ # Perform bicubic interpolation
258
+ interpolated_embed_fp32 = F.grid_sample(
259
+ pos_embed_2d,
260
+ grid,
261
+ mode="bicubic",
262
+ align_corners=False,
263
+ padding_mode="border",
264
+ )
265
+
266
+ # Reshape and convert back to original dtype
267
+ adapted_pos_embed_fp32 = (
268
+ interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
269
+ )
270
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
271
+ embeddings.device
272
+ )
273
+
274
+ # Add adapted position encoding to embeddings
275
+ embeddings = embeddings + adapted_pos_embed
276
+ return embeddings
277
+
278
+
279
+ class Glm4vVisionRotaryEmbedding(nn.Module):
280
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
281
+ super().__init__()
282
+ self.dim = dim
283
+ self.theta = theta
284
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
285
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
286
+ self._seq_len_cached = 0
287
+ self._freqs_cached = None
288
+
289
+ def update_freqs_cache(self, seqlen: int) -> None:
290
+ if seqlen > self._seq_len_cached:
291
+ seqlen *= 2
292
+ self._seq_len_cached = seqlen
293
+ self.inv_freq = 1.0 / (
294
+ self.theta
295
+ ** (
296
+ torch.arange(
297
+ 0,
298
+ self.dim,
299
+ 2,
300
+ dtype=torch.float,
301
+ device=self.inv_freq.device,
302
+ )
303
+ / self.dim
304
+ )
305
+ )
306
+ seq = torch.arange(
307
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
308
+ )
309
+ freqs = torch.outer(seq, self.inv_freq)
310
+ self._freqs_cached = freqs
311
+
312
+ def forward(self, seqlen: int) -> torch.Tensor:
313
+ self.update_freqs_cache(seqlen)
314
+ return self._freqs_cached[:seqlen]
315
+
316
+
317
+ class Glm4vVisionModel(nn.Module):
318
+ def __init__(
319
+ self,
320
+ vision_config: Glm4vVisionConfig,
321
+ norm_eps: float = 1e-6,
322
+ quant_config: Optional[QuantizationConfig] = None,
323
+ prefix: str = "",
324
+ ) -> None:
325
+ super().__init__()
326
+
327
+ patch_size = vision_config.patch_size
328
+ temporal_patch_size = vision_config.temporal_patch_size
329
+ in_channels = vision_config.in_channels
330
+ depth = vision_config.depth
331
+ self.hidden_size = vision_config.hidden_size
332
+ self.num_heads = vision_config.num_heads
333
+
334
+ self.patch_size = vision_config.patch_size
335
+ self.spatial_merge_size = vision_config.spatial_merge_size
336
+ self.out_hidden_size = vision_config.out_hidden_size
337
+
338
+ self.patch_embed = Glm4vVisionPatchEmbed(
339
+ patch_size=patch_size,
340
+ temporal_patch_size=temporal_patch_size,
341
+ in_channels=in_channels,
342
+ hidden_size=self.hidden_size,
343
+ )
344
+
345
+ norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
346
+ head_dim = self.hidden_size // self.num_heads
347
+ self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
348
+
349
+ self.blocks = nn.ModuleList(
350
+ [
351
+ Glm4vVisionBlock(
352
+ config=vision_config,
353
+ norm_layer=norm_layer,
354
+ quant_config=quant_config,
355
+ prefix=add_prefix(f"blocks.{layer_idx}", prefix),
356
+ )
357
+ for layer_idx in range(depth)
358
+ ]
359
+ )
360
+
361
+ self.merger = Glm4vPatchMerger(
362
+ d_model=vision_config.out_hidden_size,
363
+ context_dim=vision_config.intermediate_size,
364
+ quant_config=quant_config,
365
+ bias=False,
366
+ prefix=add_prefix("merger", prefix),
367
+ )
368
+
369
+ self.embeddings = Glm4vVisionEmbeddings(vision_config)
370
+
371
+ self.post_conv_layernorm = Glm4vRMSNorm(
372
+ vision_config.hidden_size, eps=vision_config.rms_norm_eps
373
+ )
374
+ self.downsample = nn.Conv2d(
375
+ in_channels=vision_config.hidden_size,
376
+ out_channels=vision_config.out_hidden_size,
377
+ kernel_size=vision_config.spatial_merge_size,
378
+ stride=vision_config.spatial_merge_size,
379
+ )
380
+ self.post_layernorm = Glm4vRMSNorm(
381
+ vision_config.hidden_size, eps=vision_config.rms_norm_eps
382
+ )
383
+
384
+ @property
385
+ def dtype(self) -> torch.dtype:
386
+ return self.patch_embed.proj.weight.dtype
387
+
388
+ @property
389
+ def device(self) -> torch.device:
390
+ return self.patch_embed.proj.weight.device
391
+
392
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
393
+ pos_ids = []
394
+ for t, h, w in grid_thw:
395
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
396
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
397
+ hpos_ids = (
398
+ hpos_ids.reshape(
399
+ h // self.spatial_merge_size,
400
+ self.spatial_merge_size,
401
+ w // self.spatial_merge_size,
402
+ self.spatial_merge_size,
403
+ )
404
+ .permute(0, 2, 1, 3)
405
+ .flatten()
406
+ )
407
+ wpos_ids = (
408
+ wpos_ids.reshape(
409
+ h // self.spatial_merge_size,
410
+ self.spatial_merge_size,
411
+ w // self.spatial_merge_size,
412
+ self.spatial_merge_size,
413
+ )
414
+ .permute(0, 2, 1, 3)
415
+ .flatten()
416
+ )
417
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
418
+ pos_ids = torch.cat(pos_ids, dim=0)
419
+ max_grid_size = grid_thw[:, 1:].max()
420
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
421
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
422
+ return rotary_pos_emb, pos_ids
423
+
424
+ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
425
+ # patchify
426
+ x = x.to(device=self.device, dtype=self.dtype)
427
+ x = self.patch_embed(x)
428
+ x = self.post_conv_layernorm(x)
429
+
430
+ # compute position embedding
431
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
432
+ # compute cu_seqlens
433
+ cu_seqlens = torch.repeat_interleave(
434
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
435
+ ).cumsum(dim=0, dtype=torch.int32)
436
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
437
+
438
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
439
+ x = self.embeddings(
440
+ x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
441
+ )
442
+
443
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
444
+ rotary_pos_emb_tuple = (emb.cos(), emb.sin())
445
+
446
+ # x.shape: (s, b, d) where b=1 for vision processing
447
+ # transformers
448
+ x = x.unsqueeze(1)
449
+ for blk in self.blocks:
450
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
451
+
452
+ # adapter
453
+ x = self.post_layernorm(x)
454
+ x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
455
+ x = x.permute(0, 3, 1, 2)
456
+ x = self.downsample(x).view(-1, self.out_hidden_size)
457
+ x = self.merger(x)
458
+
459
+ return x
460
+
461
+
462
+ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
463
+ def __init__(
464
+ self,
465
+ config: Glm4vConfig,
466
+ quant_config: Optional[QuantizationConfig] = None,
467
+ prefix: str = "",
468
+ ) -> None:
469
+ nn.Module.__init__(self)
470
+
471
+ self.config = config
472
+
473
+ self.model = Glm4Model(
474
+ config,
475
+ quant_config,
476
+ prefix=add_prefix("model", prefix),
477
+ )
478
+ self.visual = Glm4vVisionModel(
479
+ config.vision_config,
480
+ norm_eps=getattr(config, "rms_norm_eps", 1e-5),
481
+ quant_config=quant_config,
482
+ prefix=add_prefix("visual", prefix),
483
+ )
484
+
485
+ if config.tie_word_embeddings:
486
+ self.lm_head = self.model.embed_tokens
487
+ else:
488
+ self.lm_head = ParallelLMHead(
489
+ config.vocab_size,
490
+ config.hidden_size,
491
+ quant_config=quant_config,
492
+ prefix=add_prefix("lm_head", prefix),
493
+ )
494
+
495
+ self.logits_processor = LogitsProcessor(config)
496
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
497
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
498
+
499
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
500
+ pixel_values = torch.cat(
501
+ [item.feature.squeeze(0) for item in items], dim=0
502
+ ).type(self.visual.dtype)
503
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
504
+ # For multi-image, pixel_values is [num_of_images, L, C] shape
505
+ # assert pixel_values.dim() == 2, pixel_values.dim()
506
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
507
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
508
+ split_sizes = (
509
+ image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
510
+ ).tolist()
511
+ image_embeds = torch.split(image_embeds, split_sizes)
512
+ return torch.cat(image_embeds)
513
+
514
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
515
+ pixel_values_videos = torch.cat(
516
+ [item.feature.squeeze(0) for item in items], dim=0
517
+ ).type(self.visual.dtype)
518
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
519
+ # For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
520
+ # assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
521
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
522
+
523
+ # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
524
+ temp_frames_hw = []
525
+ for t, h, w in video_grid_thw:
526
+ repeated_row = (
527
+ torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
528
+ )
529
+ temp_frames_hw.append(repeated_row)
530
+ flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
531
+ video_embeds = self.visual(
532
+ pixel_values_videos, grid_thw=flattened_video_grid_thw
533
+ )
534
+ split_sizes = (
535
+ video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
536
+ ).tolist()
537
+ video_embeds = torch.split(video_embeds, split_sizes)
538
+ return torch.cat(video_embeds)
539
+
540
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
+ stacked_params_mapping = [
542
+ # (param_name, shard_name, shard_id)
543
+ (".qkv_proj", ".q_proj", "q"),
544
+ (".qkv_proj", ".k_proj", "k"),
545
+ (".qkv_proj", ".v_proj", "v"),
546
+ (".gate_up_proj", ".up_proj", 1),
547
+ (".gate_up_proj", ".gate_proj", 0),
548
+ ]
549
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
550
+ for name, loaded_weight in weights:
551
+ if "language_model." in name:
552
+ name = name.replace("language_model.", "")
553
+ if "model.visual." in name:
554
+ name = name.replace("model.visual.", "visual.")
555
+
556
+ if "rotary_emb.inv_freq" in name:
557
+ continue
558
+
559
+ for param_name, weight_name, shard_id in stacked_params_mapping:
560
+ if weight_name not in name:
561
+ continue
562
+ name = name.replace(weight_name, param_name)
563
+
564
+ # Skip loading extra bias for GPTQ models.
565
+ if name.endswith(".bias") and name not in params_dict:
566
+ continue
567
+ param = params_dict[name]
568
+ weight_loader = param.weight_loader
569
+ weight_loader(param, loaded_weight, shard_id)
570
+ break
571
+ else:
572
+ if "visual" in name:
573
+ # adapt to VisionAttention
574
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
575
+
576
+ try:
577
+ # Skip loading extra bias for GPTQ models.
578
+ if name.endswith(".bias") and name not in params_dict:
579
+ continue
580
+ param = params_dict[name]
581
+ except KeyError:
582
+ print(params_dict.keys())
583
+ raise
584
+
585
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
586
+ weight_loader(param, loaded_weight)
587
+
588
+
589
+ EntryClass = [Glm4vForConditionalGeneration]