sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/entrypoints/engine.py +44 -22
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +25 -15
- sglang/srt/managers/scheduler.py +263 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tp_worker.py +51 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +115 -57
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +34 -22
- sglang/srt/openai_api/protocol.py +11 -1
- sglang/srt/server_args.py +67 -22
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +88 -9
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,639 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# ruff: noqa: E501
|
3
|
+
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
4
|
+
# This file is meant to be used in kimi_vl.py only
|
5
|
+
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
6
|
+
#
|
7
|
+
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
8
|
+
#
|
9
|
+
# Licensing Information:
|
10
|
+
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
11
|
+
# - Other parts of the code are licensed under the MIT License.
|
12
|
+
#
|
13
|
+
# Apache License, Version 2.0:
|
14
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
15
|
+
# you may not use this file except in compliance with the License.
|
16
|
+
# You may obtain a copy of the License at
|
17
|
+
#
|
18
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
19
|
+
#
|
20
|
+
# Unless required by applicable law or agreed to in writing, software
|
21
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
22
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
23
|
+
# See the License for the specific language governing permissions and
|
24
|
+
# limitations under the License.
|
25
|
+
#
|
26
|
+
# MIT License:
|
27
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
28
|
+
# of this software and associated documentation files (the "Software"), to deal
|
29
|
+
# in the Software without restriction, including without limitation the rights
|
30
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
31
|
+
# copies of the Software, and to permit persons to whom the Software is
|
32
|
+
# furnished to do so, subject to the following conditions:
|
33
|
+
#
|
34
|
+
# The above copyright notice and this permission notice shall be included in all
|
35
|
+
# copies or substantial portions of the Software.
|
36
|
+
#
|
37
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
38
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
39
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
40
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
41
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
42
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
43
|
+
# SOFTWARE.
|
44
|
+
import math
|
45
|
+
from copy import deepcopy
|
46
|
+
from functools import cached_property
|
47
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
48
|
+
|
49
|
+
import torch
|
50
|
+
import torch.nn as nn
|
51
|
+
import torch.nn.functional as F
|
52
|
+
from transformers.activations import ACT2FN, PytorchGELUTanh
|
53
|
+
from transformers.modeling_utils import PreTrainedModel
|
54
|
+
|
55
|
+
try:
|
56
|
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
57
|
+
except ImportError:
|
58
|
+
flash_attn_varlen_func = None
|
59
|
+
|
60
|
+
from sglang.srt.configs import MoonViTConfig
|
61
|
+
|
62
|
+
|
63
|
+
def multihead_attention(
|
64
|
+
q: torch.Tensor,
|
65
|
+
k: torch.Tensor,
|
66
|
+
v: torch.Tensor,
|
67
|
+
q_cu_seqlens: Optional[torch.Tensor] = None,
|
68
|
+
k_cu_seqlens: Optional[torch.Tensor] = None,
|
69
|
+
):
|
70
|
+
"""Multi-head attention using flash attention 2.
|
71
|
+
This function is used to handle the case where the query, key, and value are packed.
|
72
|
+
Args:
|
73
|
+
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
|
74
|
+
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
75
|
+
The first element should be 0 and the last element should be q.shape[0].
|
76
|
+
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
77
|
+
The first element should be 0 and the last element should be k.shape[0].
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
81
|
+
where dim = num_heads * head_dim
|
82
|
+
"""
|
83
|
+
if flash_attn_varlen_func is None:
|
84
|
+
raise ImportError(
|
85
|
+
"flash_attn is not installed, this function needs flash_attn_varlen_func from flash_attn"
|
86
|
+
)
|
87
|
+
# Unified format legal check
|
88
|
+
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
89
|
+
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
|
90
|
+
assert (
|
91
|
+
k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
|
92
|
+
), "k_cu_seqlens must sum to k.shape[0]"
|
93
|
+
assert q.dtype in [
|
94
|
+
torch.bfloat16,
|
95
|
+
torch.float16,
|
96
|
+
], f"unsupported dtype {q.dtype} for multihead attn"
|
97
|
+
|
98
|
+
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
|
99
|
+
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
|
100
|
+
attn_out = flash_attn_varlen_func(
|
101
|
+
q,
|
102
|
+
k,
|
103
|
+
v,
|
104
|
+
q_cu_seqlens,
|
105
|
+
k_cu_seqlens,
|
106
|
+
max_seqlen_q,
|
107
|
+
max_seqlen_k,
|
108
|
+
causal=False,
|
109
|
+
)
|
110
|
+
attn_out = attn_out.flatten(start_dim=-2)
|
111
|
+
|
112
|
+
return attn_out
|
113
|
+
|
114
|
+
|
115
|
+
def sdpa_attention(
|
116
|
+
q: torch.Tensor,
|
117
|
+
k: torch.Tensor,
|
118
|
+
v: torch.Tensor,
|
119
|
+
q_cu_seqlens: Optional[torch.Tensor] = None,
|
120
|
+
k_cu_seqlens: Optional[torch.Tensor] = None,
|
121
|
+
) -> torch.Tensor:
|
122
|
+
"""Multi-head attention using torch scaled dot product attention.
|
123
|
+
This function is used to handle the case where the query, key, and value are packed.
|
124
|
+
Args:
|
125
|
+
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
|
126
|
+
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
127
|
+
The first element should be 0 and the last element should be q.shape[0].
|
128
|
+
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
129
|
+
The first element should be 0 and the last element should be k.shape[0].
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
133
|
+
where dim = num_heads * head_dim
|
134
|
+
"""
|
135
|
+
# Unified format legal check
|
136
|
+
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
137
|
+
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
|
138
|
+
seq_length = q.shape[0]
|
139
|
+
attention_mask = torch.zeros(
|
140
|
+
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
141
|
+
)
|
142
|
+
for i in range(1, len(q_cu_seqlens)):
|
143
|
+
attention_mask[
|
144
|
+
...,
|
145
|
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
146
|
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
147
|
+
] = True
|
148
|
+
q = q.transpose(0, 1)
|
149
|
+
k = k.transpose(0, 1)
|
150
|
+
v = v.transpose(0, 1)
|
151
|
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
152
|
+
attn_output = attn_output.transpose(0, 1)
|
153
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
154
|
+
return attn_output
|
155
|
+
|
156
|
+
|
157
|
+
VL_VISION_ATTENTION_FUNCTIONS = {
|
158
|
+
"flash_attention_2": multihead_attention,
|
159
|
+
"sdpa": sdpa_attention,
|
160
|
+
}
|
161
|
+
|
162
|
+
|
163
|
+
def _apply_rope_input_validation(x, freqs_cis):
|
164
|
+
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
165
|
+
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
|
166
|
+
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
|
167
|
+
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
|
168
|
+
|
169
|
+
|
170
|
+
def apply_rope(
|
171
|
+
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
|
172
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
173
|
+
"""
|
174
|
+
Args: (The leading dimensions of all inputs should be the same)
|
175
|
+
xq: query, tensor of shape (..., num_heads, head_dim)
|
176
|
+
xk: key, tensor of shape (..., num_heads, head_dim)
|
177
|
+
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
|
178
|
+
Returns:
|
179
|
+
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
|
180
|
+
"""
|
181
|
+
_apply_rope_input_validation(xq, freqs_cis)
|
182
|
+
_apply_rope_input_validation(xk, freqs_cis)
|
183
|
+
|
184
|
+
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
|
185
|
+
# ..., num_heads, head_dim/2
|
186
|
+
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
|
187
|
+
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
|
188
|
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
189
|
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
190
|
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
191
|
+
|
192
|
+
|
193
|
+
class Learnable2DInterpPosEmb(nn.Module):
|
194
|
+
|
195
|
+
def __init__(
|
196
|
+
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
|
197
|
+
) -> None:
|
198
|
+
super().__init__()
|
199
|
+
self.height = height
|
200
|
+
self.width = width
|
201
|
+
self.interpolation_mode = interpolation_mode
|
202
|
+
self.weight = nn.Parameter(torch.empty(height, width, dim))
|
203
|
+
self.reset_parameters()
|
204
|
+
|
205
|
+
def reset_parameters(self):
|
206
|
+
nn.init.normal_(self.weight)
|
207
|
+
|
208
|
+
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
209
|
+
pos_embs = []
|
210
|
+
for shape in grid_hws.tolist():
|
211
|
+
if shape == self.weight.shape[:-1]:
|
212
|
+
pos_embs.append(self.weight.flatten(end_dim=1))
|
213
|
+
else:
|
214
|
+
pos_embs.append(
|
215
|
+
F.interpolate(
|
216
|
+
self.weight.permute((2, 0, 1)).unsqueeze(0),
|
217
|
+
size=shape,
|
218
|
+
mode=self.interpolation_mode,
|
219
|
+
)
|
220
|
+
.squeeze(0)
|
221
|
+
.permute((1, 2, 0))
|
222
|
+
.flatten(end_dim=1)
|
223
|
+
)
|
224
|
+
out = x + torch.cat(pos_embs)
|
225
|
+
return out
|
226
|
+
|
227
|
+
|
228
|
+
class MoonVisionPatchEmbed(nn.Module):
|
229
|
+
|
230
|
+
def __init__(
|
231
|
+
self,
|
232
|
+
out_dim: int,
|
233
|
+
in_dim: int = 3,
|
234
|
+
patch_size: Union[int, Tuple[int, int]] = (14, 14),
|
235
|
+
pos_emb_height: int = 14,
|
236
|
+
pos_emb_width: int = 14,
|
237
|
+
):
|
238
|
+
super().__init__()
|
239
|
+
assert isinstance(
|
240
|
+
patch_size, (int, Sequence)
|
241
|
+
), f"Invalid patch_size type: {type(patch_size)}"
|
242
|
+
if isinstance(patch_size, int):
|
243
|
+
patch_size = (patch_size, patch_size)
|
244
|
+
assert (
|
245
|
+
len(patch_size) == 2
|
246
|
+
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
|
247
|
+
self.patch_size = patch_size
|
248
|
+
|
249
|
+
self.proj = nn.Conv2d(
|
250
|
+
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
|
251
|
+
)
|
252
|
+
|
253
|
+
self.pos_emb = Learnable2DInterpPosEmb(
|
254
|
+
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
255
|
+
)
|
256
|
+
|
257
|
+
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
|
258
|
+
"""
|
259
|
+
Args:
|
260
|
+
x (L, Channels): input tensor
|
261
|
+
grid_hw (N, 2): grid height and width
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
(L, Cout) tensor
|
265
|
+
"""
|
266
|
+
x = self.proj(x).view(x.size(0), -1)
|
267
|
+
# apply positional embedding
|
268
|
+
x = self.pos_emb(x, grid_hw)
|
269
|
+
return x
|
270
|
+
|
271
|
+
|
272
|
+
class Rope2DPosEmb(nn.Module):
|
273
|
+
"""2D rotary position embedding with multi-resolution support.
|
274
|
+
|
275
|
+
This class is intended to be used in the following way:
|
276
|
+
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
|
277
|
+
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
|
278
|
+
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
|
279
|
+
The rope is shared across all attention layers and all heads.
|
280
|
+
|
281
|
+
Refs:
|
282
|
+
- RoFormer: https://arxiv.org/abs/2104.09864
|
283
|
+
- VisionLLaMA: https://arxiv.org/abs/2403.00522
|
284
|
+
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
|
285
|
+
|
286
|
+
Args:
|
287
|
+
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
|
288
|
+
max_height (int): the maximum height of the 2D grid
|
289
|
+
max_width (int): the maximum width of the 2D grid
|
290
|
+
theta_base (float): the base of the theta
|
291
|
+
device (str): the device to store the precomputed cis
|
292
|
+
"""
|
293
|
+
|
294
|
+
def __init__(
|
295
|
+
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
|
296
|
+
):
|
297
|
+
super().__init__()
|
298
|
+
self.dim = dim
|
299
|
+
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
300
|
+
self.max_height = max_height
|
301
|
+
self.max_width = max_width
|
302
|
+
self.theta_base = theta_base
|
303
|
+
self.device = device
|
304
|
+
|
305
|
+
def extra_repr(self):
|
306
|
+
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
307
|
+
|
308
|
+
@cached_property
|
309
|
+
def precomputed_freqs_cis(self) -> torch.Tensor:
|
310
|
+
"""Calculate the cis(freqs) for each position in the 2D grid.
|
311
|
+
|
312
|
+
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
313
|
+
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
|
314
|
+
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
|
315
|
+
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
316
|
+
"""
|
317
|
+
N = self.max_height * self.max_width
|
318
|
+
flat_pos = torch.arange(0, N).float().to(self.device)
|
319
|
+
x_pos = flat_pos % self.max_width
|
320
|
+
y_pos = flat_pos // self.max_width
|
321
|
+
dim_range = (
|
322
|
+
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
|
323
|
+
) # C/4
|
324
|
+
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
325
|
+
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
326
|
+
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
|
327
|
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
|
328
|
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
|
329
|
+
# N, C/4, 2
|
330
|
+
freqs_cis = torch.cat(
|
331
|
+
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
|
332
|
+
)
|
333
|
+
# max_height, max_width, C/2
|
334
|
+
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
335
|
+
return freqs_cis
|
336
|
+
|
337
|
+
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
338
|
+
"""
|
339
|
+
Args:
|
340
|
+
grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
|
341
|
+
Returns:
|
342
|
+
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
343
|
+
"""
|
344
|
+
shapes = grid_hws.tolist()
|
345
|
+
assert all(
|
346
|
+
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
347
|
+
), (
|
348
|
+
shapes,
|
349
|
+
self.max_height,
|
350
|
+
self.max_width,
|
351
|
+
)
|
352
|
+
freqs_cis = torch.cat(
|
353
|
+
[
|
354
|
+
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
355
|
+
for h, w in shapes
|
356
|
+
],
|
357
|
+
dim=0,
|
358
|
+
)
|
359
|
+
return freqs_cis
|
360
|
+
|
361
|
+
def get_freqs_cis_by_idx(
|
362
|
+
self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
|
363
|
+
) -> torch.Tensor:
|
364
|
+
"""
|
365
|
+
Args:
|
366
|
+
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
367
|
+
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
368
|
+
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
369
|
+
Return:
|
370
|
+
freqs_cis: tensor of shape (..., dim//2)
|
371
|
+
"""
|
372
|
+
assert (
|
373
|
+
pos_idx.shape[:-1] == pos_idx_mask.shape
|
374
|
+
and pos_idx.shape[-1] == 2
|
375
|
+
and pos_idx.ndim == pos_idx_mask.ndim + 1
|
376
|
+
), (pos_idx.shape, pos_idx_mask.shape)
|
377
|
+
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
378
|
+
|
379
|
+
shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
|
380
|
+
freqs_cis = torch.ones(
|
381
|
+
shp, dtype=torch.complex64, device=self.device
|
382
|
+
) # ..., head_dim/2
|
383
|
+
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
|
384
|
+
pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
|
385
|
+
]
|
386
|
+
return freqs_cis
|
387
|
+
|
388
|
+
|
389
|
+
class MLP2(nn.Module):
|
390
|
+
"""
|
391
|
+
Args:
|
392
|
+
dims: [in_dim, hidden_dim, out_dim]
|
393
|
+
bias: whether to use bias in linear layer.
|
394
|
+
"""
|
395
|
+
|
396
|
+
def __init__(self, dims: list[int], activation, bias=True):
|
397
|
+
super().__init__()
|
398
|
+
assert len(dims) == 3
|
399
|
+
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
|
400
|
+
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
|
401
|
+
self.activation = activation
|
402
|
+
for m in [self.fc0, self.fc1]:
|
403
|
+
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
|
404
|
+
if m.bias is not None:
|
405
|
+
nn.init.zeros_(m.bias)
|
406
|
+
|
407
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
408
|
+
x = self.fc0(x)
|
409
|
+
x = self.activation(x)
|
410
|
+
return self.fc1(x)
|
411
|
+
|
412
|
+
|
413
|
+
class MoonVitEncoderLayer(nn.Module):
|
414
|
+
|
415
|
+
def __init__(
|
416
|
+
self,
|
417
|
+
num_heads: int,
|
418
|
+
hidden_dim: int,
|
419
|
+
mlp_dim: int,
|
420
|
+
*,
|
421
|
+
attn_implementation: str = "flash_attention_2", # use fa2 in sglang by default
|
422
|
+
activation=F.gelu,
|
423
|
+
attn_bias: bool = False,
|
424
|
+
):
|
425
|
+
super().__init__()
|
426
|
+
self.num_heads = num_heads
|
427
|
+
self.hidden_dim = hidden_dim
|
428
|
+
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
429
|
+
self.attn_implementation = attn_implementation
|
430
|
+
|
431
|
+
self.norm0 = nn.LayerNorm(hidden_dim)
|
432
|
+
self.norm1 = nn.LayerNorm(hidden_dim)
|
433
|
+
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
|
434
|
+
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
|
435
|
+
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
|
436
|
+
|
437
|
+
def attention_qkvpacked(
|
438
|
+
self,
|
439
|
+
x: torch.Tensor,
|
440
|
+
cu_seqlens: torch.Tensor,
|
441
|
+
rope_freqs_cis: Optional[torch.Tensor] = None,
|
442
|
+
):
|
443
|
+
"""
|
444
|
+
Args:
|
445
|
+
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
446
|
+
cu_seqlens (torch.Tensor):
|
447
|
+
"""
|
448
|
+
xqkv = self.wqkv(x)
|
449
|
+
|
450
|
+
qkv_shape = xqkv.size()[:-1] + (
|
451
|
+
3,
|
452
|
+
self.num_heads,
|
453
|
+
self.hidden_size_per_attention_head,
|
454
|
+
)
|
455
|
+
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
456
|
+
xqkv = xqkv.view(*qkv_shape)
|
457
|
+
xq, xk, xv = torch.unbind(xqkv, dim=-3)
|
458
|
+
|
459
|
+
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
|
460
|
+
|
461
|
+
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
|
462
|
+
attn_out = attn_func(
|
463
|
+
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
|
464
|
+
)
|
465
|
+
|
466
|
+
attn_out = self.wo(attn_out)
|
467
|
+
return attn_out
|
468
|
+
|
469
|
+
def forward(
|
470
|
+
self,
|
471
|
+
hidden_states: torch.Tensor,
|
472
|
+
cu_seqlens: torch.Tensor,
|
473
|
+
rope_freqs_cis: Union[torch.Tensor, None] = None,
|
474
|
+
) -> torch.Tensor:
|
475
|
+
"""
|
476
|
+
Args:
|
477
|
+
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
|
478
|
+
|
479
|
+
Returns:
|
480
|
+
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
|
481
|
+
"""
|
482
|
+
residual = hidden_states
|
483
|
+
hidden_states = self.norm0(hidden_states)
|
484
|
+
attn_out = self.attention_qkvpacked(
|
485
|
+
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
|
486
|
+
)
|
487
|
+
hidden_states = residual + attn_out
|
488
|
+
|
489
|
+
residual = hidden_states
|
490
|
+
hidden_states = self.mlp(self.norm1(hidden_states))
|
491
|
+
hidden_states = residual + hidden_states
|
492
|
+
return hidden_states
|
493
|
+
|
494
|
+
|
495
|
+
class MoonVitEncoder(nn.Module):
|
496
|
+
|
497
|
+
def __init__(
|
498
|
+
self,
|
499
|
+
hidden_dim: int,
|
500
|
+
num_layers: int,
|
501
|
+
block_cfg: dict,
|
502
|
+
) -> None:
|
503
|
+
super().__init__()
|
504
|
+
|
505
|
+
self.rope_2d = Rope2DPosEmb(
|
506
|
+
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
|
507
|
+
)
|
508
|
+
self.blocks = nn.ModuleList(
|
509
|
+
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]
|
510
|
+
)
|
511
|
+
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
512
|
+
|
513
|
+
def forward(
|
514
|
+
self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
|
515
|
+
) -> torch.Tensor:
|
516
|
+
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
|
517
|
+
|
518
|
+
lengths = torch.cat(
|
519
|
+
(
|
520
|
+
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
521
|
+
grid_hw[:, 0] * grid_hw[:, 1],
|
522
|
+
)
|
523
|
+
)
|
524
|
+
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
525
|
+
|
526
|
+
for _, block in enumerate(self.blocks):
|
527
|
+
hidden_states = block(
|
528
|
+
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
|
529
|
+
)
|
530
|
+
|
531
|
+
hidden_states = self.final_layernorm(hidden_states)
|
532
|
+
|
533
|
+
return hidden_states
|
534
|
+
|
535
|
+
|
536
|
+
def patch_merger(
|
537
|
+
x: torch.Tensor,
|
538
|
+
grid_hw: torch.Tensor,
|
539
|
+
merge_kernel_size: list[int, int] = (2, 2),
|
540
|
+
) -> List[torch.Tensor]:
|
541
|
+
d_model = x.size(-1)
|
542
|
+
|
543
|
+
outputs = []
|
544
|
+
pre_sum = 0
|
545
|
+
for x_shape in grid_hw.tolist():
|
546
|
+
height, width = x_shape[0], x_shape[1]
|
547
|
+
# Get the current sequence
|
548
|
+
seq = x[pre_sum : pre_sum + height * width]
|
549
|
+
# Reshape along self.merge_kernel_size and concat to the last dimension
|
550
|
+
kernel_height, kernel_width = merge_kernel_size
|
551
|
+
new_height, new_width = height // kernel_height, width // kernel_width
|
552
|
+
reshaped_seq = seq.view(
|
553
|
+
new_height, kernel_height, new_width, kernel_width, d_model
|
554
|
+
)
|
555
|
+
reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
|
556
|
+
padded_seq = reshaped_seq.view(
|
557
|
+
new_height * new_width, kernel_height * kernel_width, -1
|
558
|
+
)
|
559
|
+
outputs.append(padded_seq)
|
560
|
+
pre_sum += height * width
|
561
|
+
|
562
|
+
return outputs
|
563
|
+
|
564
|
+
|
565
|
+
class MoonVitVLProjector(nn.Module):
|
566
|
+
|
567
|
+
def __init__(
|
568
|
+
self,
|
569
|
+
in_channels: int,
|
570
|
+
merge_kernel_size: list[int, int],
|
571
|
+
hidden_act: str = "gelu",
|
572
|
+
ln_eps: float = 1e-5,
|
573
|
+
out_dim: int = 4096,
|
574
|
+
):
|
575
|
+
super().__init__()
|
576
|
+
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
|
577
|
+
|
578
|
+
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
|
579
|
+
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
580
|
+
self.act = ACT2FN[hidden_act]
|
581
|
+
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
|
582
|
+
|
583
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
584
|
+
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
|
585
|
+
hidden_states = self.linear_1(hidden_states)
|
586
|
+
hidden_states = self.act(hidden_states)
|
587
|
+
hidden_states = self.linear_2(hidden_states)
|
588
|
+
return hidden_states
|
589
|
+
|
590
|
+
|
591
|
+
class MoonVitPretrainedModel(PreTrainedModel):
|
592
|
+
config_class = MoonViTConfig
|
593
|
+
model_type = "moonvit"
|
594
|
+
_no_split_modules = ["PackingTransformer"]
|
595
|
+
_supports_flash_attn_2 = True
|
596
|
+
_supports_sdpa = True
|
597
|
+
|
598
|
+
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
599
|
+
super().__init__(config, *inputs, **kwargs)
|
600
|
+
config = deepcopy(config)
|
601
|
+
self.merge_kernel_size = config.merge_kernel_size
|
602
|
+
self.patch_size = config.patch_size
|
603
|
+
self.patch_embed = MoonVisionPatchEmbed(
|
604
|
+
out_dim=config.hidden_size,
|
605
|
+
patch_size=config.patch_size,
|
606
|
+
pos_emb_height=config.init_pos_emb_height,
|
607
|
+
pos_emb_width=config.init_pos_emb_width,
|
608
|
+
)
|
609
|
+
|
610
|
+
self.encoder = MoonVitEncoder(
|
611
|
+
hidden_dim=config.hidden_size,
|
612
|
+
num_layers=config.num_hidden_layers,
|
613
|
+
block_cfg={
|
614
|
+
"num_heads": config.num_attention_heads,
|
615
|
+
"hidden_dim": config.hidden_size,
|
616
|
+
"mlp_dim": config.intermediate_size,
|
617
|
+
"activation": PytorchGELUTanh(),
|
618
|
+
"attn_bias": True,
|
619
|
+
"attn_implementation": config._attn_implementation,
|
620
|
+
},
|
621
|
+
)
|
622
|
+
|
623
|
+
def forward(
|
624
|
+
self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
|
625
|
+
) -> torch.Tensor:
|
626
|
+
"""
|
627
|
+
Args:
|
628
|
+
pixel_values (torch.Tensor): The input pixel values.
|
629
|
+
grid_hw (torch.Tensor): The grid height and width.
|
630
|
+
|
631
|
+
Returns:
|
632
|
+
torch.Tensor: The output tokens.
|
633
|
+
"""
|
634
|
+
hidden_states = self.patch_embed(pixel_values, grid_hw)
|
635
|
+
hidden_states = self.encoder(hidden_states, grid_hw)
|
636
|
+
hidden_states = patch_merger(
|
637
|
+
hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
|
638
|
+
)
|
639
|
+
return hidden_states
|