sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/models/glm4_moe.py
CHANGED
@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
|
|
154
154
|
)
|
155
155
|
self.act_fn = SiluAndMul()
|
156
156
|
|
157
|
-
def forward(self, x, forward_batch=None,
|
157
|
+
def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
|
158
158
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
159
159
|
return x
|
160
160
|
|
161
161
|
gate_up, _ = self.gate_up_proj(x)
|
162
162
|
x = self.act_fn(gate_up)
|
163
|
-
x, _ = self.down_proj(x, skip_all_reduce=
|
163
|
+
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
|
164
164
|
return x
|
165
165
|
|
166
166
|
|
@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
529
529
|
def forward_normal_dual_stream(
|
530
530
|
self,
|
531
531
|
hidden_states: torch.Tensor,
|
532
|
-
|
532
|
+
should_allreduce_fusion: bool = False,
|
533
533
|
use_reduce_scatter: bool = False,
|
534
534
|
) -> torch.Tensor:
|
535
535
|
|
@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
553
553
|
if self.ep_size > 1:
|
554
554
|
if (
|
555
555
|
self.tp_size > 1
|
556
|
-
and not
|
556
|
+
and not should_allreduce_fusion
|
557
557
|
and not use_reduce_scatter
|
558
558
|
):
|
559
559
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
564
564
|
final_hidden_states += shared_output
|
565
565
|
if (
|
566
566
|
self.tp_size > 1
|
567
|
-
and not
|
567
|
+
and not should_allreduce_fusion
|
568
568
|
and not use_reduce_scatter
|
569
569
|
):
|
570
570
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
575
575
|
def forward_normal(
|
576
576
|
self,
|
577
577
|
hidden_states: torch.Tensor,
|
578
|
-
|
578
|
+
should_allreduce_fusion: bool = False,
|
579
579
|
use_reduce_scatter: bool = False,
|
580
580
|
) -> torch.Tensor:
|
581
581
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
582
582
|
self.shared_experts.gate_up_proj
|
583
583
|
):
|
584
|
-
return self.forward_cpu(hidden_states,
|
584
|
+
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
585
585
|
|
586
586
|
shared_output = self._forward_shared_experts(hidden_states)
|
587
587
|
# router_logits: (num_tokens, n_experts)
|
@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
596
596
|
# fused in biased_grouped_topk so we can skip here
|
597
597
|
final_hidden_states *= self.routed_scaling_factor
|
598
598
|
if self.ep_size > 1:
|
599
|
-
if self.tp_size > 1 and not
|
599
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
600
600
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
601
601
|
final_hidden_states
|
602
602
|
)
|
@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
605
605
|
else:
|
606
606
|
if shared_output is not None:
|
607
607
|
final_hidden_states += shared_output
|
608
|
-
if self.tp_size > 1 and not
|
608
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
609
609
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
610
610
|
final_hidden_states
|
611
611
|
)
|
@@ -0,0 +1,589 @@
|
|
1
|
+
import logging
|
2
|
+
from functools import lru_cache, partial
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
9
|
+
|
10
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
11
|
+
from sglang.srt.layers.activation import SiluAndMul
|
12
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
13
|
+
from sglang.srt.layers.linear import (
|
14
|
+
ColumnParallelLinear,
|
15
|
+
MergedColumnParallelLinear,
|
16
|
+
RowParallelLinear,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
22
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
23
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
24
|
+
from sglang.srt.models.glm4 import Glm4Model
|
25
|
+
from sglang.srt.models.qwen2_5_vl import (
|
26
|
+
Qwen2_5_VisionBlock,
|
27
|
+
Qwen2_5_VLForConditionalGeneration,
|
28
|
+
)
|
29
|
+
from sglang.srt.utils import add_prefix
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
cached_get_processor = lru_cache(get_processor)
|
34
|
+
|
35
|
+
|
36
|
+
class Glm4vRMSNorm(RMSNorm):
|
37
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
38
|
+
original_shape = x.shape
|
39
|
+
x_2d = x.contiguous().reshape(-1, original_shape[-1])
|
40
|
+
x_2d = super().forward(x_2d)
|
41
|
+
x = x_2d.reshape(original_shape)
|
42
|
+
return x
|
43
|
+
|
44
|
+
|
45
|
+
class Glm4vVisionMLP(nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
in_features: int,
|
49
|
+
hidden_features: int,
|
50
|
+
bias: bool = False,
|
51
|
+
quant_config: Optional[QuantizationConfig] = None,
|
52
|
+
prefix: str = "",
|
53
|
+
):
|
54
|
+
super().__init__()
|
55
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
56
|
+
input_size=in_features,
|
57
|
+
output_sizes=[hidden_features] * 2,
|
58
|
+
bias=bias,
|
59
|
+
quant_config=quant_config,
|
60
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
61
|
+
)
|
62
|
+
self.down_proj = RowParallelLinear(
|
63
|
+
hidden_features,
|
64
|
+
in_features,
|
65
|
+
bias=bias,
|
66
|
+
quant_config=quant_config,
|
67
|
+
prefix=add_prefix("down_proj", prefix),
|
68
|
+
)
|
69
|
+
self.act_fn = SiluAndMul()
|
70
|
+
|
71
|
+
def forward(self, x: torch.Tensor):
|
72
|
+
gate_up, _ = self.gate_up_proj(x)
|
73
|
+
x = self.act_fn(gate_up)
|
74
|
+
x, _ = self.down_proj(x)
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
config: Glm4vVisionConfig,
|
82
|
+
norm_layer: Optional[nn.Module] = None,
|
83
|
+
quant_config: Optional[QuantizationConfig] = None,
|
84
|
+
prefix: str = "",
|
85
|
+
) -> None:
|
86
|
+
super().__init__(
|
87
|
+
dim=config.hidden_size,
|
88
|
+
intermediate_dim=config.out_hidden_size,
|
89
|
+
num_heads=config.num_heads,
|
90
|
+
hidden_act=config.hidden_act,
|
91
|
+
norm_layer=norm_layer,
|
92
|
+
quant_config=quant_config,
|
93
|
+
prefix=prefix,
|
94
|
+
)
|
95
|
+
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
96
|
+
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
97
|
+
|
98
|
+
self.mlp = Glm4vVisionMLP(
|
99
|
+
config.hidden_size,
|
100
|
+
config.out_hidden_size,
|
101
|
+
bias=False,
|
102
|
+
quant_config=quant_config,
|
103
|
+
prefix=add_prefix("mlp", prefix),
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
class Glm4vVisionPatchEmbed(nn.Module):
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
patch_size: int = 14,
|
111
|
+
temporal_patch_size: int = 2,
|
112
|
+
in_channels: int = 3,
|
113
|
+
hidden_size: int = 1536,
|
114
|
+
) -> None:
|
115
|
+
super().__init__()
|
116
|
+
self.patch_size = patch_size
|
117
|
+
self.temporal_patch_size = temporal_patch_size
|
118
|
+
self.hidden_size = hidden_size
|
119
|
+
self.in_channels = in_channels
|
120
|
+
|
121
|
+
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
122
|
+
self.proj = nn.Conv3d(
|
123
|
+
in_channels,
|
124
|
+
hidden_size,
|
125
|
+
kernel_size=kernel_size,
|
126
|
+
stride=kernel_size,
|
127
|
+
bias=True,
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
x = x.view(
|
132
|
+
-1,
|
133
|
+
self.in_channels,
|
134
|
+
self.temporal_patch_size,
|
135
|
+
self.patch_size,
|
136
|
+
self.patch_size,
|
137
|
+
)
|
138
|
+
x = self.proj(x).view(-1, self.hidden_size)
|
139
|
+
return x
|
140
|
+
|
141
|
+
|
142
|
+
class Glm4vPatchMerger(nn.Module):
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
d_model: int,
|
146
|
+
context_dim: int,
|
147
|
+
quant_config: Optional[QuantizationConfig] = None,
|
148
|
+
bias: bool = False,
|
149
|
+
prefix: str = "",
|
150
|
+
) -> None:
|
151
|
+
super().__init__()
|
152
|
+
self.hidden_size = d_model
|
153
|
+
self.proj = ColumnParallelLinear(
|
154
|
+
self.hidden_size,
|
155
|
+
self.hidden_size,
|
156
|
+
bias=bias,
|
157
|
+
quant_config=quant_config,
|
158
|
+
prefix=add_prefix("proj", prefix),
|
159
|
+
gather_output=True,
|
160
|
+
)
|
161
|
+
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
162
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
163
|
+
input_size=self.hidden_size,
|
164
|
+
output_sizes=[context_dim] * 2,
|
165
|
+
bias=bias,
|
166
|
+
quant_config=quant_config,
|
167
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
168
|
+
)
|
169
|
+
self.down_proj = RowParallelLinear(
|
170
|
+
context_dim,
|
171
|
+
self.hidden_size,
|
172
|
+
bias=bias,
|
173
|
+
quant_config=quant_config,
|
174
|
+
prefix=add_prefix("down_proj", prefix),
|
175
|
+
)
|
176
|
+
self.extra_activation_func = nn.GELU()
|
177
|
+
|
178
|
+
def forward(self, x: torch.Tensor):
|
179
|
+
x, _ = self.proj(x)
|
180
|
+
x = self.extra_activation_func(self.post_projection_norm(x))
|
181
|
+
gate_up, _ = self.gate_up_proj(x)
|
182
|
+
gate, up = gate_up.chunk(2, dim=-1)
|
183
|
+
x = F.silu(gate) * up
|
184
|
+
x, _ = self.down_proj(x)
|
185
|
+
return x
|
186
|
+
|
187
|
+
|
188
|
+
class Glm4vVisionEmbeddings(nn.Module):
|
189
|
+
def __init__(self, config: Glm4vVisionConfig):
|
190
|
+
super().__init__()
|
191
|
+
self.config = config
|
192
|
+
self.embed_dim = config.hidden_size
|
193
|
+
self.image_size = config.image_size
|
194
|
+
self.patch_size = config.patch_size
|
195
|
+
|
196
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
197
|
+
self.num_positions = self.num_patches
|
198
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
199
|
+
self.register_buffer(
|
200
|
+
"position_ids",
|
201
|
+
torch.arange(self.num_positions).expand((1, -1)),
|
202
|
+
persistent=False,
|
203
|
+
)
|
204
|
+
|
205
|
+
def forward(
|
206
|
+
self, embeddings, lengths, image_shapes, h_coords, w_coords
|
207
|
+
) -> torch.Tensor:
|
208
|
+
pos_embed_weight = self.position_embedding.weight
|
209
|
+
hidden_size = pos_embed_weight.shape[1]
|
210
|
+
total_seq = h_coords.shape[0]
|
211
|
+
device = pos_embed_weight.device
|
212
|
+
|
213
|
+
# Move coordinates to correct device
|
214
|
+
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
215
|
+
|
216
|
+
# Handle empty sequence case
|
217
|
+
if total_seq == 0:
|
218
|
+
adapted_pos_embed = torch.empty(
|
219
|
+
0, hidden_size, device=device, dtype=pos_embed_weight.dtype
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
# Convert inputs to tensors if needed
|
223
|
+
if isinstance(lengths, list):
|
224
|
+
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
225
|
+
if not isinstance(image_shapes, torch.Tensor):
|
226
|
+
image_shapes = torch.tensor(
|
227
|
+
image_shapes, device=device, dtype=torch.long
|
228
|
+
)
|
229
|
+
|
230
|
+
# Prepare 2D position embedding
|
231
|
+
orig_size_sq = pos_embed_weight.shape[0]
|
232
|
+
orig_size = int(orig_size_sq**0.5)
|
233
|
+
pos_embed_2d = (
|
234
|
+
pos_embed_weight.view(orig_size, orig_size, hidden_size)
|
235
|
+
.permute(2, 0, 1)
|
236
|
+
.unsqueeze(0)
|
237
|
+
.to(device=device, dtype=torch.float32)
|
238
|
+
)
|
239
|
+
|
240
|
+
# Calculate target dimensions for each patch
|
241
|
+
target_h = torch.cat(
|
242
|
+
[image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
|
243
|
+
).to(device=device, dtype=torch.float32)
|
244
|
+
target_w = torch.cat(
|
245
|
+
[image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
|
246
|
+
).to(device=device, dtype=torch.float32)
|
247
|
+
|
248
|
+
# Normalize coordinates to [-1, 1] range for grid_sample
|
249
|
+
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
250
|
+
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
251
|
+
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
252
|
+
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
253
|
+
|
254
|
+
# Create sampling grid
|
255
|
+
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
|
256
|
+
|
257
|
+
# Perform bicubic interpolation
|
258
|
+
interpolated_embed_fp32 = F.grid_sample(
|
259
|
+
pos_embed_2d,
|
260
|
+
grid,
|
261
|
+
mode="bicubic",
|
262
|
+
align_corners=False,
|
263
|
+
padding_mode="border",
|
264
|
+
)
|
265
|
+
|
266
|
+
# Reshape and convert back to original dtype
|
267
|
+
adapted_pos_embed_fp32 = (
|
268
|
+
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
|
269
|
+
)
|
270
|
+
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
|
271
|
+
embeddings.device
|
272
|
+
)
|
273
|
+
|
274
|
+
# Add adapted position encoding to embeddings
|
275
|
+
embeddings = embeddings + adapted_pos_embed
|
276
|
+
return embeddings
|
277
|
+
|
278
|
+
|
279
|
+
class Glm4vVisionRotaryEmbedding(nn.Module):
|
280
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
281
|
+
super().__init__()
|
282
|
+
self.dim = dim
|
283
|
+
self.theta = theta
|
284
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
285
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
286
|
+
self._seq_len_cached = 0
|
287
|
+
self._freqs_cached = None
|
288
|
+
|
289
|
+
def update_freqs_cache(self, seqlen: int) -> None:
|
290
|
+
if seqlen > self._seq_len_cached:
|
291
|
+
seqlen *= 2
|
292
|
+
self._seq_len_cached = seqlen
|
293
|
+
self.inv_freq = 1.0 / (
|
294
|
+
self.theta
|
295
|
+
** (
|
296
|
+
torch.arange(
|
297
|
+
0,
|
298
|
+
self.dim,
|
299
|
+
2,
|
300
|
+
dtype=torch.float,
|
301
|
+
device=self.inv_freq.device,
|
302
|
+
)
|
303
|
+
/ self.dim
|
304
|
+
)
|
305
|
+
)
|
306
|
+
seq = torch.arange(
|
307
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
308
|
+
)
|
309
|
+
freqs = torch.outer(seq, self.inv_freq)
|
310
|
+
self._freqs_cached = freqs
|
311
|
+
|
312
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
313
|
+
self.update_freqs_cache(seqlen)
|
314
|
+
return self._freqs_cached[:seqlen]
|
315
|
+
|
316
|
+
|
317
|
+
class Glm4vVisionModel(nn.Module):
|
318
|
+
def __init__(
|
319
|
+
self,
|
320
|
+
vision_config: Glm4vVisionConfig,
|
321
|
+
norm_eps: float = 1e-6,
|
322
|
+
quant_config: Optional[QuantizationConfig] = None,
|
323
|
+
prefix: str = "",
|
324
|
+
) -> None:
|
325
|
+
super().__init__()
|
326
|
+
|
327
|
+
patch_size = vision_config.patch_size
|
328
|
+
temporal_patch_size = vision_config.temporal_patch_size
|
329
|
+
in_channels = vision_config.in_channels
|
330
|
+
depth = vision_config.depth
|
331
|
+
self.hidden_size = vision_config.hidden_size
|
332
|
+
self.num_heads = vision_config.num_heads
|
333
|
+
|
334
|
+
self.patch_size = vision_config.patch_size
|
335
|
+
self.spatial_merge_size = vision_config.spatial_merge_size
|
336
|
+
self.out_hidden_size = vision_config.out_hidden_size
|
337
|
+
|
338
|
+
self.patch_embed = Glm4vVisionPatchEmbed(
|
339
|
+
patch_size=patch_size,
|
340
|
+
temporal_patch_size=temporal_patch_size,
|
341
|
+
in_channels=in_channels,
|
342
|
+
hidden_size=self.hidden_size,
|
343
|
+
)
|
344
|
+
|
345
|
+
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
|
346
|
+
head_dim = self.hidden_size // self.num_heads
|
347
|
+
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
348
|
+
|
349
|
+
self.blocks = nn.ModuleList(
|
350
|
+
[
|
351
|
+
Glm4vVisionBlock(
|
352
|
+
config=vision_config,
|
353
|
+
norm_layer=norm_layer,
|
354
|
+
quant_config=quant_config,
|
355
|
+
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
|
356
|
+
)
|
357
|
+
for layer_idx in range(depth)
|
358
|
+
]
|
359
|
+
)
|
360
|
+
|
361
|
+
self.merger = Glm4vPatchMerger(
|
362
|
+
d_model=vision_config.out_hidden_size,
|
363
|
+
context_dim=vision_config.intermediate_size,
|
364
|
+
quant_config=quant_config,
|
365
|
+
bias=False,
|
366
|
+
prefix=add_prefix("merger", prefix),
|
367
|
+
)
|
368
|
+
|
369
|
+
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
370
|
+
|
371
|
+
self.post_conv_layernorm = Glm4vRMSNorm(
|
372
|
+
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
373
|
+
)
|
374
|
+
self.downsample = nn.Conv2d(
|
375
|
+
in_channels=vision_config.hidden_size,
|
376
|
+
out_channels=vision_config.out_hidden_size,
|
377
|
+
kernel_size=vision_config.spatial_merge_size,
|
378
|
+
stride=vision_config.spatial_merge_size,
|
379
|
+
)
|
380
|
+
self.post_layernorm = Glm4vRMSNorm(
|
381
|
+
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
382
|
+
)
|
383
|
+
|
384
|
+
@property
|
385
|
+
def dtype(self) -> torch.dtype:
|
386
|
+
return self.patch_embed.proj.weight.dtype
|
387
|
+
|
388
|
+
@property
|
389
|
+
def device(self) -> torch.device:
|
390
|
+
return self.patch_embed.proj.weight.device
|
391
|
+
|
392
|
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
393
|
+
pos_ids = []
|
394
|
+
for t, h, w in grid_thw:
|
395
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
396
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
397
|
+
hpos_ids = (
|
398
|
+
hpos_ids.reshape(
|
399
|
+
h // self.spatial_merge_size,
|
400
|
+
self.spatial_merge_size,
|
401
|
+
w // self.spatial_merge_size,
|
402
|
+
self.spatial_merge_size,
|
403
|
+
)
|
404
|
+
.permute(0, 2, 1, 3)
|
405
|
+
.flatten()
|
406
|
+
)
|
407
|
+
wpos_ids = (
|
408
|
+
wpos_ids.reshape(
|
409
|
+
h // self.spatial_merge_size,
|
410
|
+
self.spatial_merge_size,
|
411
|
+
w // self.spatial_merge_size,
|
412
|
+
self.spatial_merge_size,
|
413
|
+
)
|
414
|
+
.permute(0, 2, 1, 3)
|
415
|
+
.flatten()
|
416
|
+
)
|
417
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
418
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
419
|
+
max_grid_size = grid_thw[:, 1:].max()
|
420
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
421
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
422
|
+
return rotary_pos_emb, pos_ids
|
423
|
+
|
424
|
+
def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
425
|
+
# patchify
|
426
|
+
x = x.to(device=self.device, dtype=self.dtype)
|
427
|
+
x = self.patch_embed(x)
|
428
|
+
x = self.post_conv_layernorm(x)
|
429
|
+
|
430
|
+
# compute position embedding
|
431
|
+
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
432
|
+
# compute cu_seqlens
|
433
|
+
cu_seqlens = torch.repeat_interleave(
|
434
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
435
|
+
).cumsum(dim=0, dtype=torch.int32)
|
436
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
437
|
+
|
438
|
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
439
|
+
x = self.embeddings(
|
440
|
+
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
441
|
+
)
|
442
|
+
|
443
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
444
|
+
rotary_pos_emb_tuple = (emb.cos(), emb.sin())
|
445
|
+
|
446
|
+
# x.shape: (s, b, d) where b=1 for vision processing
|
447
|
+
# transformers
|
448
|
+
x = x.unsqueeze(1)
|
449
|
+
for blk in self.blocks:
|
450
|
+
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
|
451
|
+
|
452
|
+
# adapter
|
453
|
+
x = self.post_layernorm(x)
|
454
|
+
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
|
455
|
+
x = x.permute(0, 3, 1, 2)
|
456
|
+
x = self.downsample(x).view(-1, self.out_hidden_size)
|
457
|
+
x = self.merger(x)
|
458
|
+
|
459
|
+
return x
|
460
|
+
|
461
|
+
|
462
|
+
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
463
|
+
def __init__(
|
464
|
+
self,
|
465
|
+
config: Glm4vConfig,
|
466
|
+
quant_config: Optional[QuantizationConfig] = None,
|
467
|
+
prefix: str = "",
|
468
|
+
) -> None:
|
469
|
+
nn.Module.__init__(self)
|
470
|
+
|
471
|
+
self.config = config
|
472
|
+
|
473
|
+
self.model = Glm4Model(
|
474
|
+
config,
|
475
|
+
quant_config,
|
476
|
+
prefix=add_prefix("model", prefix),
|
477
|
+
)
|
478
|
+
self.visual = Glm4vVisionModel(
|
479
|
+
config.vision_config,
|
480
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
481
|
+
quant_config=quant_config,
|
482
|
+
prefix=add_prefix("visual", prefix),
|
483
|
+
)
|
484
|
+
|
485
|
+
if config.tie_word_embeddings:
|
486
|
+
self.lm_head = self.model.embed_tokens
|
487
|
+
else:
|
488
|
+
self.lm_head = ParallelLMHead(
|
489
|
+
config.vocab_size,
|
490
|
+
config.hidden_size,
|
491
|
+
quant_config=quant_config,
|
492
|
+
prefix=add_prefix("lm_head", prefix),
|
493
|
+
)
|
494
|
+
|
495
|
+
self.logits_processor = LogitsProcessor(config)
|
496
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
497
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
498
|
+
|
499
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
500
|
+
pixel_values = torch.cat(
|
501
|
+
[item.feature.squeeze(0) for item in items], dim=0
|
502
|
+
).type(self.visual.dtype)
|
503
|
+
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
504
|
+
# For multi-image, pixel_values is [num_of_images, L, C] shape
|
505
|
+
# assert pixel_values.dim() == 2, pixel_values.dim()
|
506
|
+
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
507
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
508
|
+
split_sizes = (
|
509
|
+
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
510
|
+
).tolist()
|
511
|
+
image_embeds = torch.split(image_embeds, split_sizes)
|
512
|
+
return torch.cat(image_embeds)
|
513
|
+
|
514
|
+
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
515
|
+
pixel_values_videos = torch.cat(
|
516
|
+
[item.feature.squeeze(0) for item in items], dim=0
|
517
|
+
).type(self.visual.dtype)
|
518
|
+
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
519
|
+
# For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
|
520
|
+
# assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
|
521
|
+
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
522
|
+
|
523
|
+
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
524
|
+
temp_frames_hw = []
|
525
|
+
for t, h, w in video_grid_thw:
|
526
|
+
repeated_row = (
|
527
|
+
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
528
|
+
)
|
529
|
+
temp_frames_hw.append(repeated_row)
|
530
|
+
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
531
|
+
video_embeds = self.visual(
|
532
|
+
pixel_values_videos, grid_thw=flattened_video_grid_thw
|
533
|
+
)
|
534
|
+
split_sizes = (
|
535
|
+
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
536
|
+
).tolist()
|
537
|
+
video_embeds = torch.split(video_embeds, split_sizes)
|
538
|
+
return torch.cat(video_embeds)
|
539
|
+
|
540
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
541
|
+
stacked_params_mapping = [
|
542
|
+
# (param_name, shard_name, shard_id)
|
543
|
+
(".qkv_proj", ".q_proj", "q"),
|
544
|
+
(".qkv_proj", ".k_proj", "k"),
|
545
|
+
(".qkv_proj", ".v_proj", "v"),
|
546
|
+
(".gate_up_proj", ".up_proj", 1),
|
547
|
+
(".gate_up_proj", ".gate_proj", 0),
|
548
|
+
]
|
549
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
550
|
+
for name, loaded_weight in weights:
|
551
|
+
if "language_model." in name:
|
552
|
+
name = name.replace("language_model.", "")
|
553
|
+
if "model.visual." in name:
|
554
|
+
name = name.replace("model.visual.", "visual.")
|
555
|
+
|
556
|
+
if "rotary_emb.inv_freq" in name:
|
557
|
+
continue
|
558
|
+
|
559
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
560
|
+
if weight_name not in name:
|
561
|
+
continue
|
562
|
+
name = name.replace(weight_name, param_name)
|
563
|
+
|
564
|
+
# Skip loading extra bias for GPTQ models.
|
565
|
+
if name.endswith(".bias") and name not in params_dict:
|
566
|
+
continue
|
567
|
+
param = params_dict[name]
|
568
|
+
weight_loader = param.weight_loader
|
569
|
+
weight_loader(param, loaded_weight, shard_id)
|
570
|
+
break
|
571
|
+
else:
|
572
|
+
if "visual" in name:
|
573
|
+
# adapt to VisionAttention
|
574
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
575
|
+
|
576
|
+
try:
|
577
|
+
# Skip loading extra bias for GPTQ models.
|
578
|
+
if name.endswith(".bias") and name not in params_dict:
|
579
|
+
continue
|
580
|
+
param = params_dict[name]
|
581
|
+
except KeyError:
|
582
|
+
print(params_dict.keys())
|
583
|
+
raise
|
584
|
+
|
585
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
586
|
+
weight_loader(param, loaded_weight)
|
587
|
+
|
588
|
+
|
589
|
+
EntryClass = [Glm4vForConditionalGeneration]
|