sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
154
154
  )
155
155
  self.act_fn = SiluAndMul()
156
156
 
157
- def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
157
+ def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
158
158
  if (self.tp_size == 1) and x.shape[0] == 0:
159
159
  return x
160
160
 
161
161
  gate_up, _ = self.gate_up_proj(x)
162
162
  x = self.act_fn(gate_up)
163
- x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
163
+ x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
164
164
  return x
165
165
 
166
166
 
@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
529
529
  def forward_normal_dual_stream(
530
530
  self,
531
531
  hidden_states: torch.Tensor,
532
- can_fuse_mlp_allreduce: bool = False,
532
+ should_allreduce_fusion: bool = False,
533
533
  use_reduce_scatter: bool = False,
534
534
  ) -> torch.Tensor:
535
535
 
@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
553
553
  if self.ep_size > 1:
554
554
  if (
555
555
  self.tp_size > 1
556
- and not can_fuse_mlp_allreduce
556
+ and not should_allreduce_fusion
557
557
  and not use_reduce_scatter
558
558
  ):
559
559
  final_hidden_states = tensor_model_parallel_all_reduce(
@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
564
564
  final_hidden_states += shared_output
565
565
  if (
566
566
  self.tp_size > 1
567
- and not can_fuse_mlp_allreduce
567
+ and not should_allreduce_fusion
568
568
  and not use_reduce_scatter
569
569
  ):
570
570
  final_hidden_states = tensor_model_parallel_all_reduce(
@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
575
575
  def forward_normal(
576
576
  self,
577
577
  hidden_states: torch.Tensor,
578
- can_fuse_mlp_allreduce: bool = False,
578
+ should_allreduce_fusion: bool = False,
579
579
  use_reduce_scatter: bool = False,
580
580
  ) -> torch.Tensor:
581
581
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
582
582
  self.shared_experts.gate_up_proj
583
583
  ):
584
- return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
584
+ return self.forward_cpu(hidden_states, should_allreduce_fusion)
585
585
 
586
586
  shared_output = self._forward_shared_experts(hidden_states)
587
587
  # router_logits: (num_tokens, n_experts)
@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
596
596
  # fused in biased_grouped_topk so we can skip here
597
597
  final_hidden_states *= self.routed_scaling_factor
598
598
  if self.ep_size > 1:
599
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
599
+ if self.tp_size > 1 and not should_allreduce_fusion:
600
600
  final_hidden_states = tensor_model_parallel_all_reduce(
601
601
  final_hidden_states
602
602
  )
@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
605
605
  else:
606
606
  if shared_output is not None:
607
607
  final_hidden_states += shared_output
608
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
608
+ if self.tp_size > 1 and not should_allreduce_fusion:
609
609
  final_hidden_states = tensor_model_parallel_all_reduce(
610
610
  final_hidden_states
611
611
  )
@@ -0,0 +1,589 @@
1
+ import logging
2
+ from functools import lru_cache, partial
3
+ from typing import Iterable, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
9
+
10
+ from sglang.srt.hf_transformers_utils import get_processor
11
+ from sglang.srt.layers.activation import SiluAndMul
12
+ from sglang.srt.layers.layernorm import RMSNorm
13
+ from sglang.srt.layers.linear import (
14
+ ColumnParallelLinear,
15
+ MergedColumnParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
22
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem
23
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
+ from sglang.srt.models.glm4 import Glm4Model
25
+ from sglang.srt.models.qwen2_5_vl import (
26
+ Qwen2_5_VisionBlock,
27
+ Qwen2_5_VLForConditionalGeneration,
28
+ )
29
+ from sglang.srt.utils import add_prefix
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ cached_get_processor = lru_cache(get_processor)
34
+
35
+
36
+ class Glm4vRMSNorm(RMSNorm):
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ original_shape = x.shape
39
+ x_2d = x.contiguous().reshape(-1, original_shape[-1])
40
+ x_2d = super().forward(x_2d)
41
+ x = x_2d.reshape(original_shape)
42
+ return x
43
+
44
+
45
+ class Glm4vVisionMLP(nn.Module):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: int,
50
+ bias: bool = False,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ prefix: str = "",
53
+ ):
54
+ super().__init__()
55
+ self.gate_up_proj = MergedColumnParallelLinear(
56
+ input_size=in_features,
57
+ output_sizes=[hidden_features] * 2,
58
+ bias=bias,
59
+ quant_config=quant_config,
60
+ prefix=add_prefix("gate_up_proj", prefix),
61
+ )
62
+ self.down_proj = RowParallelLinear(
63
+ hidden_features,
64
+ in_features,
65
+ bias=bias,
66
+ quant_config=quant_config,
67
+ prefix=add_prefix("down_proj", prefix),
68
+ )
69
+ self.act_fn = SiluAndMul()
70
+
71
+ def forward(self, x: torch.Tensor):
72
+ gate_up, _ = self.gate_up_proj(x)
73
+ x = self.act_fn(gate_up)
74
+ x, _ = self.down_proj(x)
75
+ return x
76
+
77
+
78
+ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
79
+ def __init__(
80
+ self,
81
+ config: Glm4vVisionConfig,
82
+ norm_layer: Optional[nn.Module] = None,
83
+ quant_config: Optional[QuantizationConfig] = None,
84
+ prefix: str = "",
85
+ ) -> None:
86
+ super().__init__(
87
+ dim=config.hidden_size,
88
+ intermediate_dim=config.out_hidden_size,
89
+ num_heads=config.num_heads,
90
+ hidden_act=config.hidden_act,
91
+ norm_layer=norm_layer,
92
+ quant_config=quant_config,
93
+ prefix=prefix,
94
+ )
95
+ self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
96
+ self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
97
+
98
+ self.mlp = Glm4vVisionMLP(
99
+ config.hidden_size,
100
+ config.out_hidden_size,
101
+ bias=False,
102
+ quant_config=quant_config,
103
+ prefix=add_prefix("mlp", prefix),
104
+ )
105
+
106
+
107
+ class Glm4vVisionPatchEmbed(nn.Module):
108
+ def __init__(
109
+ self,
110
+ patch_size: int = 14,
111
+ temporal_patch_size: int = 2,
112
+ in_channels: int = 3,
113
+ hidden_size: int = 1536,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.patch_size = patch_size
117
+ self.temporal_patch_size = temporal_patch_size
118
+ self.hidden_size = hidden_size
119
+ self.in_channels = in_channels
120
+
121
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
122
+ self.proj = nn.Conv3d(
123
+ in_channels,
124
+ hidden_size,
125
+ kernel_size=kernel_size,
126
+ stride=kernel_size,
127
+ bias=True,
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x = x.view(
132
+ -1,
133
+ self.in_channels,
134
+ self.temporal_patch_size,
135
+ self.patch_size,
136
+ self.patch_size,
137
+ )
138
+ x = self.proj(x).view(-1, self.hidden_size)
139
+ return x
140
+
141
+
142
+ class Glm4vPatchMerger(nn.Module):
143
+ def __init__(
144
+ self,
145
+ d_model: int,
146
+ context_dim: int,
147
+ quant_config: Optional[QuantizationConfig] = None,
148
+ bias: bool = False,
149
+ prefix: str = "",
150
+ ) -> None:
151
+ super().__init__()
152
+ self.hidden_size = d_model
153
+ self.proj = ColumnParallelLinear(
154
+ self.hidden_size,
155
+ self.hidden_size,
156
+ bias=bias,
157
+ quant_config=quant_config,
158
+ prefix=add_prefix("proj", prefix),
159
+ gather_output=True,
160
+ )
161
+ self.post_projection_norm = nn.LayerNorm(self.hidden_size)
162
+ self.gate_up_proj = MergedColumnParallelLinear(
163
+ input_size=self.hidden_size,
164
+ output_sizes=[context_dim] * 2,
165
+ bias=bias,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("gate_up_proj", prefix),
168
+ )
169
+ self.down_proj = RowParallelLinear(
170
+ context_dim,
171
+ self.hidden_size,
172
+ bias=bias,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("down_proj", prefix),
175
+ )
176
+ self.extra_activation_func = nn.GELU()
177
+
178
+ def forward(self, x: torch.Tensor):
179
+ x, _ = self.proj(x)
180
+ x = self.extra_activation_func(self.post_projection_norm(x))
181
+ gate_up, _ = self.gate_up_proj(x)
182
+ gate, up = gate_up.chunk(2, dim=-1)
183
+ x = F.silu(gate) * up
184
+ x, _ = self.down_proj(x)
185
+ return x
186
+
187
+
188
+ class Glm4vVisionEmbeddings(nn.Module):
189
+ def __init__(self, config: Glm4vVisionConfig):
190
+ super().__init__()
191
+ self.config = config
192
+ self.embed_dim = config.hidden_size
193
+ self.image_size = config.image_size
194
+ self.patch_size = config.patch_size
195
+
196
+ self.num_patches = (self.image_size // self.patch_size) ** 2
197
+ self.num_positions = self.num_patches
198
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
199
+ self.register_buffer(
200
+ "position_ids",
201
+ torch.arange(self.num_positions).expand((1, -1)),
202
+ persistent=False,
203
+ )
204
+
205
+ def forward(
206
+ self, embeddings, lengths, image_shapes, h_coords, w_coords
207
+ ) -> torch.Tensor:
208
+ pos_embed_weight = self.position_embedding.weight
209
+ hidden_size = pos_embed_weight.shape[1]
210
+ total_seq = h_coords.shape[0]
211
+ device = pos_embed_weight.device
212
+
213
+ # Move coordinates to correct device
214
+ h_coords, w_coords = h_coords.to(device), w_coords.to(device)
215
+
216
+ # Handle empty sequence case
217
+ if total_seq == 0:
218
+ adapted_pos_embed = torch.empty(
219
+ 0, hidden_size, device=device, dtype=pos_embed_weight.dtype
220
+ )
221
+ else:
222
+ # Convert inputs to tensors if needed
223
+ if isinstance(lengths, list):
224
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
225
+ if not isinstance(image_shapes, torch.Tensor):
226
+ image_shapes = torch.tensor(
227
+ image_shapes, device=device, dtype=torch.long
228
+ )
229
+
230
+ # Prepare 2D position embedding
231
+ orig_size_sq = pos_embed_weight.shape[0]
232
+ orig_size = int(orig_size_sq**0.5)
233
+ pos_embed_2d = (
234
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
235
+ .permute(2, 0, 1)
236
+ .unsqueeze(0)
237
+ .to(device=device, dtype=torch.float32)
238
+ )
239
+
240
+ # Calculate target dimensions for each patch
241
+ target_h = torch.cat(
242
+ [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
243
+ ).to(device=device, dtype=torch.float32)
244
+ target_w = torch.cat(
245
+ [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
246
+ ).to(device=device, dtype=torch.float32)
247
+
248
+ # Normalize coordinates to [-1, 1] range for grid_sample
249
+ h_coords = h_coords.to(device=device, dtype=torch.float32)
250
+ w_coords = w_coords.to(device=device, dtype=torch.float32)
251
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
252
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
253
+
254
+ # Create sampling grid
255
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
256
+
257
+ # Perform bicubic interpolation
258
+ interpolated_embed_fp32 = F.grid_sample(
259
+ pos_embed_2d,
260
+ grid,
261
+ mode="bicubic",
262
+ align_corners=False,
263
+ padding_mode="border",
264
+ )
265
+
266
+ # Reshape and convert back to original dtype
267
+ adapted_pos_embed_fp32 = (
268
+ interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
269
+ )
270
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
271
+ embeddings.device
272
+ )
273
+
274
+ # Add adapted position encoding to embeddings
275
+ embeddings = embeddings + adapted_pos_embed
276
+ return embeddings
277
+
278
+
279
+ class Glm4vVisionRotaryEmbedding(nn.Module):
280
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
281
+ super().__init__()
282
+ self.dim = dim
283
+ self.theta = theta
284
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
285
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
286
+ self._seq_len_cached = 0
287
+ self._freqs_cached = None
288
+
289
+ def update_freqs_cache(self, seqlen: int) -> None:
290
+ if seqlen > self._seq_len_cached:
291
+ seqlen *= 2
292
+ self._seq_len_cached = seqlen
293
+ self.inv_freq = 1.0 / (
294
+ self.theta
295
+ ** (
296
+ torch.arange(
297
+ 0,
298
+ self.dim,
299
+ 2,
300
+ dtype=torch.float,
301
+ device=self.inv_freq.device,
302
+ )
303
+ / self.dim
304
+ )
305
+ )
306
+ seq = torch.arange(
307
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
308
+ )
309
+ freqs = torch.outer(seq, self.inv_freq)
310
+ self._freqs_cached = freqs
311
+
312
+ def forward(self, seqlen: int) -> torch.Tensor:
313
+ self.update_freqs_cache(seqlen)
314
+ return self._freqs_cached[:seqlen]
315
+
316
+
317
+ class Glm4vVisionModel(nn.Module):
318
+ def __init__(
319
+ self,
320
+ vision_config: Glm4vVisionConfig,
321
+ norm_eps: float = 1e-6,
322
+ quant_config: Optional[QuantizationConfig] = None,
323
+ prefix: str = "",
324
+ ) -> None:
325
+ super().__init__()
326
+
327
+ patch_size = vision_config.patch_size
328
+ temporal_patch_size = vision_config.temporal_patch_size
329
+ in_channels = vision_config.in_channels
330
+ depth = vision_config.depth
331
+ self.hidden_size = vision_config.hidden_size
332
+ self.num_heads = vision_config.num_heads
333
+
334
+ self.patch_size = vision_config.patch_size
335
+ self.spatial_merge_size = vision_config.spatial_merge_size
336
+ self.out_hidden_size = vision_config.out_hidden_size
337
+
338
+ self.patch_embed = Glm4vVisionPatchEmbed(
339
+ patch_size=patch_size,
340
+ temporal_patch_size=temporal_patch_size,
341
+ in_channels=in_channels,
342
+ hidden_size=self.hidden_size,
343
+ )
344
+
345
+ norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
346
+ head_dim = self.hidden_size // self.num_heads
347
+ self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
348
+
349
+ self.blocks = nn.ModuleList(
350
+ [
351
+ Glm4vVisionBlock(
352
+ config=vision_config,
353
+ norm_layer=norm_layer,
354
+ quant_config=quant_config,
355
+ prefix=add_prefix(f"blocks.{layer_idx}", prefix),
356
+ )
357
+ for layer_idx in range(depth)
358
+ ]
359
+ )
360
+
361
+ self.merger = Glm4vPatchMerger(
362
+ d_model=vision_config.out_hidden_size,
363
+ context_dim=vision_config.intermediate_size,
364
+ quant_config=quant_config,
365
+ bias=False,
366
+ prefix=add_prefix("merger", prefix),
367
+ )
368
+
369
+ self.embeddings = Glm4vVisionEmbeddings(vision_config)
370
+
371
+ self.post_conv_layernorm = Glm4vRMSNorm(
372
+ vision_config.hidden_size, eps=vision_config.rms_norm_eps
373
+ )
374
+ self.downsample = nn.Conv2d(
375
+ in_channels=vision_config.hidden_size,
376
+ out_channels=vision_config.out_hidden_size,
377
+ kernel_size=vision_config.spatial_merge_size,
378
+ stride=vision_config.spatial_merge_size,
379
+ )
380
+ self.post_layernorm = Glm4vRMSNorm(
381
+ vision_config.hidden_size, eps=vision_config.rms_norm_eps
382
+ )
383
+
384
+ @property
385
+ def dtype(self) -> torch.dtype:
386
+ return self.patch_embed.proj.weight.dtype
387
+
388
+ @property
389
+ def device(self) -> torch.device:
390
+ return self.patch_embed.proj.weight.device
391
+
392
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
393
+ pos_ids = []
394
+ for t, h, w in grid_thw:
395
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
396
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
397
+ hpos_ids = (
398
+ hpos_ids.reshape(
399
+ h // self.spatial_merge_size,
400
+ self.spatial_merge_size,
401
+ w // self.spatial_merge_size,
402
+ self.spatial_merge_size,
403
+ )
404
+ .permute(0, 2, 1, 3)
405
+ .flatten()
406
+ )
407
+ wpos_ids = (
408
+ wpos_ids.reshape(
409
+ h // self.spatial_merge_size,
410
+ self.spatial_merge_size,
411
+ w // self.spatial_merge_size,
412
+ self.spatial_merge_size,
413
+ )
414
+ .permute(0, 2, 1, 3)
415
+ .flatten()
416
+ )
417
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
418
+ pos_ids = torch.cat(pos_ids, dim=0)
419
+ max_grid_size = grid_thw[:, 1:].max()
420
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
421
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
422
+ return rotary_pos_emb, pos_ids
423
+
424
+ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
425
+ # patchify
426
+ x = x.to(device=self.device, dtype=self.dtype)
427
+ x = self.patch_embed(x)
428
+ x = self.post_conv_layernorm(x)
429
+
430
+ # compute position embedding
431
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
432
+ # compute cu_seqlens
433
+ cu_seqlens = torch.repeat_interleave(
434
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
435
+ ).cumsum(dim=0, dtype=torch.int32)
436
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
437
+
438
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
439
+ x = self.embeddings(
440
+ x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
441
+ )
442
+
443
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
444
+ rotary_pos_emb_tuple = (emb.cos(), emb.sin())
445
+
446
+ # x.shape: (s, b, d) where b=1 for vision processing
447
+ # transformers
448
+ x = x.unsqueeze(1)
449
+ for blk in self.blocks:
450
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
451
+
452
+ # adapter
453
+ x = self.post_layernorm(x)
454
+ x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
455
+ x = x.permute(0, 3, 1, 2)
456
+ x = self.downsample(x).view(-1, self.out_hidden_size)
457
+ x = self.merger(x)
458
+
459
+ return x
460
+
461
+
462
+ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
463
+ def __init__(
464
+ self,
465
+ config: Glm4vConfig,
466
+ quant_config: Optional[QuantizationConfig] = None,
467
+ prefix: str = "",
468
+ ) -> None:
469
+ nn.Module.__init__(self)
470
+
471
+ self.config = config
472
+
473
+ self.model = Glm4Model(
474
+ config,
475
+ quant_config,
476
+ prefix=add_prefix("model", prefix),
477
+ )
478
+ self.visual = Glm4vVisionModel(
479
+ config.vision_config,
480
+ norm_eps=getattr(config, "rms_norm_eps", 1e-5),
481
+ quant_config=quant_config,
482
+ prefix=add_prefix("visual", prefix),
483
+ )
484
+
485
+ if config.tie_word_embeddings:
486
+ self.lm_head = self.model.embed_tokens
487
+ else:
488
+ self.lm_head = ParallelLMHead(
489
+ config.vocab_size,
490
+ config.hidden_size,
491
+ quant_config=quant_config,
492
+ prefix=add_prefix("lm_head", prefix),
493
+ )
494
+
495
+ self.logits_processor = LogitsProcessor(config)
496
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
497
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
498
+
499
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
500
+ pixel_values = torch.cat(
501
+ [item.feature.squeeze(0) for item in items], dim=0
502
+ ).type(self.visual.dtype)
503
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
504
+ # For multi-image, pixel_values is [num_of_images, L, C] shape
505
+ # assert pixel_values.dim() == 2, pixel_values.dim()
506
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
507
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
508
+ split_sizes = (
509
+ image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
510
+ ).tolist()
511
+ image_embeds = torch.split(image_embeds, split_sizes)
512
+ return torch.cat(image_embeds)
513
+
514
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
515
+ pixel_values_videos = torch.cat(
516
+ [item.feature.squeeze(0) for item in items], dim=0
517
+ ).type(self.visual.dtype)
518
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
519
+ # For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
520
+ # assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
521
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
522
+
523
+ # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
524
+ temp_frames_hw = []
525
+ for t, h, w in video_grid_thw:
526
+ repeated_row = (
527
+ torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
528
+ )
529
+ temp_frames_hw.append(repeated_row)
530
+ flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
531
+ video_embeds = self.visual(
532
+ pixel_values_videos, grid_thw=flattened_video_grid_thw
533
+ )
534
+ split_sizes = (
535
+ video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
536
+ ).tolist()
537
+ video_embeds = torch.split(video_embeds, split_sizes)
538
+ return torch.cat(video_embeds)
539
+
540
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
+ stacked_params_mapping = [
542
+ # (param_name, shard_name, shard_id)
543
+ (".qkv_proj", ".q_proj", "q"),
544
+ (".qkv_proj", ".k_proj", "k"),
545
+ (".qkv_proj", ".v_proj", "v"),
546
+ (".gate_up_proj", ".up_proj", 1),
547
+ (".gate_up_proj", ".gate_proj", 0),
548
+ ]
549
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
550
+ for name, loaded_weight in weights:
551
+ if "language_model." in name:
552
+ name = name.replace("language_model.", "")
553
+ if "model.visual." in name:
554
+ name = name.replace("model.visual.", "visual.")
555
+
556
+ if "rotary_emb.inv_freq" in name:
557
+ continue
558
+
559
+ for param_name, weight_name, shard_id in stacked_params_mapping:
560
+ if weight_name not in name:
561
+ continue
562
+ name = name.replace(weight_name, param_name)
563
+
564
+ # Skip loading extra bias for GPTQ models.
565
+ if name.endswith(".bias") and name not in params_dict:
566
+ continue
567
+ param = params_dict[name]
568
+ weight_loader = param.weight_loader
569
+ weight_loader(param, loaded_weight, shard_id)
570
+ break
571
+ else:
572
+ if "visual" in name:
573
+ # adapt to VisionAttention
574
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
575
+
576
+ try:
577
+ # Skip loading extra bias for GPTQ models.
578
+ if name.endswith(".bias") and name not in params_dict:
579
+ continue
580
+ param = params_dict[name]
581
+ except KeyError:
582
+ print(params_dict.keys())
583
+ raise
584
+
585
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
586
+ weight_loader(param, loaded_weight)
587
+
588
+
589
+ EntryClass = [Glm4vForConditionalGeneration]