sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,639 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # ruff: noqa: E501
3
+ # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
4
+ # This file is meant to be used in kimi_vl.py only
5
+ # Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
8
+ #
9
+ # Licensing Information:
10
+ # - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
11
+ # - Other parts of the code are licensed under the MIT License.
12
+ #
13
+ # Apache License, Version 2.0:
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ #
26
+ # MIT License:
27
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
28
+ # of this software and associated documentation files (the "Software"), to deal
29
+ # in the Software without restriction, including without limitation the rights
30
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
31
+ # copies of the Software, and to permit persons to whom the Software is
32
+ # furnished to do so, subject to the following conditions:
33
+ #
34
+ # The above copyright notice and this permission notice shall be included in all
35
+ # copies or substantial portions of the Software.
36
+ #
37
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
38
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
39
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
40
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
41
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
42
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
43
+ # SOFTWARE.
44
+ import math
45
+ from copy import deepcopy
46
+ from functools import cached_property
47
+ from typing import List, Optional, Sequence, Tuple, Union
48
+
49
+ import torch
50
+ import torch.nn as nn
51
+ import torch.nn.functional as F
52
+ from transformers.activations import ACT2FN, PytorchGELUTanh
53
+ from transformers.modeling_utils import PreTrainedModel
54
+
55
+ try:
56
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
57
+ except ImportError:
58
+ flash_attn_varlen_func = None
59
+
60
+ from sglang.srt.configs import MoonViTConfig
61
+
62
+
63
+ def multihead_attention(
64
+ q: torch.Tensor,
65
+ k: torch.Tensor,
66
+ v: torch.Tensor,
67
+ q_cu_seqlens: Optional[torch.Tensor] = None,
68
+ k_cu_seqlens: Optional[torch.Tensor] = None,
69
+ ):
70
+ """Multi-head attention using flash attention 2.
71
+ This function is used to handle the case where the query, key, and value are packed.
72
+ Args:
73
+ q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
74
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
75
+ The first element should be 0 and the last element should be q.shape[0].
76
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
77
+ The first element should be 0 and the last element should be k.shape[0].
78
+
79
+ Returns:
80
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
81
+ where dim = num_heads * head_dim
82
+ """
83
+ if flash_attn_varlen_func is None:
84
+ raise ImportError(
85
+ "flash_attn is not installed, this function needs flash_attn_varlen_func from flash_attn"
86
+ )
87
+ # Unified format legal check
88
+ assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
89
+ assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
90
+ assert (
91
+ k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
92
+ ), "k_cu_seqlens must sum to k.shape[0]"
93
+ assert q.dtype in [
94
+ torch.bfloat16,
95
+ torch.float16,
96
+ ], f"unsupported dtype {q.dtype} for multihead attn"
97
+
98
+ max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
99
+ max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
100
+ attn_out = flash_attn_varlen_func(
101
+ q,
102
+ k,
103
+ v,
104
+ q_cu_seqlens,
105
+ k_cu_seqlens,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ causal=False,
109
+ )
110
+ attn_out = attn_out.flatten(start_dim=-2)
111
+
112
+ return attn_out
113
+
114
+
115
+ def sdpa_attention(
116
+ q: torch.Tensor,
117
+ k: torch.Tensor,
118
+ v: torch.Tensor,
119
+ q_cu_seqlens: Optional[torch.Tensor] = None,
120
+ k_cu_seqlens: Optional[torch.Tensor] = None,
121
+ ) -> torch.Tensor:
122
+ """Multi-head attention using torch scaled dot product attention.
123
+ This function is used to handle the case where the query, key, and value are packed.
124
+ Args:
125
+ q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
126
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
127
+ The first element should be 0 and the last element should be q.shape[0].
128
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
129
+ The first element should be 0 and the last element should be k.shape[0].
130
+
131
+ Returns:
132
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
133
+ where dim = num_heads * head_dim
134
+ """
135
+ # Unified format legal check
136
+ assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
137
+ assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
138
+ seq_length = q.shape[0]
139
+ attention_mask = torch.zeros(
140
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
141
+ )
142
+ for i in range(1, len(q_cu_seqlens)):
143
+ attention_mask[
144
+ ...,
145
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
146
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
147
+ ] = True
148
+ q = q.transpose(0, 1)
149
+ k = k.transpose(0, 1)
150
+ v = v.transpose(0, 1)
151
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
152
+ attn_output = attn_output.transpose(0, 1)
153
+ attn_output = attn_output.reshape(seq_length, -1)
154
+ return attn_output
155
+
156
+
157
+ VL_VISION_ATTENTION_FUNCTIONS = {
158
+ "flash_attention_2": multihead_attention,
159
+ "sdpa": sdpa_attention,
160
+ }
161
+
162
+
163
+ def _apply_rope_input_validation(x, freqs_cis):
164
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
165
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
166
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
167
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
168
+
169
+
170
+ def apply_rope(
171
+ xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
172
+ ) -> tuple[torch.Tensor, torch.Tensor]:
173
+ """
174
+ Args: (The leading dimensions of all inputs should be the same)
175
+ xq: query, tensor of shape (..., num_heads, head_dim)
176
+ xk: key, tensor of shape (..., num_heads, head_dim)
177
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
178
+ Returns:
179
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
180
+ """
181
+ _apply_rope_input_validation(xq, freqs_cis)
182
+ _apply_rope_input_validation(xk, freqs_cis)
183
+
184
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
185
+ # ..., num_heads, head_dim/2
186
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
187
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
188
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
189
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
190
+ return xq_out.type_as(xq), xk_out.type_as(xk)
191
+
192
+
193
+ class Learnable2DInterpPosEmb(nn.Module):
194
+
195
+ def __init__(
196
+ self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
197
+ ) -> None:
198
+ super().__init__()
199
+ self.height = height
200
+ self.width = width
201
+ self.interpolation_mode = interpolation_mode
202
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
203
+ self.reset_parameters()
204
+
205
+ def reset_parameters(self):
206
+ nn.init.normal_(self.weight)
207
+
208
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
209
+ pos_embs = []
210
+ for shape in grid_hws.tolist():
211
+ if shape == self.weight.shape[:-1]:
212
+ pos_embs.append(self.weight.flatten(end_dim=1))
213
+ else:
214
+ pos_embs.append(
215
+ F.interpolate(
216
+ self.weight.permute((2, 0, 1)).unsqueeze(0),
217
+ size=shape,
218
+ mode=self.interpolation_mode,
219
+ )
220
+ .squeeze(0)
221
+ .permute((1, 2, 0))
222
+ .flatten(end_dim=1)
223
+ )
224
+ out = x + torch.cat(pos_embs)
225
+ return out
226
+
227
+
228
+ class MoonVisionPatchEmbed(nn.Module):
229
+
230
+ def __init__(
231
+ self,
232
+ out_dim: int,
233
+ in_dim: int = 3,
234
+ patch_size: Union[int, Tuple[int, int]] = (14, 14),
235
+ pos_emb_height: int = 14,
236
+ pos_emb_width: int = 14,
237
+ ):
238
+ super().__init__()
239
+ assert isinstance(
240
+ patch_size, (int, Sequence)
241
+ ), f"Invalid patch_size type: {type(patch_size)}"
242
+ if isinstance(patch_size, int):
243
+ patch_size = (patch_size, patch_size)
244
+ assert (
245
+ len(patch_size) == 2
246
+ ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
247
+ self.patch_size = patch_size
248
+
249
+ self.proj = nn.Conv2d(
250
+ in_dim, out_dim, kernel_size=patch_size, stride=patch_size
251
+ )
252
+
253
+ self.pos_emb = Learnable2DInterpPosEmb(
254
+ height=pos_emb_height, width=pos_emb_width, dim=out_dim
255
+ )
256
+
257
+ def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
258
+ """
259
+ Args:
260
+ x (L, Channels): input tensor
261
+ grid_hw (N, 2): grid height and width
262
+
263
+ Returns:
264
+ (L, Cout) tensor
265
+ """
266
+ x = self.proj(x).view(x.size(0), -1)
267
+ # apply positional embedding
268
+ x = self.pos_emb(x, grid_hw)
269
+ return x
270
+
271
+
272
+ class Rope2DPosEmb(nn.Module):
273
+ """2D rotary position embedding with multi-resolution support.
274
+
275
+ This class is intended to be used in the following way:
276
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
277
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
278
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
279
+ The rope is shared across all attention layers and all heads.
280
+
281
+ Refs:
282
+ - RoFormer: https://arxiv.org/abs/2104.09864
283
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
284
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
285
+
286
+ Args:
287
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
288
+ max_height (int): the maximum height of the 2D grid
289
+ max_width (int): the maximum width of the 2D grid
290
+ theta_base (float): the base of the theta
291
+ device (str): the device to store the precomputed cis
292
+ """
293
+
294
+ def __init__(
295
+ self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
296
+ ):
297
+ super().__init__()
298
+ self.dim = dim
299
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
300
+ self.max_height = max_height
301
+ self.max_width = max_width
302
+ self.theta_base = theta_base
303
+ self.device = device
304
+
305
+ def extra_repr(self):
306
+ return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
307
+
308
+ @cached_property
309
+ def precomputed_freqs_cis(self) -> torch.Tensor:
310
+ """Calculate the cis(freqs) for each position in the 2D grid.
311
+
312
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
313
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
314
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
315
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
316
+ """
317
+ N = self.max_height * self.max_width
318
+ flat_pos = torch.arange(0, N).float().to(self.device)
319
+ x_pos = flat_pos % self.max_width
320
+ y_pos = flat_pos // self.max_width
321
+ dim_range = (
322
+ torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
323
+ ) # C/4
324
+ freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
325
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
326
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
327
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
328
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
329
+ # N, C/4, 2
330
+ freqs_cis = torch.cat(
331
+ [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
332
+ )
333
+ # max_height, max_width, C/2
334
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
335
+ return freqs_cis
336
+
337
+ def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
338
+ """
339
+ Args:
340
+ grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
341
+ Returns:
342
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
343
+ """
344
+ shapes = grid_hws.tolist()
345
+ assert all(
346
+ 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
347
+ ), (
348
+ shapes,
349
+ self.max_height,
350
+ self.max_width,
351
+ )
352
+ freqs_cis = torch.cat(
353
+ [
354
+ self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
355
+ for h, w in shapes
356
+ ],
357
+ dim=0,
358
+ )
359
+ return freqs_cis
360
+
361
+ def get_freqs_cis_by_idx(
362
+ self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
363
+ ) -> torch.Tensor:
364
+ """
365
+ Args:
366
+ pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
367
+ pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
368
+ Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
369
+ Return:
370
+ freqs_cis: tensor of shape (..., dim//2)
371
+ """
372
+ assert (
373
+ pos_idx.shape[:-1] == pos_idx_mask.shape
374
+ and pos_idx.shape[-1] == 2
375
+ and pos_idx.ndim == pos_idx_mask.ndim + 1
376
+ ), (pos_idx.shape, pos_idx_mask.shape)
377
+ assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
378
+
379
+ shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
380
+ freqs_cis = torch.ones(
381
+ shp, dtype=torch.complex64, device=self.device
382
+ ) # ..., head_dim/2
383
+ freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
384
+ pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
385
+ ]
386
+ return freqs_cis
387
+
388
+
389
+ class MLP2(nn.Module):
390
+ """
391
+ Args:
392
+ dims: [in_dim, hidden_dim, out_dim]
393
+ bias: whether to use bias in linear layer.
394
+ """
395
+
396
+ def __init__(self, dims: list[int], activation, bias=True):
397
+ super().__init__()
398
+ assert len(dims) == 3
399
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
400
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
401
+ self.activation = activation
402
+ for m in [self.fc0, self.fc1]:
403
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
404
+ if m.bias is not None:
405
+ nn.init.zeros_(m.bias)
406
+
407
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
408
+ x = self.fc0(x)
409
+ x = self.activation(x)
410
+ return self.fc1(x)
411
+
412
+
413
+ class MoonVitEncoderLayer(nn.Module):
414
+
415
+ def __init__(
416
+ self,
417
+ num_heads: int,
418
+ hidden_dim: int,
419
+ mlp_dim: int,
420
+ *,
421
+ attn_implementation: str = "flash_attention_2", # use fa2 in sglang by default
422
+ activation=F.gelu,
423
+ attn_bias: bool = False,
424
+ ):
425
+ super().__init__()
426
+ self.num_heads = num_heads
427
+ self.hidden_dim = hidden_dim
428
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
429
+ self.attn_implementation = attn_implementation
430
+
431
+ self.norm0 = nn.LayerNorm(hidden_dim)
432
+ self.norm1 = nn.LayerNorm(hidden_dim)
433
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
434
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
435
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
436
+
437
+ def attention_qkvpacked(
438
+ self,
439
+ x: torch.Tensor,
440
+ cu_seqlens: torch.Tensor,
441
+ rope_freqs_cis: Optional[torch.Tensor] = None,
442
+ ):
443
+ """
444
+ Args:
445
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
446
+ cu_seqlens (torch.Tensor):
447
+ """
448
+ xqkv = self.wqkv(x)
449
+
450
+ qkv_shape = xqkv.size()[:-1] + (
451
+ 3,
452
+ self.num_heads,
453
+ self.hidden_size_per_attention_head,
454
+ )
455
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
456
+ xqkv = xqkv.view(*qkv_shape)
457
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
458
+
459
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
460
+
461
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
462
+ attn_out = attn_func(
463
+ xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
464
+ )
465
+
466
+ attn_out = self.wo(attn_out)
467
+ return attn_out
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ cu_seqlens: torch.Tensor,
473
+ rope_freqs_cis: Union[torch.Tensor, None] = None,
474
+ ) -> torch.Tensor:
475
+ """
476
+ Args:
477
+ hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
478
+
479
+ Returns:
480
+ output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
481
+ """
482
+ residual = hidden_states
483
+ hidden_states = self.norm0(hidden_states)
484
+ attn_out = self.attention_qkvpacked(
485
+ hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
486
+ )
487
+ hidden_states = residual + attn_out
488
+
489
+ residual = hidden_states
490
+ hidden_states = self.mlp(self.norm1(hidden_states))
491
+ hidden_states = residual + hidden_states
492
+ return hidden_states
493
+
494
+
495
+ class MoonVitEncoder(nn.Module):
496
+
497
+ def __init__(
498
+ self,
499
+ hidden_dim: int,
500
+ num_layers: int,
501
+ block_cfg: dict,
502
+ ) -> None:
503
+ super().__init__()
504
+
505
+ self.rope_2d = Rope2DPosEmb(
506
+ block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
507
+ )
508
+ self.blocks = nn.ModuleList(
509
+ [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]
510
+ )
511
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
512
+
513
+ def forward(
514
+ self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
515
+ ) -> torch.Tensor:
516
+ rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
517
+
518
+ lengths = torch.cat(
519
+ (
520
+ torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
521
+ grid_hw[:, 0] * grid_hw[:, 1],
522
+ )
523
+ )
524
+ cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
525
+
526
+ for _, block in enumerate(self.blocks):
527
+ hidden_states = block(
528
+ hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
529
+ )
530
+
531
+ hidden_states = self.final_layernorm(hidden_states)
532
+
533
+ return hidden_states
534
+
535
+
536
+ def patch_merger(
537
+ x: torch.Tensor,
538
+ grid_hw: torch.Tensor,
539
+ merge_kernel_size: list[int, int] = (2, 2),
540
+ ) -> List[torch.Tensor]:
541
+ d_model = x.size(-1)
542
+
543
+ outputs = []
544
+ pre_sum = 0
545
+ for x_shape in grid_hw.tolist():
546
+ height, width = x_shape[0], x_shape[1]
547
+ # Get the current sequence
548
+ seq = x[pre_sum : pre_sum + height * width]
549
+ # Reshape along self.merge_kernel_size and concat to the last dimension
550
+ kernel_height, kernel_width = merge_kernel_size
551
+ new_height, new_width = height // kernel_height, width // kernel_width
552
+ reshaped_seq = seq.view(
553
+ new_height, kernel_height, new_width, kernel_width, d_model
554
+ )
555
+ reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
556
+ padded_seq = reshaped_seq.view(
557
+ new_height * new_width, kernel_height * kernel_width, -1
558
+ )
559
+ outputs.append(padded_seq)
560
+ pre_sum += height * width
561
+
562
+ return outputs
563
+
564
+
565
+ class MoonVitVLProjector(nn.Module):
566
+
567
+ def __init__(
568
+ self,
569
+ in_channels: int,
570
+ merge_kernel_size: list[int, int],
571
+ hidden_act: str = "gelu",
572
+ ln_eps: float = 1e-5,
573
+ out_dim: int = 4096,
574
+ ):
575
+ super().__init__()
576
+ self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
577
+
578
+ self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
579
+ self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
580
+ self.act = ACT2FN[hidden_act]
581
+ self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
582
+
583
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
584
+ hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
585
+ hidden_states = self.linear_1(hidden_states)
586
+ hidden_states = self.act(hidden_states)
587
+ hidden_states = self.linear_2(hidden_states)
588
+ return hidden_states
589
+
590
+
591
+ class MoonVitPretrainedModel(PreTrainedModel):
592
+ config_class = MoonViTConfig
593
+ model_type = "moonvit"
594
+ _no_split_modules = ["PackingTransformer"]
595
+ _supports_flash_attn_2 = True
596
+ _supports_sdpa = True
597
+
598
+ def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
599
+ super().__init__(config, *inputs, **kwargs)
600
+ config = deepcopy(config)
601
+ self.merge_kernel_size = config.merge_kernel_size
602
+ self.patch_size = config.patch_size
603
+ self.patch_embed = MoonVisionPatchEmbed(
604
+ out_dim=config.hidden_size,
605
+ patch_size=config.patch_size,
606
+ pos_emb_height=config.init_pos_emb_height,
607
+ pos_emb_width=config.init_pos_emb_width,
608
+ )
609
+
610
+ self.encoder = MoonVitEncoder(
611
+ hidden_dim=config.hidden_size,
612
+ num_layers=config.num_hidden_layers,
613
+ block_cfg={
614
+ "num_heads": config.num_attention_heads,
615
+ "hidden_dim": config.hidden_size,
616
+ "mlp_dim": config.intermediate_size,
617
+ "activation": PytorchGELUTanh(),
618
+ "attn_bias": True,
619
+ "attn_implementation": config._attn_implementation,
620
+ },
621
+ )
622
+
623
+ def forward(
624
+ self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
625
+ ) -> torch.Tensor:
626
+ """
627
+ Args:
628
+ pixel_values (torch.Tensor): The input pixel values.
629
+ grid_hw (torch.Tensor): The grid height and width.
630
+
631
+ Returns:
632
+ torch.Tensor: The output tokens.
633
+ """
634
+ hidden_states = self.patch_embed(pixel_values, grid_hw)
635
+ hidden_states = self.encoder(hidden_states, grid_hw)
636
+ hidden_states = patch_merger(
637
+ hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
638
+ )
639
+ return hidden_states