sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,670 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==========================582====================================================
14
+
15
+ from typing import Iterable, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
20
+ # Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
21
+ import torch.nn.functional as F
22
+ from einops import rearrange, repeat
23
+ from sgl_kernel.flash_attn import flash_attn_varlen_func
24
+ from torch import nn
25
+ from transformers import PretrainedConfig, PreTrainedModel
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
28
+
29
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
+ from sglang.srt.managers.mm_utils import (
31
+ MultiModalityDataPaddingPatternTokenPairs,
32
+ general_mm_embed_routine,
33
+ )
34
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
37
+ from sglang.srt.models.deepseek_janus_pro import DropPath
38
+ from sglang.srt.models.internlm2 import InternLM2ForCausalLM
39
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
40
+ from sglang.utils import logger
41
+
42
+
43
+ class FlashAttention(nn.Module):
44
+ """Implement the scaled dot product attention with softmax.
45
+ Arguments
46
+ ---------
47
+ softmax_scale: The temperature to use for the softmax attention.
48
+ (default: 1/sqrt(d_keys) where d_keys is computed at
49
+ runtime)
50
+ attention_dropout: The dropout rate to apply to the attention
51
+ (default: 0.0)
52
+ """
53
+
54
+ def __init__(
55
+ self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
56
+ ):
57
+ super().__init__()
58
+ self.softmax_scale = softmax_scale
59
+ self.dropout_p = attention_dropout
60
+
61
+ def forward(
62
+ self,
63
+ qkv,
64
+ causal=False,
65
+ max_s=None,
66
+ ):
67
+ """Implements the multihead softmax attention.
68
+ Arguments
69
+ ---------
70
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
71
+ if unpadded: (nnz, 3, h, d)
72
+ """
73
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
74
+ assert qkv.is_cuda
75
+
76
+ batch_size, seqlen, _, nheads, d = qkv.shape
77
+ if batch_size == 0 or seqlen == 0:
78
+ output_shape = (batch_size, seqlen, nheads, d)
79
+ return (
80
+ torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
81
+ None,
82
+ )
83
+
84
+ qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
85
+ q, k, v = qkv_reshaped.unbind(1)
86
+
87
+ max_s = seqlen
88
+ cu_seqlens = torch.arange(
89
+ 0,
90
+ (batch_size + 1) * seqlen,
91
+ step=seqlen,
92
+ dtype=torch.int32,
93
+ device=qkv.device,
94
+ )
95
+ output_reshaped = flash_attn_varlen_func(
96
+ q,
97
+ k,
98
+ v,
99
+ cu_seqlens,
100
+ cu_seqlens,
101
+ max_s,
102
+ max_s,
103
+ softmax_scale=self.softmax_scale,
104
+ causal=causal,
105
+ )
106
+ output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
107
+ return output, None
108
+
109
+
110
+ class InternAttention(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.config = config
114
+ self.embed_dim = config.hidden_size
115
+ self.num_heads = config.num_attention_heads
116
+ self.head_dim = self.embed_dim // self.num_heads
117
+
118
+ self.scale = self.head_dim**-0.5
119
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
120
+ self.proj_drop = nn.Dropout(config.dropout)
121
+
122
+ self.qk_normalization = config.qk_normalization
123
+
124
+ if self.qk_normalization:
125
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
126
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
127
+
128
+ self.inner_attn = FlashAttention(softmax_scale=self.scale)
129
+
130
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
131
+
132
+ def _flash_attn(
133
+ self,
134
+ x,
135
+ ):
136
+ qkv = self.qkv(x)
137
+ qkv = rearrange(
138
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
139
+ )
140
+
141
+ if self.qk_normalization:
142
+ q, k, v = qkv.unbind(2)
143
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
144
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
145
+ qkv = torch.stack([q, k, v], dim=2)
146
+
147
+ context, _ = self.inner_attn(
148
+ qkv,
149
+ )
150
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
151
+ outs = self.proj_drop(outs)
152
+ return outs
153
+
154
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155
+ x = self._flash_attn(hidden_states)
156
+ return x
157
+
158
+
159
+ class InternVisionEmbeddings(nn.Module):
160
+ def __init__(self, config: PretrainedConfig):
161
+ super().__init__()
162
+ self.config = config
163
+ self.embed_dim = config.hidden_size
164
+ self.image_size = config.image_size
165
+ self.patch_size = config.patch_size
166
+
167
+ self.class_embedding = nn.Parameter(
168
+ torch.randn(1, 1, self.embed_dim),
169
+ )
170
+
171
+ self.patch_embedding = nn.Conv2d(
172
+ in_channels=3,
173
+ out_channels=self.embed_dim,
174
+ kernel_size=self.patch_size,
175
+ stride=self.patch_size,
176
+ )
177
+
178
+ self.num_patches = (self.image_size // self.patch_size) ** 2
179
+ self.num_positions = self.num_patches + 1
180
+
181
+ self.position_embedding = nn.Parameter(
182
+ torch.randn(1, self.num_positions, self.embed_dim)
183
+ )
184
+
185
+ def _get_pos_embed(self, pos_embed, H, W):
186
+ target_dtype = pos_embed.dtype
187
+ pos_embed = (
188
+ pos_embed.float()
189
+ .reshape(
190
+ 1,
191
+ self.image_size // self.patch_size,
192
+ self.image_size // self.patch_size,
193
+ -1,
194
+ )
195
+ .permute(0, 3, 1, 2)
196
+ )
197
+ pos_embed = (
198
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
199
+ .reshape(1, -1, H * W)
200
+ .permute(0, 2, 1)
201
+ .to(target_dtype)
202
+ )
203
+ return pos_embed
204
+
205
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
206
+ target_dtype = self.patch_embedding.weight.dtype
207
+ patch_embeds = self.patch_embedding(
208
+ pixel_values
209
+ ) # shape = [*, channel, width, height]
210
+ batch_size, _, height, width = patch_embeds.shape
211
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
212
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
213
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
214
+ position_embedding = torch.cat(
215
+ [
216
+ self.position_embedding[:, :1, :],
217
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
218
+ ],
219
+ dim=1,
220
+ )
221
+ embeddings = embeddings + position_embedding.to(target_dtype)
222
+ return embeddings
223
+
224
+
225
+ class InternRMSNorm(nn.Module):
226
+ def __init__(self, hidden_size, eps=1e-6):
227
+ super().__init__()
228
+ self.weight = nn.Parameter(torch.ones(hidden_size))
229
+ self.variance_epsilon = eps
230
+
231
+ def forward(self, hidden_states):
232
+ input_dtype = hidden_states.dtype
233
+ hidden_states = hidden_states.to(torch.float32)
234
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
235
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
236
+ return self.weight * hidden_states.to(input_dtype)
237
+
238
+
239
+ class InternMLP(nn.Module):
240
+ def __init__(self, config: PretrainedConfig):
241
+ super().__init__()
242
+ self.config = config
243
+ self.act = ACT2FN[config.hidden_act]
244
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
245
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
246
+
247
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248
+ hidden_states = self.fc1(hidden_states)
249
+ hidden_states = self.act(hidden_states)
250
+ hidden_states = self.fc2(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ NORM2FN = {
255
+ "rms_norm": InternRMSNorm,
256
+ "layer_norm": nn.LayerNorm,
257
+ }
258
+
259
+
260
+ class InternVisionEncoderLayer(nn.Module):
261
+
262
+ def __init__(
263
+ self,
264
+ config: PretrainedConfig,
265
+ drop_path_rate: float,
266
+ quant_config: QuantizationConfig = None,
267
+ ):
268
+ super().__init__()
269
+ self.embed_dim = config.hidden_size
270
+ self.intermediate_size = config.intermediate_size
271
+ self.norm_type = config.norm_type
272
+ self.attn = InternAttention(config)
273
+ self.mlp = InternMLP(config)
274
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
275
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
276
+
277
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
278
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
279
+ self.drop_path1 = (
280
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
281
+ )
282
+ self.drop_path2 = (
283
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
284
+ )
285
+
286
+ def forward(
287
+ self,
288
+ hidden_states: torch.Tensor,
289
+ ) -> Tuple[
290
+ torch.FloatTensor,
291
+ Optional[torch.FloatTensor],
292
+ Optional[Tuple[torch.FloatTensor]],
293
+ ]:
294
+ """
295
+ Args:
296
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
297
+ """
298
+ hidden_states = hidden_states + self.drop_path1(
299
+ self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
300
+ )
301
+
302
+ hidden_states = hidden_states + self.drop_path2(
303
+ self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
304
+ )
305
+
306
+ return hidden_states
307
+
308
+
309
+ class InternVisionEncoder(nn.Module):
310
+ """
311
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
312
+ [`InternEncoderLayer`].
313
+
314
+ Args:
315
+ config (`InternConfig`):
316
+ The corresponding vision configuration for the `InternEncoder`.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ config: PretrainedConfig,
322
+ quant_config: Optional[QuantizationConfig] = None,
323
+ ):
324
+ super().__init__()
325
+ self.config = config
326
+ # stochastic depth decay rule
327
+ dpr = [
328
+ x.item()
329
+ for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
330
+ ]
331
+ self.layers = nn.ModuleList(
332
+ [
333
+ InternVisionEncoderLayer(config, dpr[idx], quant_config)
334
+ for idx in range(config.num_hidden_layers)
335
+ ]
336
+ )
337
+
338
+ def forward(
339
+ self,
340
+ inputs_embeds,
341
+ output_hidden_states: Optional[bool] = None,
342
+ return_dict: Optional[bool] = None,
343
+ ) -> Union[Tuple, BaseModelOutput]:
344
+ r"""
345
+ Args:
346
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
347
+ Embedded representation of the inputs. Should be float, not int tokens.
348
+ output_hidden_states (`bool`, *optional*):
349
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
350
+ for more detail.
351
+ return_dict (`bool`, *optional*):
352
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
353
+ """
354
+ output_hidden_states = (
355
+ output_hidden_states
356
+ if output_hidden_states is not None
357
+ else self.config.output_hidden_states
358
+ )
359
+ return_dict = (
360
+ return_dict if return_dict is not None else self.config.use_return_dict
361
+ )
362
+
363
+ encoder_states = () if output_hidden_states else None
364
+ hidden_states = inputs_embeds
365
+
366
+ for idx, encoder_layer in enumerate(self.layers):
367
+ if output_hidden_states:
368
+ encoder_states = encoder_states + (hidden_states,)
369
+ layer_outputs = encoder_layer(
370
+ hidden_states,
371
+ )
372
+ hidden_states = layer_outputs
373
+
374
+ if output_hidden_states:
375
+ encoder_states = encoder_states + (hidden_states,)
376
+
377
+ if not return_dict:
378
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
379
+ return BaseModelOutput(
380
+ last_hidden_state=hidden_states, hidden_states=encoder_states
381
+ )
382
+
383
+
384
+ class InternVisionModel(PreTrainedModel):
385
+ main_input_name = "pixel_values"
386
+ _supports_flash_attn_2 = True
387
+ config_class = PretrainedConfig
388
+ _no_split_modules = ["InternVisionEncoderLayer"]
389
+
390
+ def __init__(
391
+ self,
392
+ config: PretrainedConfig,
393
+ quant_config: Optional[QuantizationConfig] = None,
394
+ ):
395
+ super().__init__(config)
396
+ self.config = config
397
+
398
+ self.embeddings = InternVisionEmbeddings(
399
+ config,
400
+ )
401
+ self.encoder = InternVisionEncoder(config, quant_config)
402
+
403
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
404
+ pos_emb = self.embeddings.position_embedding
405
+ _, num_positions, embed_dim = pos_emb.shape
406
+ cls_emb = pos_emb[:, :1, :]
407
+ pos_emb = (
408
+ pos_emb[:, 1:, :]
409
+ .reshape(1, old_size // patch_size, old_size // patch_size, -1)
410
+ .permute(0, 3, 1, 2)
411
+ )
412
+ pos_emb = F.interpolate(
413
+ pos_emb.float(),
414
+ size=new_size // patch_size,
415
+ mode="bicubic",
416
+ align_corners=False,
417
+ )
418
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
419
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
420
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
421
+ self.embeddings.image_size = new_size
422
+ logger.info(
423
+ "Resized position embeddings from {} to {}".format(old_size, new_size)
424
+ )
425
+
426
+ def get_input_embeddings(self):
427
+ return self.embeddings
428
+
429
+ def forward(
430
+ self,
431
+ pixel_values: Optional[torch.FloatTensor] = None,
432
+ output_hidden_states: Optional[bool] = None,
433
+ return_dict: Optional[bool] = None,
434
+ pixel_embeds: Optional[torch.FloatTensor] = None,
435
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
436
+ pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
437
+ output_hidden_states = (
438
+ output_hidden_states
439
+ if output_hidden_states is not None
440
+ else self.config.output_hidden_states
441
+ )
442
+ return_dict = (
443
+ return_dict if return_dict is not None else self.config.use_return_dict
444
+ )
445
+
446
+ if pixel_values is None and pixel_embeds is None:
447
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
448
+
449
+ if pixel_embeds is not None:
450
+ hidden_states = pixel_embeds
451
+ else:
452
+ if len(pixel_values.shape) == 4:
453
+ hidden_states = self.embeddings(pixel_values)
454
+ else:
455
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
456
+ encoder_outputs = self.encoder(
457
+ inputs_embeds=hidden_states,
458
+ output_hidden_states=output_hidden_states,
459
+ return_dict=return_dict,
460
+ )
461
+ last_hidden_state = encoder_outputs.last_hidden_state
462
+ pooled_output = last_hidden_state[:, 0, :]
463
+
464
+ if not return_dict:
465
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
466
+
467
+ return BaseModelOutputWithPooling(
468
+ last_hidden_state=last_hidden_state,
469
+ pooler_output=pooled_output,
470
+ hidden_states=encoder_outputs.hidden_states,
471
+ attentions=encoder_outputs.attentions,
472
+ )
473
+
474
+
475
+ class InternVLChatModel(nn.Module):
476
+ def __init__(
477
+ self,
478
+ config: PretrainedConfig,
479
+ quant_config: Optional[QuantizationConfig] = None,
480
+ use_flash_attn=True,
481
+ ) -> None:
482
+ super().__init__()
483
+ self.config = config
484
+ self.quant_config = quant_config
485
+
486
+ image_size = config.force_image_size or config.vision_config.image_size
487
+ patch_size = config.vision_config.patch_size
488
+ self.patch_size = patch_size
489
+ self.select_layer = config.select_layer
490
+ self.template = config.template
491
+ self.num_image_token = int(
492
+ (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
493
+ )
494
+ self.downsample_ratio = config.downsample_ratio
495
+ self.ps_version = config.ps_version
496
+
497
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
498
+ config.llm_config._attn_implementation = (
499
+ "flash_attention_2" if use_flash_attn else "eager"
500
+ )
501
+
502
+ logger.info(f"num_image_token: {self.num_image_token}")
503
+ logger.info(f"ps_version: {self.ps_version}")
504
+
505
+ self.vision_model = InternVisionModel(config.vision_config)
506
+ if config.llm_config.architectures[0] == "Qwen2ForCausalLM":
507
+ self.language_model = Qwen2ForCausalLM(
508
+ config=config.llm_config, quant_config=quant_config
509
+ )
510
+ elif config.llm_config.architectures[0] == "InternLM2ForCausalLM":
511
+ self.language_model = InternLM2ForCausalLM(
512
+ config=config.llm_config, quant_config=quant_config
513
+ )
514
+ else:
515
+ raise NotImplementedError(
516
+ f"{config.llm_config.architectures[0]} is not implemented."
517
+ )
518
+
519
+ vit_hidden_size = config.vision_config.hidden_size
520
+ llm_hidden_size = config.llm_config.hidden_size
521
+
522
+ self.mlp1 = nn.Sequential(
523
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
524
+ nn.Linear(
525
+ vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
526
+ ),
527
+ nn.GELU(),
528
+ nn.Linear(llm_hidden_size, llm_hidden_size),
529
+ )
530
+
531
+ def pixel_shuffle(self, x, scale_factor=0.5):
532
+ n, w, h, c = x.size()
533
+ # N, W, H, C --> N, W, H * scale, C // scale
534
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
535
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
536
+ x = x.permute(0, 2, 1, 3).contiguous()
537
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
538
+ x = x.view(
539
+ n,
540
+ int(h * scale_factor),
541
+ int(w * scale_factor),
542
+ int(c / (scale_factor * scale_factor)),
543
+ )
544
+ if self.ps_version == "v1":
545
+ logger.warn(
546
+ "In ps_version 'v1', the height and width have not been swapped back, "
547
+ "which results in a transposed image."
548
+ )
549
+ else:
550
+ x = x.permute(0, 2, 1, 3).contiguous()
551
+ return x
552
+
553
+ def extract_feature(self, pixel_values):
554
+ if self.select_layer == -1:
555
+ vit_embeds = self.vision_model(
556
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
557
+ ).last_hidden_state
558
+ else:
559
+ vit_embeds = self.vision_model(
560
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
561
+ ).hidden_states[self.select_layer]
562
+ vit_embeds = vit_embeds[:, 1:, :]
563
+
564
+ h = w = int(vit_embeds.shape[1] ** 0.5)
565
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
566
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
567
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
568
+ vit_embeds = self.mlp1(vit_embeds)
569
+ return vit_embeds
570
+
571
+ def get_image_feature(self, items: List[MultimodalDataItem]):
572
+ """
573
+ Projects the last hidden state from the vision model into language model space.
574
+
575
+ Returns:
576
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
577
+ """
578
+ pixel_values = torch.cat([item.pixel_values for item in items])
579
+ image_features = self.extract_feature(pixel_values)
580
+ return image_features
581
+
582
+ @torch.no_grad()
583
+ def forward(
584
+ self,
585
+ input_ids: torch.Tensor,
586
+ positions: torch.Tensor,
587
+ forward_batch: ForwardBatch,
588
+ input_embeds: torch.Tensor = None,
589
+ ) -> torch.Tensor:
590
+
591
+ hs = general_mm_embed_routine(
592
+ input_ids=input_ids,
593
+ forward_batch=forward_batch,
594
+ language_model=self.language_model,
595
+ image_data_embedding_func=self.get_image_feature,
596
+ positions=positions,
597
+ )
598
+
599
+ return hs
600
+
601
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
602
+ # Get all special token IDs
603
+ im_start_id: int = mm_inputs.im_start_id
604
+ im_end_id: int = mm_inputs.im_end_id
605
+
606
+ media_token_pairs = [(im_start_id, im_end_id)]
607
+ helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
608
+
609
+ return helper.pad_input_tokens(input_ids, mm_inputs)
610
+
611
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
612
+ if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
613
+ stacked_params_mapping = [
614
+ # (param_name, shard_name, shard_id)
615
+ ("gate_up_proj", "w1", 0),
616
+ ("gate_up_proj", "w3", 1),
617
+ ]
618
+ elif "Qwen2ForCausalLM" in self.config.llm_config.architectures:
619
+ stacked_params_mapping = [
620
+ # (param_name, shard_name, shard_id)
621
+ ("qkv_proj", "q_proj", "q"),
622
+ ("qkv_proj", "k_proj", "k"),
623
+ ("qkv_proj", "v_proj", "v"),
624
+ ("gate_up_proj", "gate_proj", 0),
625
+ ("gate_up_proj", "up_proj", 1),
626
+ ]
627
+ params_dict = dict(self.named_parameters())
628
+
629
+ for name, loaded_weight in weights:
630
+ if "rotary_emb.inv_freq" in name:
631
+ continue
632
+ for param_name, weight_name, shard_id in stacked_params_mapping:
633
+ if weight_name not in name:
634
+ continue
635
+ name = name.replace(weight_name, param_name)
636
+ # Skip loading extra bias for GPTQ models.
637
+ if name.endswith(".bias") and name not in params_dict:
638
+ continue
639
+ param = params_dict[name]
640
+ weight_loader = param.weight_loader
641
+ weight_loader(param, loaded_weight, shard_id)
642
+ break
643
+ else:
644
+ # Skip loading extra bias for GPTQ models.
645
+ if name.endswith(".bias") and name not in params_dict:
646
+ continue
647
+ param = params_dict[name]
648
+ if "wqkv" in name:
649
+ config = self.config
650
+ kv_groups = config.num_attention_heads // config.num_key_value_heads
651
+ head_dim = config.hidden_size // config.num_attention_heads
652
+ loaded_weight = loaded_weight.view(
653
+ -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
654
+ )
655
+ wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
656
+ wq = wq.reshape(-1, wq.shape[-1])
657
+ wk = wk.reshape(-1, wk.shape[-1])
658
+ wv = wv.reshape(-1, wv.shape[-1])
659
+ weight_loader = param.weight_loader
660
+ weight_loader(param, wq, "q")
661
+ weight_loader(param, wk, "k")
662
+ weight_loader(param, wv, "v")
663
+ else:
664
+ weight_loader = getattr(
665
+ param, "weight_loader", default_weight_loader
666
+ )
667
+ weight_loader(param, loaded_weight)
668
+
669
+
670
+ EntryClass = InternVLChatModel
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  ParallelLMHead,
46
46
  VocabParallelEmbedding,
47
47
  )
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49
50
  from sglang.srt.model_loader.weight_utils import (
50
51
  default_weight_loader,
@@ -90,7 +91,7 @@ class LlamaMLP(nn.Module):
90
91
  )
91
92
  self.act_fn = SiluAndMul()
92
93
 
93
- def forward(self, x):
94
+ def forward(self, x, forward_batch=None):
94
95
  gate_up, _ = self.gate_up_proj(x)
95
96
  x = self.act_fn(gate_up)
96
97
  x, _ = self.down_proj(x)
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
420
421
  config.hidden_size,
421
422
  quant_config=quant_config,
422
423
  prefix=add_prefix("lm_head", prefix),
424
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
423
425
  )
424
426
  self.logits_processor = LogitsProcessor(config)
425
427
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)