sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
|
|
1
|
+
import logging
|
2
|
+
from functools import lru_cache, partial
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
9
|
+
|
10
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
11
|
+
from sglang.srt.layers.activation import SiluAndMul
|
12
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
13
|
+
from sglang.srt.layers.linear import (
|
14
|
+
ColumnParallelLinear,
|
15
|
+
MergedColumnParallelLinear,
|
16
|
+
RowParallelLinear,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
22
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
23
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
24
|
+
from sglang.srt.models.glm4 import Glm4Model
|
25
|
+
from sglang.srt.models.qwen2_5_vl import (
|
26
|
+
Qwen2_5_VisionBlock,
|
27
|
+
Qwen2_5_VLForConditionalGeneration,
|
28
|
+
)
|
29
|
+
from sglang.srt.utils import add_prefix
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
cached_get_processor = lru_cache(get_processor)
|
34
|
+
|
35
|
+
|
36
|
+
class Glm4vRMSNorm(RMSNorm):
|
37
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
38
|
+
original_shape = x.shape
|
39
|
+
x_2d = x.contiguous().reshape(-1, original_shape[-1])
|
40
|
+
x_2d = super().forward(x_2d)
|
41
|
+
x = x_2d.reshape(original_shape)
|
42
|
+
return x
|
43
|
+
|
44
|
+
|
45
|
+
class Glm4vVisionMLP(nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
in_features: int,
|
49
|
+
hidden_features: int,
|
50
|
+
bias: bool = False,
|
51
|
+
quant_config: Optional[QuantizationConfig] = None,
|
52
|
+
prefix: str = "",
|
53
|
+
):
|
54
|
+
super().__init__()
|
55
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
56
|
+
input_size=in_features,
|
57
|
+
output_sizes=[hidden_features] * 2,
|
58
|
+
bias=bias,
|
59
|
+
quant_config=quant_config,
|
60
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
61
|
+
)
|
62
|
+
self.down_proj = RowParallelLinear(
|
63
|
+
hidden_features,
|
64
|
+
in_features,
|
65
|
+
bias=bias,
|
66
|
+
quant_config=quant_config,
|
67
|
+
prefix=add_prefix("down_proj", prefix),
|
68
|
+
)
|
69
|
+
self.act_fn = SiluAndMul()
|
70
|
+
|
71
|
+
def forward(self, x: torch.Tensor):
|
72
|
+
gate_up, _ = self.gate_up_proj(x)
|
73
|
+
x = self.act_fn(gate_up)
|
74
|
+
x, _ = self.down_proj(x)
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
config: Glm4vVisionConfig,
|
82
|
+
norm_layer: Optional[nn.Module] = None,
|
83
|
+
quant_config: Optional[QuantizationConfig] = None,
|
84
|
+
prefix: str = "",
|
85
|
+
) -> None:
|
86
|
+
super().__init__(
|
87
|
+
dim=config.hidden_size,
|
88
|
+
intermediate_dim=config.out_hidden_size,
|
89
|
+
num_heads=config.num_heads,
|
90
|
+
hidden_act=config.hidden_act,
|
91
|
+
norm_layer=norm_layer,
|
92
|
+
quant_config=quant_config,
|
93
|
+
prefix=prefix,
|
94
|
+
)
|
95
|
+
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
96
|
+
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
97
|
+
|
98
|
+
self.mlp = Glm4vVisionMLP(
|
99
|
+
config.hidden_size,
|
100
|
+
config.out_hidden_size,
|
101
|
+
bias=False,
|
102
|
+
quant_config=quant_config,
|
103
|
+
prefix=add_prefix("mlp", prefix),
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
class Glm4vVisionPatchEmbed(nn.Module):
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
patch_size: int = 14,
|
111
|
+
temporal_patch_size: int = 2,
|
112
|
+
in_channels: int = 3,
|
113
|
+
hidden_size: int = 1536,
|
114
|
+
) -> None:
|
115
|
+
super().__init__()
|
116
|
+
self.patch_size = patch_size
|
117
|
+
self.temporal_patch_size = temporal_patch_size
|
118
|
+
self.hidden_size = hidden_size
|
119
|
+
self.in_channels = in_channels
|
120
|
+
|
121
|
+
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
122
|
+
self.proj = nn.Conv3d(
|
123
|
+
in_channels,
|
124
|
+
hidden_size,
|
125
|
+
kernel_size=kernel_size,
|
126
|
+
stride=kernel_size,
|
127
|
+
bias=True,
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
x = x.view(
|
132
|
+
-1,
|
133
|
+
self.in_channels,
|
134
|
+
self.temporal_patch_size,
|
135
|
+
self.patch_size,
|
136
|
+
self.patch_size,
|
137
|
+
)
|
138
|
+
x = self.proj(x).view(-1, self.hidden_size)
|
139
|
+
return x
|
140
|
+
|
141
|
+
|
142
|
+
class Glm4vPatchMerger(nn.Module):
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
d_model: int,
|
146
|
+
context_dim: int,
|
147
|
+
quant_config: Optional[QuantizationConfig] = None,
|
148
|
+
bias: bool = False,
|
149
|
+
prefix: str = "",
|
150
|
+
) -> None:
|
151
|
+
super().__init__()
|
152
|
+
self.hidden_size = d_model
|
153
|
+
self.proj = ColumnParallelLinear(
|
154
|
+
self.hidden_size,
|
155
|
+
self.hidden_size,
|
156
|
+
bias=bias,
|
157
|
+
quant_config=quant_config,
|
158
|
+
prefix=add_prefix("proj", prefix),
|
159
|
+
gather_output=True,
|
160
|
+
)
|
161
|
+
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
162
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
163
|
+
input_size=self.hidden_size,
|
164
|
+
output_sizes=[context_dim] * 2,
|
165
|
+
bias=bias,
|
166
|
+
quant_config=quant_config,
|
167
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
168
|
+
)
|
169
|
+
self.down_proj = RowParallelLinear(
|
170
|
+
context_dim,
|
171
|
+
self.hidden_size,
|
172
|
+
bias=bias,
|
173
|
+
quant_config=quant_config,
|
174
|
+
prefix=add_prefix("down_proj", prefix),
|
175
|
+
)
|
176
|
+
self.extra_activation_func = nn.GELU()
|
177
|
+
|
178
|
+
def forward(self, x: torch.Tensor):
|
179
|
+
x, _ = self.proj(x)
|
180
|
+
x = self.extra_activation_func(self.post_projection_norm(x))
|
181
|
+
gate_up, _ = self.gate_up_proj(x)
|
182
|
+
gate, up = gate_up.chunk(2, dim=-1)
|
183
|
+
x = F.silu(gate) * up
|
184
|
+
x, _ = self.down_proj(x)
|
185
|
+
return x
|
186
|
+
|
187
|
+
|
188
|
+
class Glm4vVisionEmbeddings(nn.Module):
|
189
|
+
def __init__(self, config: Glm4vVisionConfig):
|
190
|
+
super().__init__()
|
191
|
+
self.config = config
|
192
|
+
self.embed_dim = config.hidden_size
|
193
|
+
self.image_size = config.image_size
|
194
|
+
self.patch_size = config.patch_size
|
195
|
+
|
196
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
197
|
+
self.num_positions = self.num_patches
|
198
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
199
|
+
self.register_buffer(
|
200
|
+
"position_ids",
|
201
|
+
torch.arange(self.num_positions).expand((1, -1)),
|
202
|
+
persistent=False,
|
203
|
+
)
|
204
|
+
|
205
|
+
def forward(
|
206
|
+
self, embeddings, lengths, image_shapes, h_coords, w_coords
|
207
|
+
) -> torch.Tensor:
|
208
|
+
pos_embed_weight = self.position_embedding.weight
|
209
|
+
hidden_size = pos_embed_weight.shape[1]
|
210
|
+
total_seq = h_coords.shape[0]
|
211
|
+
device = pos_embed_weight.device
|
212
|
+
|
213
|
+
# Move coordinates to correct device
|
214
|
+
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
215
|
+
|
216
|
+
# Handle empty sequence case
|
217
|
+
if total_seq == 0:
|
218
|
+
adapted_pos_embed = torch.empty(
|
219
|
+
0, hidden_size, device=device, dtype=pos_embed_weight.dtype
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
# Convert inputs to tensors if needed
|
223
|
+
if isinstance(lengths, list):
|
224
|
+
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
225
|
+
if not isinstance(image_shapes, torch.Tensor):
|
226
|
+
image_shapes = torch.tensor(
|
227
|
+
image_shapes, device=device, dtype=torch.long
|
228
|
+
)
|
229
|
+
|
230
|
+
# Prepare 2D position embedding
|
231
|
+
orig_size_sq = pos_embed_weight.shape[0]
|
232
|
+
orig_size = int(orig_size_sq**0.5)
|
233
|
+
pos_embed_2d = (
|
234
|
+
pos_embed_weight.view(orig_size, orig_size, hidden_size)
|
235
|
+
.permute(2, 0, 1)
|
236
|
+
.unsqueeze(0)
|
237
|
+
.to(device=device, dtype=torch.float32)
|
238
|
+
)
|
239
|
+
|
240
|
+
# Calculate target dimensions for each patch
|
241
|
+
target_h = torch.cat(
|
242
|
+
[image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
|
243
|
+
).to(device=device, dtype=torch.float32)
|
244
|
+
target_w = torch.cat(
|
245
|
+
[image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
|
246
|
+
).to(device=device, dtype=torch.float32)
|
247
|
+
|
248
|
+
# Normalize coordinates to [-1, 1] range for grid_sample
|
249
|
+
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
250
|
+
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
251
|
+
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
252
|
+
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
253
|
+
|
254
|
+
# Create sampling grid
|
255
|
+
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
|
256
|
+
|
257
|
+
# Perform bicubic interpolation
|
258
|
+
interpolated_embed_fp32 = F.grid_sample(
|
259
|
+
pos_embed_2d,
|
260
|
+
grid,
|
261
|
+
mode="bicubic",
|
262
|
+
align_corners=False,
|
263
|
+
padding_mode="border",
|
264
|
+
)
|
265
|
+
|
266
|
+
# Reshape and convert back to original dtype
|
267
|
+
adapted_pos_embed_fp32 = (
|
268
|
+
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
|
269
|
+
)
|
270
|
+
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
|
271
|
+
embeddings.device
|
272
|
+
)
|
273
|
+
|
274
|
+
# Add adapted position encoding to embeddings
|
275
|
+
embeddings = embeddings + adapted_pos_embed
|
276
|
+
return embeddings
|
277
|
+
|
278
|
+
|
279
|
+
class Glm4vVisionRotaryEmbedding(nn.Module):
|
280
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
281
|
+
super().__init__()
|
282
|
+
self.dim = dim
|
283
|
+
self.theta = theta
|
284
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
285
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
286
|
+
self._seq_len_cached = 0
|
287
|
+
self._freqs_cached = None
|
288
|
+
|
289
|
+
def update_freqs_cache(self, seqlen: int) -> None:
|
290
|
+
if seqlen > self._seq_len_cached:
|
291
|
+
seqlen *= 2
|
292
|
+
self._seq_len_cached = seqlen
|
293
|
+
self.inv_freq = 1.0 / (
|
294
|
+
self.theta
|
295
|
+
** (
|
296
|
+
torch.arange(
|
297
|
+
0,
|
298
|
+
self.dim,
|
299
|
+
2,
|
300
|
+
dtype=torch.float,
|
301
|
+
device=self.inv_freq.device,
|
302
|
+
)
|
303
|
+
/ self.dim
|
304
|
+
)
|
305
|
+
)
|
306
|
+
seq = torch.arange(
|
307
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
308
|
+
)
|
309
|
+
freqs = torch.outer(seq, self.inv_freq)
|
310
|
+
self._freqs_cached = freqs
|
311
|
+
|
312
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
313
|
+
self.update_freqs_cache(seqlen)
|
314
|
+
return self._freqs_cached[:seqlen]
|
315
|
+
|
316
|
+
|
317
|
+
class Glm4vVisionModel(nn.Module):
|
318
|
+
def __init__(
|
319
|
+
self,
|
320
|
+
vision_config: Glm4vVisionConfig,
|
321
|
+
norm_eps: float = 1e-6,
|
322
|
+
quant_config: Optional[QuantizationConfig] = None,
|
323
|
+
prefix: str = "",
|
324
|
+
) -> None:
|
325
|
+
super().__init__()
|
326
|
+
|
327
|
+
patch_size = vision_config.patch_size
|
328
|
+
temporal_patch_size = vision_config.temporal_patch_size
|
329
|
+
in_channels = vision_config.in_channels
|
330
|
+
depth = vision_config.depth
|
331
|
+
self.hidden_size = vision_config.hidden_size
|
332
|
+
self.num_heads = vision_config.num_heads
|
333
|
+
|
334
|
+
self.patch_size = vision_config.patch_size
|
335
|
+
self.spatial_merge_size = vision_config.spatial_merge_size
|
336
|
+
self.out_hidden_size = vision_config.out_hidden_size
|
337
|
+
|
338
|
+
self.patch_embed = Glm4vVisionPatchEmbed(
|
339
|
+
patch_size=patch_size,
|
340
|
+
temporal_patch_size=temporal_patch_size,
|
341
|
+
in_channels=in_channels,
|
342
|
+
hidden_size=self.hidden_size,
|
343
|
+
)
|
344
|
+
|
345
|
+
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
|
346
|
+
head_dim = self.hidden_size // self.num_heads
|
347
|
+
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
348
|
+
|
349
|
+
self.blocks = nn.ModuleList(
|
350
|
+
[
|
351
|
+
Glm4vVisionBlock(
|
352
|
+
config=vision_config,
|
353
|
+
norm_layer=norm_layer,
|
354
|
+
quant_config=quant_config,
|
355
|
+
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
|
356
|
+
)
|
357
|
+
for layer_idx in range(depth)
|
358
|
+
]
|
359
|
+
)
|
360
|
+
|
361
|
+
self.merger = Glm4vPatchMerger(
|
362
|
+
d_model=vision_config.out_hidden_size,
|
363
|
+
context_dim=vision_config.intermediate_size,
|
364
|
+
quant_config=quant_config,
|
365
|
+
bias=False,
|
366
|
+
prefix=add_prefix("merger", prefix),
|
367
|
+
)
|
368
|
+
|
369
|
+
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
370
|
+
|
371
|
+
self.post_conv_layernorm = Glm4vRMSNorm(
|
372
|
+
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
373
|
+
)
|
374
|
+
self.downsample = nn.Conv2d(
|
375
|
+
in_channels=vision_config.hidden_size,
|
376
|
+
out_channels=vision_config.out_hidden_size,
|
377
|
+
kernel_size=vision_config.spatial_merge_size,
|
378
|
+
stride=vision_config.spatial_merge_size,
|
379
|
+
)
|
380
|
+
self.post_layernorm = Glm4vRMSNorm(
|
381
|
+
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
382
|
+
)
|
383
|
+
|
384
|
+
@property
|
385
|
+
def dtype(self) -> torch.dtype:
|
386
|
+
return self.patch_embed.proj.weight.dtype
|
387
|
+
|
388
|
+
@property
|
389
|
+
def device(self) -> torch.device:
|
390
|
+
return self.patch_embed.proj.weight.device
|
391
|
+
|
392
|
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
393
|
+
pos_ids = []
|
394
|
+
for t, h, w in grid_thw:
|
395
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
396
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
397
|
+
hpos_ids = (
|
398
|
+
hpos_ids.reshape(
|
399
|
+
h // self.spatial_merge_size,
|
400
|
+
self.spatial_merge_size,
|
401
|
+
w // self.spatial_merge_size,
|
402
|
+
self.spatial_merge_size,
|
403
|
+
)
|
404
|
+
.permute(0, 2, 1, 3)
|
405
|
+
.flatten()
|
406
|
+
)
|
407
|
+
wpos_ids = (
|
408
|
+
wpos_ids.reshape(
|
409
|
+
h // self.spatial_merge_size,
|
410
|
+
self.spatial_merge_size,
|
411
|
+
w // self.spatial_merge_size,
|
412
|
+
self.spatial_merge_size,
|
413
|
+
)
|
414
|
+
.permute(0, 2, 1, 3)
|
415
|
+
.flatten()
|
416
|
+
)
|
417
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
418
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
419
|
+
max_grid_size = grid_thw[:, 1:].max()
|
420
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
421
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
422
|
+
return rotary_pos_emb, pos_ids
|
423
|
+
|
424
|
+
def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
425
|
+
# patchify
|
426
|
+
x = x.to(device=self.device, dtype=self.dtype)
|
427
|
+
x = self.patch_embed(x)
|
428
|
+
x = self.post_conv_layernorm(x)
|
429
|
+
|
430
|
+
# compute position embedding
|
431
|
+
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
432
|
+
# compute cu_seqlens
|
433
|
+
cu_seqlens = torch.repeat_interleave(
|
434
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
435
|
+
).cumsum(dim=0, dtype=torch.int32)
|
436
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
437
|
+
|
438
|
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
439
|
+
x = self.embeddings(
|
440
|
+
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
441
|
+
)
|
442
|
+
|
443
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
444
|
+
rotary_pos_emb_tuple = (emb.cos(), emb.sin())
|
445
|
+
|
446
|
+
# x.shape: (s, b, d) where b=1 for vision processing
|
447
|
+
# transformers
|
448
|
+
x = x.unsqueeze(1)
|
449
|
+
for blk in self.blocks:
|
450
|
+
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
|
451
|
+
|
452
|
+
# adapter
|
453
|
+
x = self.post_layernorm(x)
|
454
|
+
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
|
455
|
+
x = x.permute(0, 3, 1, 2)
|
456
|
+
x = self.downsample(x).view(-1, self.out_hidden_size)
|
457
|
+
x = self.merger(x)
|
458
|
+
|
459
|
+
return x
|
460
|
+
|
461
|
+
|
462
|
+
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
463
|
+
def __init__(
|
464
|
+
self,
|
465
|
+
config: Glm4vConfig,
|
466
|
+
quant_config: Optional[QuantizationConfig] = None,
|
467
|
+
prefix: str = "",
|
468
|
+
) -> None:
|
469
|
+
nn.Module.__init__(self)
|
470
|
+
|
471
|
+
self.config = config
|
472
|
+
|
473
|
+
self.model = Glm4Model(
|
474
|
+
config,
|
475
|
+
quant_config,
|
476
|
+
prefix=add_prefix("model", prefix),
|
477
|
+
)
|
478
|
+
self.visual = Glm4vVisionModel(
|
479
|
+
config.vision_config,
|
480
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
481
|
+
quant_config=quant_config,
|
482
|
+
prefix=add_prefix("visual", prefix),
|
483
|
+
)
|
484
|
+
|
485
|
+
if config.tie_word_embeddings:
|
486
|
+
self.lm_head = self.model.embed_tokens
|
487
|
+
else:
|
488
|
+
self.lm_head = ParallelLMHead(
|
489
|
+
config.vocab_size,
|
490
|
+
config.hidden_size,
|
491
|
+
quant_config=quant_config,
|
492
|
+
prefix=add_prefix("lm_head", prefix),
|
493
|
+
)
|
494
|
+
|
495
|
+
self.logits_processor = LogitsProcessor(config)
|
496
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
497
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
498
|
+
|
499
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
500
|
+
pixel_values = torch.cat(
|
501
|
+
[item.feature.squeeze(0) for item in items], dim=0
|
502
|
+
).type(self.visual.dtype)
|
503
|
+
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
504
|
+
# For multi-image, pixel_values is [num_of_images, L, C] shape
|
505
|
+
# assert pixel_values.dim() == 2, pixel_values.dim()
|
506
|
+
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
507
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
508
|
+
split_sizes = (
|
509
|
+
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
510
|
+
).tolist()
|
511
|
+
image_embeds = torch.split(image_embeds, split_sizes)
|
512
|
+
return torch.cat(image_embeds)
|
513
|
+
|
514
|
+
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
515
|
+
pixel_values_videos = torch.cat(
|
516
|
+
[item.feature.squeeze(0) for item in items], dim=0
|
517
|
+
).type(self.visual.dtype)
|
518
|
+
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
519
|
+
# For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
|
520
|
+
# assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
|
521
|
+
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
522
|
+
|
523
|
+
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
524
|
+
temp_frames_hw = []
|
525
|
+
for t, h, w in video_grid_thw:
|
526
|
+
repeated_row = (
|
527
|
+
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
528
|
+
)
|
529
|
+
temp_frames_hw.append(repeated_row)
|
530
|
+
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
531
|
+
video_embeds = self.visual(
|
532
|
+
pixel_values_videos, grid_thw=flattened_video_grid_thw
|
533
|
+
)
|
534
|
+
split_sizes = (
|
535
|
+
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
536
|
+
).tolist()
|
537
|
+
video_embeds = torch.split(video_embeds, split_sizes)
|
538
|
+
return torch.cat(video_embeds)
|
539
|
+
|
540
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
541
|
+
stacked_params_mapping = [
|
542
|
+
# (param_name, shard_name, shard_id)
|
543
|
+
(".qkv_proj", ".q_proj", "q"),
|
544
|
+
(".qkv_proj", ".k_proj", "k"),
|
545
|
+
(".qkv_proj", ".v_proj", "v"),
|
546
|
+
(".gate_up_proj", ".up_proj", 1),
|
547
|
+
(".gate_up_proj", ".gate_proj", 0),
|
548
|
+
]
|
549
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
550
|
+
for name, loaded_weight in weights:
|
551
|
+
if "language_model." in name:
|
552
|
+
name = name.replace("language_model.", "")
|
553
|
+
if "model.visual." in name:
|
554
|
+
name = name.replace("model.visual.", "visual.")
|
555
|
+
|
556
|
+
if "rotary_emb.inv_freq" in name:
|
557
|
+
continue
|
558
|
+
|
559
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
560
|
+
if weight_name not in name:
|
561
|
+
continue
|
562
|
+
name = name.replace(weight_name, param_name)
|
563
|
+
|
564
|
+
# Skip loading extra bias for GPTQ models.
|
565
|
+
if name.endswith(".bias") and name not in params_dict:
|
566
|
+
continue
|
567
|
+
param = params_dict[name]
|
568
|
+
weight_loader = param.weight_loader
|
569
|
+
weight_loader(param, loaded_weight, shard_id)
|
570
|
+
break
|
571
|
+
else:
|
572
|
+
if "visual" in name:
|
573
|
+
# adapt to VisionAttention
|
574
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
575
|
+
|
576
|
+
try:
|
577
|
+
# Skip loading extra bias for GPTQ models.
|
578
|
+
if name.endswith(".bias") and name not in params_dict:
|
579
|
+
continue
|
580
|
+
param = params_dict[name]
|
581
|
+
except KeyError:
|
582
|
+
print(params_dict.keys())
|
583
|
+
raise
|
584
|
+
|
585
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
586
|
+
weight_loader(param, loaded_weight)
|
587
|
+
|
588
|
+
|
589
|
+
EntryClass = [Glm4vForConditionalGeneration]
|