sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Modeling from:
16
+ # ./llama.py and
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
18
+ """Inference-only GLM4 model compatible with THUDM weights."""
19
+
20
+ from typing import Iterable, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import Glm4Config
25
+
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
29
+ from sglang.srt.layers.logits_processor import LogitsProcessor
30
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
+ from sglang.srt.layers.radix_attention import RadixAttention
32
+ from sglang.srt.layers.rotary_embedding import get_rope
33
+ from sglang.srt.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.models.llama import LlamaMLP as Glm4MLP
40
+ from sglang.srt.utils import add_prefix, make_layers
41
+
42
+
43
+ class Glm4Attention(nn.Module):
44
+ def __init__(
45
+ self,
46
+ config,
47
+ layer_id: int = 0,
48
+ quant_config: Optional[QuantizationConfig] = None,
49
+ prefix: str = "",
50
+ ):
51
+ super().__init__()
52
+ self.hidden_size = config.hidden_size
53
+ tp_size = get_tensor_model_parallel_world_size()
54
+ self.total_num_heads = config.num_attention_heads
55
+ assert self.total_num_heads % tp_size == 0
56
+ self.num_heads = self.total_num_heads // tp_size
57
+ self.total_num_kv_heads = config.num_key_value_heads
58
+ if self.total_num_kv_heads >= tp_size:
59
+ # Number of KV heads is greater than TP size, so we partition
60
+ # the KV heads across multiple tensor parallel GPUs.
61
+ assert self.total_num_kv_heads % tp_size == 0
62
+ else:
63
+ # Number of KV heads is less than TP size, so we replicate
64
+ # the KV heads across multiple tensor parallel GPUs.
65
+ assert tp_size % self.total_num_kv_heads == 0
66
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
67
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
68
+ self.head_dim = config.hidden_size // self.total_num_heads
69
+ self.q_size = self.num_heads * self.head_dim
70
+ self.kv_size = self.num_kv_heads * self.head_dim
71
+ self.scaling = self.head_dim**-0.5
72
+ self.rope_theta = getattr(config, "rope_theta", 1000000)
73
+ self.rope_scaling = getattr(config, "rope_scaling", None)
74
+
75
+ self.qkv_proj = QKVParallelLinear(
76
+ self.hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ self.total_num_kv_heads,
80
+ bias=config.attention_bias,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("qkv_proj", prefix),
83
+ )
84
+ self.o_proj = RowParallelLinear(
85
+ self.total_num_heads * self.head_dim,
86
+ self.hidden_size,
87
+ bias=False,
88
+ quant_config=quant_config,
89
+ prefix=add_prefix("o_proj", prefix),
90
+ )
91
+
92
+ self.rotary_emb = get_rope(
93
+ self.head_dim,
94
+ rotary_dim=self.head_dim,
95
+ max_position=config.max_position_embeddings,
96
+ base=self.rope_theta,
97
+ rope_scaling=self.rope_scaling,
98
+ partial_rotary_factor=partial_rotary_factor,
99
+ is_neox_style=False,
100
+ )
101
+ self.attn = RadixAttention(
102
+ self.num_heads,
103
+ self.head_dim,
104
+ self.scaling,
105
+ num_kv_heads=self.num_kv_heads,
106
+ layer_id=layer_id,
107
+ quant_config=quant_config,
108
+ prefix=add_prefix("attn", prefix),
109
+ )
110
+
111
+ def forward(
112
+ self,
113
+ positions: torch.Tensor,
114
+ hidden_states: torch.Tensor,
115
+ forward_batch: ForwardBatch,
116
+ ) -> torch.Tensor:
117
+ qkv, _ = self.qkv_proj(hidden_states)
118
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
119
+ q, k = self.rotary_emb(positions, q, k)
120
+ context_layer = self.attn(
121
+ q,
122
+ k,
123
+ v,
124
+ forward_batch,
125
+ )
126
+ attn_output, _ = self.o_proj(context_layer)
127
+ return attn_output
128
+
129
+
130
+ class Glm4DecoderLayer(nn.Module):
131
+ """A single transformer layer.
132
+
133
+ Transformer layer takes input with size [s, b, h] and returns an
134
+ output of the same size.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ config,
140
+ layer_id: int,
141
+ quant_config: Optional[QuantizationConfig] = None,
142
+ prefix: str = "",
143
+ ):
144
+ super().__init__()
145
+ # Self attention.
146
+ self.self_attn = Glm4Attention(
147
+ config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
148
+ )
149
+
150
+ # MLP
151
+ self.mlp = Glm4MLP(
152
+ config.hidden_size,
153
+ intermediate_size=config.intermediate_size,
154
+ hidden_act=config.hidden_act,
155
+ quant_config=quant_config,
156
+ prefix=add_prefix("mlp", prefix),
157
+ )
158
+
159
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
160
+ self.post_attention_layernorm = RMSNorm(
161
+ config.hidden_size, eps=config.rms_norm_eps
162
+ )
163
+ self.post_self_attn_layernorm = RMSNorm(
164
+ config.hidden_size, eps=config.rms_norm_eps
165
+ )
166
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
167
+
168
+ def forward(
169
+ self,
170
+ positions: torch.Tensor,
171
+ hidden_states: torch.Tensor,
172
+ forward_batch: ForwardBatch,
173
+ residual: Optional[torch.Tensor],
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ # Self Attention
176
+ if residual is None:
177
+ residual = hidden_states
178
+ hidden_states = self.input_layernorm(hidden_states)
179
+ else:
180
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
181
+ hidden_states = self.self_attn(
182
+ positions=positions,
183
+ hidden_states=hidden_states,
184
+ forward_batch=forward_batch,
185
+ )
186
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
187
+
188
+ # Fully Connected
189
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
190
+ hidden_states = self.mlp(hidden_states)
191
+ hidden_states = self.post_mlp_layernorm(hidden_states)
192
+
193
+ return hidden_states, residual
194
+
195
+
196
+ class Glm4Model(nn.Module):
197
+ def __init__(
198
+ self,
199
+ config: Glm4Config,
200
+ quant_config: Optional[QuantizationConfig] = None,
201
+ prefix: str = "",
202
+ ) -> None:
203
+ super().__init__()
204
+ self.config = config
205
+ self.embed_tokens = VocabParallelEmbedding(
206
+ config.vocab_size,
207
+ config.hidden_size,
208
+ quant_config=quant_config,
209
+ prefix=add_prefix("embed_tokens", prefix),
210
+ )
211
+ self.layers = make_layers(
212
+ config.num_hidden_layers,
213
+ lambda idx, prefix: Glm4DecoderLayer(
214
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
215
+ ),
216
+ prefix="model.layers",
217
+ )
218
+
219
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
+
221
+ @torch.no_grad()
222
+ def forward(
223
+ self,
224
+ input_ids: torch.Tensor,
225
+ positions: torch.Tensor,
226
+ forward_batch: ForwardBatch,
227
+ input_embeds: torch.Tensor = None,
228
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
229
+ if input_embeds is None:
230
+ hidden_states = self.embed_tokens(input_ids)
231
+ else:
232
+ hidden_states = input_embeds
233
+ residual = None
234
+ for layer in self.layers:
235
+ hidden_states, residual = layer(
236
+ positions,
237
+ hidden_states,
238
+ forward_batch,
239
+ residual,
240
+ )
241
+ hidden_states, _ = self.norm(hidden_states, residual)
242
+
243
+ return hidden_states
244
+
245
+
246
+ class Glm4ForCausalLM(nn.Module):
247
+ def __init__(
248
+ self,
249
+ config: Glm4Config,
250
+ quant_config: Optional[QuantizationConfig] = None,
251
+ prefix: str = "",
252
+ ):
253
+ super().__init__()
254
+ self.config: Glm4Config = config
255
+ self.quant_config = quant_config
256
+ self.model = Glm4Model(config, quant_config, add_prefix("model", prefix))
257
+ if config.tie_word_embeddings:
258
+ self.lm_head = self.model.embed_tokens
259
+ else:
260
+ self.lm_head = ParallelLMHead(
261
+ config.vocab_size,
262
+ config.hidden_size,
263
+ quant_config=quant_config,
264
+ prefix="lm_head",
265
+ )
266
+ self.logits_processor = LogitsProcessor(config)
267
+
268
+ @torch.no_grad()
269
+ def forward(
270
+ self,
271
+ input_ids: torch.Tensor,
272
+ positions: torch.Tensor,
273
+ forward_batch: ForwardBatch,
274
+ ) -> torch.Tensor:
275
+ hidden_states = self.model(input_ids, positions, forward_batch)
276
+ return self.logits_processor(
277
+ input_ids, hidden_states, self.lm_head, forward_batch
278
+ )
279
+
280
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281
+ stacked_params_mapping = [
282
+ # (param_name, weight_name, shard_id)
283
+ (".qkv_proj", ".q_proj", "q"),
284
+ (".qkv_proj", ".k_proj", "k"),
285
+ (".qkv_proj", ".v_proj", "v"),
286
+ (".gate_up_proj", ".gate_proj", 0),
287
+ (".gate_up_proj", ".up_proj", 1),
288
+ ]
289
+ params_dict = dict(self.named_parameters())
290
+ for name, loaded_weight in weights:
291
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
292
+ continue
293
+ for param_name, weight_name, shard_id in stacked_params_mapping:
294
+ if weight_name not in name:
295
+ continue
296
+ name = name.replace(weight_name, param_name)
297
+ param = params_dict[name]
298
+ weight_loader = param.weight_loader
299
+ weight_loader(param, loaded_weight, shard_id)
300
+ break
301
+ else:
302
+ if name in params_dict.keys():
303
+ param = params_dict[name]
304
+ weight_loader = getattr(
305
+ param, "weight_loader", default_weight_loader
306
+ )
307
+ weight_loader(param, loaded_weight)
308
+ else:
309
+ raise KeyError(f"Parameter '{name}' not found in model.")
310
+
311
+
312
+ EntryClass = [Glm4ForCausalLM]
@@ -11,21 +11,19 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==========================582====================================================
14
-
15
- from typing import Iterable, List, Optional, Tuple, Union
14
+ from typing import Iterable, List, Optional, Set, Tuple, Union
16
15
 
17
16
  import torch
18
17
 
19
18
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
20
19
  # Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
21
20
  import torch.nn.functional as F
22
- from einops import rearrange, repeat
23
- from sgl_kernel.flash_attn import flash_attn_varlen_func
24
21
  from torch import nn
25
22
  from transformers import PretrainedConfig, PreTrainedModel
26
23
  from transformers.activations import ACT2FN
27
24
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
28
25
 
26
+ from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
29
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
28
  from sglang.srt.managers.mm_utils import (
31
29
  MultiModalityDataPaddingPatternTokenPairs,
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
40
38
  from sglang.utils import logger
41
39
 
42
40
 
43
- class FlashAttention(nn.Module):
44
- """Implement the scaled dot product attention with softmax.
45
- Arguments
46
- ---------
47
- softmax_scale: The temperature to use for the softmax attention.
48
- (default: 1/sqrt(d_keys) where d_keys is computed at
49
- runtime)
50
- attention_dropout: The dropout rate to apply to the attention
51
- (default: 0.0)
52
- """
53
-
41
+ class InternAttention(nn.Module):
54
42
  def __init__(
55
- self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
56
- ):
57
- super().__init__()
58
- self.softmax_scale = softmax_scale
59
- self.dropout_p = attention_dropout
60
-
61
- def forward(
62
43
  self,
63
- qkv,
64
- causal=False,
65
- max_s=None,
44
+ config,
45
+ quant_config: QuantizationConfig = None,
66
46
  ):
67
- """Implements the multihead softmax attention.
68
- Arguments
69
- ---------
70
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
71
- if unpadded: (nnz, 3, h, d)
72
- """
73
- assert qkv.dtype in [torch.float16, torch.bfloat16]
74
- assert qkv.is_cuda
75
-
76
- batch_size, seqlen, _, nheads, d = qkv.shape
77
- if batch_size == 0 or seqlen == 0:
78
- output_shape = (batch_size, seqlen, nheads, d)
79
- return (
80
- torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
81
- None,
82
- )
83
-
84
- qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
85
- q, k, v = qkv_reshaped.unbind(1)
86
-
87
- max_s = seqlen
88
- cu_seqlens = torch.arange(
89
- 0,
90
- (batch_size + 1) * seqlen,
91
- step=seqlen,
92
- dtype=torch.int32,
93
- device=qkv.device,
94
- )
95
- output_reshaped = flash_attn_varlen_func(
96
- q,
97
- k,
98
- v,
99
- cu_seqlens,
100
- cu_seqlens,
101
- max_s,
102
- max_s,
103
- softmax_scale=self.softmax_scale,
104
- causal=causal,
105
- )
106
- output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
107
- return output, None
108
-
109
-
110
- class InternAttention(nn.Module):
111
- def __init__(self, config):
112
47
  super().__init__()
113
48
  self.config = config
114
49
  self.embed_dim = config.hidden_size
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
116
51
  self.head_dim = self.embed_dim // self.num_heads
117
52
 
118
53
  self.scale = self.head_dim**-0.5
119
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
54
+
55
+ self.attn = VisionAttention(
56
+ qkv_backend="fa3",
57
+ embed_dim=self.embed_dim,
58
+ num_heads=self.num_heads,
59
+ projection_size=self.embed_dim,
60
+ use_qkv_parallel=True,
61
+ quant_config=quant_config,
62
+ dropout=getattr(config, "dropout", 0.0),
63
+ proj_bias=getattr(config, "qkv_bias", True),
64
+ flatten_batch=False,
65
+ )
66
+
120
67
  self.proj_drop = nn.Dropout(config.dropout)
121
68
 
122
69
  self.qk_normalization = config.qk_normalization
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
125
72
  self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
126
73
  self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
127
74
 
128
- self.inner_attn = FlashAttention(softmax_scale=self.scale)
129
-
130
- self.proj = nn.Linear(self.embed_dim, self.embed_dim)
131
-
132
- def _flash_attn(
75
+ def forward(
133
76
  self,
134
- x,
135
- ):
136
- qkv = self.qkv(x)
137
- qkv = rearrange(
138
- qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
139
- )
140
-
141
- if self.qk_normalization:
142
- q, k, v = qkv.unbind(2)
143
- q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
144
- k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
145
- qkv = torch.stack([q, k, v], dim=2)
146
-
147
- context, _ = self.inner_attn(
148
- qkv,
149
- )
150
- outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
151
- outs = self.proj_drop(outs)
77
+ hidden_states: torch.Tensor,
78
+ cu_seqlens: torch.Tensor,
79
+ ) -> torch.Tensor:
80
+ out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
81
+ outs = self.proj_drop(out)
152
82
  return outs
153
83
 
154
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155
- x = self._flash_attn(hidden_states)
156
- return x
157
-
158
84
 
159
85
  class InternVisionEmbeddings(nn.Module):
160
86
  def __init__(self, config: PretrainedConfig):
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
286
212
  def forward(
287
213
  self,
288
214
  hidden_states: torch.Tensor,
215
+ cu_seqlens: torch.Tensor,
289
216
  ) -> Tuple[
290
217
  torch.FloatTensor,
291
218
  Optional[torch.FloatTensor],
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
295
222
  Args:
296
223
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
297
224
  """
225
+
298
226
  hidden_states = hidden_states + self.drop_path1(
299
- self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
227
+ self.attn(
228
+ self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
229
+ )
230
+ * self.ls1
300
231
  )
301
232
 
302
233
  hidden_states = hidden_states + self.drop_path2(
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
363
294
  encoder_states = () if output_hidden_states else None
364
295
  hidden_states = inputs_embeds
365
296
 
297
+ cu_seqlens = SingletonCache()
298
+
366
299
  for idx, encoder_layer in enumerate(self.layers):
367
300
  if output_hidden_states:
368
301
  encoder_states = encoder_states + (hidden_states,)
369
- layer_outputs = encoder_layer(
370
- hidden_states,
371
- )
302
+ layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
372
303
  hidden_states = layer_outputs
373
304
 
374
305
  if output_hidden_states:
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
625
556
  ("gate_up_proj", "up_proj", 1),
626
557
  ]
627
558
  params_dict = dict(self.named_parameters())
559
+ loaded_params: Set[str] = set()
628
560
 
629
561
  for name, loaded_weight in weights:
630
562
  if "rotary_emb.inv_freq" in name:
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
641
573
  weight_loader(param, loaded_weight, shard_id)
642
574
  break
643
575
  else:
576
+ if "vision_model" in name:
577
+ # adapt to VisionAttention
578
+ name = name.replace(r"attn.", r"attn.attn.")
579
+ name = name.replace(r"qkv.", r"qkv_proj.")
580
+
644
581
  # Skip loading extra bias for GPTQ models.
645
582
  if name.endswith(".bias") and name not in params_dict:
646
583
  continue
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
665
602
  param, "weight_loader", default_weight_loader
666
603
  )
667
604
  weight_loader(param, loaded_weight)
605
+ loaded_params.add(name)
606
+ unloaded_params = params_dict.keys() - loaded_params
607
+ if unloaded_params:
608
+ raise RuntimeError(
609
+ f"Some weights are not initialized from checkpoints: {unloaded_params}"
610
+ )
611
+ return loaded_params
668
612
 
669
613
 
670
614
  EntryClass = InternVLChatModel
@@ -7,33 +7,17 @@ import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
9
 
10
- from sglang.srt.distributed import (
11
- get_tensor_model_parallel_rank,
12
- get_tensor_model_parallel_world_size,
13
- split_tensor_along_last_dim,
14
- tensor_model_parallel_all_gather,
15
- )
10
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
16
11
  from sglang.srt.layers.layernorm import RMSNorm
17
- from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18
12
  from sglang.srt.layers.logits_processor import LogitsProcessor
19
- from sglang.srt.layers.pooler import Pooler, PoolingType
20
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
- from sglang.srt.layers.radix_attention import RadixAttention
22
- from sglang.srt.layers.rotary_embedding import get_rope
23
14
  from sglang.srt.layers.vocab_parallel_embedding import (
24
15
  ParallelLMHead,
25
16
  VocabParallelEmbedding,
26
17
  )
27
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
29
- from sglang.srt.models.mimo import MiMoForCausalLM
30
- from sglang.srt.models.qwen2 import (
31
- Qwen2Attention,
32
- Qwen2DecoderLayer,
33
- Qwen2MLP,
34
- Qwen2Model,
35
- )
36
- from sglang.srt.utils import add_prefix
20
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer
37
21
 
38
22
 
39
23
  class MiMoMultiTokenPredictorLayer(nn.Module):