sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module):
97
97
  class Gemma2Attention(nn.Module):
98
98
  def __init__(
99
99
  self,
100
- layer_idx: int,
100
+ layer_id: int,
101
101
  config: PretrainedConfig,
102
102
  hidden_size: int,
103
103
  num_heads: int,
@@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module):
109
109
  quant_config: Optional[QuantizationConfig] = None,
110
110
  ) -> None:
111
111
  super().__init__()
112
- self.layer_idx = layer_idx
112
+ self.layer_id = layer_id
113
113
  self.config = config
114
114
  self.hidden_size = hidden_size
115
115
  tp_size = get_tensor_model_parallel_world_size()
@@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module):
156
156
  dtype=torch.get_default_dtype(),
157
157
  )
158
158
 
159
- use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
159
+ use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
160
160
  self.attn = RadixAttention(
161
161
  self.num_heads,
162
162
  self.head_dim,
163
163
  self.scaling,
164
164
  num_kv_heads=self.num_kv_heads,
165
- layer_id=layer_idx,
165
+ layer_id=layer_id,
166
166
  logit_cap=self.config.attn_logit_softcapping,
167
167
  sliding_window_size=(
168
168
  get_attention_sliding_window_size(config)
@@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module):
188
188
  class Gemma2DecoderLayer(nn.Module):
189
189
  def __init__(
190
190
  self,
191
- layer_idx: int,
191
+ layer_id: int,
192
192
  config: PretrainedConfig,
193
193
  cache_config=None,
194
194
  quant_config: Optional[QuantizationConfig] = None,
@@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module):
196
196
  super().__init__()
197
197
  self.hidden_size = config.hidden_size
198
198
  self.self_attn = Gemma2Attention(
199
- layer_idx=layer_idx,
199
+ layer_id=layer_id,
200
200
  config=config,
201
201
  hidden_size=self.hidden_size,
202
202
  num_heads=config.num_attention_heads,
@@ -269,8 +269,8 @@ class Gemma2Model(nn.Module):
269
269
  )
270
270
  self.layers = nn.ModuleList(
271
271
  [
272
- Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
273
- for layer_idx in range(config.num_hidden_layers)
272
+ Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
273
+ for layer_id in range(config.num_hidden_layers)
274
274
  ]
275
275
  )
276
276
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module):
332
332
  # Gemma does not apply LoRA to the embedding layer.
333
333
  embedding_modules = {}
334
334
  embedding_padding_modules = []
335
+ supports_lora = True
335
336
 
336
337
  def __init__(
337
338
  self,
@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
345
345
 
346
346
  # Fill in the placeholder for the image
347
347
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
348
- prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
348
+ prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
349
349
  pt = 0
350
350
  for i in range(bs):
351
351
  if not need_vision[i]:
@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
169
169
 
170
170
  # Fill in the placeholder for the image
171
171
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
172
- prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
172
+ prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
173
173
  pt = 0
174
174
  for i in range(bs):
175
175
  if not need_vision[i]:
sglang/srt/models/olmo.py CHANGED
@@ -223,8 +223,8 @@ class OlmoModel(nn.Module):
223
223
  )
224
224
  self.layers = nn.ModuleList(
225
225
  [
226
- OlmoDecoderLayer(config, layer_idx, quant_config)
227
- for layer_idx in range(config.num_hidden_layers)
226
+ OlmoDecoderLayer(config, layer_id, quant_config)
227
+ for layer_id in range(config.num_hidden_layers)
228
228
  ]
229
229
  )
230
230
  self.norm = nn.LayerNorm(
@@ -250,7 +250,7 @@ class OlmoModel(nn.Module):
250
250
  hidden_states = input_embeds
251
251
 
252
252
  # Apply blocks one-by-one.
253
- for layer_idx, decoder_layer in enumerate(self.layers):
253
+ for layer_id, decoder_layer in enumerate(self.layers):
254
254
  # shape: (batch_size, seq_len, d_model)
255
255
  hidden_states = decoder_layer(
256
256
  positions,
@@ -0,0 +1,447 @@
1
+ import math
2
+ from typing import Iterable, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Phi3Config
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from vllm.distributed import get_tensor_model_parallel_world_size
9
+ from vllm.model_executor.layers.rotary_embedding import get_rope
10
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
+ from vllm.model_executor.models.utils import make_layers
12
+
13
+ from sglang.srt.layers.linear import (
14
+ MergedColumnParallelLinear,
15
+ QKVParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
23
+ from sglang.srt.layers.vocab_parallel_embedding import (
24
+ DEFAULT_VOCAB_PADDING_SIZE,
25
+ ParallelLMHead,
26
+ VocabParallelEmbedding,
27
+ )
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+
31
+
32
+ @torch.jit.script
33
+ def quick_gelu(x):
34
+ return x * torch.sigmoid(1.702 * x)
35
+
36
+
37
+ @torch.jit.script
38
+ def gegelu(input, limit: Optional[float] = None):
39
+ a_gelu, a_linear = input[..., ::2], input[..., 1::2]
40
+ if limit is not None:
41
+ a_gelu = torch.where(
42
+ torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
43
+ )
44
+ a_linear = torch.where(
45
+ torch.isinf(a_linear),
46
+ a_linear,
47
+ a_linear.clamp(min=-limit, max=limit),
48
+ )
49
+ out_gelu = quick_gelu(a_gelu)
50
+ return out_gelu * (a_linear + 1)
51
+
52
+
53
+ class Phi3SmallMLP(nn.Module):
54
+
55
+ def __init__(
56
+ self,
57
+ config: PretrainedConfig,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ prefix: str = "",
60
+ ) -> None:
61
+ super().__init__()
62
+ self.config = config
63
+ assert (
64
+ self.config.hidden_act == "gegelu"
65
+ ), "Only `gegelu` is supported for the 4.7 series of models .."
66
+ self.hidden_size = config.hidden_size
67
+ self.gegelu_limit = config.gegelu_limit
68
+ self.intermediate_size = config.intermediate_size
69
+
70
+ self.up_proj = MergedColumnParallelLinear(
71
+ self.hidden_size,
72
+ 2 * [self.intermediate_size],
73
+ bias=True,
74
+ quant_config=quant_config,
75
+ prefix=f"{prefix}.up_proj",
76
+ )
77
+ self.down_proj = RowParallelLinear(
78
+ self.intermediate_size,
79
+ self.hidden_size,
80
+ bias=True,
81
+ quant_config=quant_config,
82
+ )
83
+
84
+ def forward(self, x):
85
+ gate_up, _ = self.up_proj(x)
86
+ x = gegelu(gate_up)
87
+ x, _ = self.down_proj(x)
88
+ return x
89
+
90
+
91
+ class Phi3SmallSelfAttention(nn.Module):
92
+
93
+ def __init__(
94
+ self,
95
+ config: PretrainedConfig,
96
+ layer_id: int = 0,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.layer_id = layer_id
102
+ self.config = config
103
+ self.sparse_block_size = config.blocksparse_block_size
104
+ self.homo_heads = config.blocksparse_homo_head_pattern
105
+ self.local_blocks = config.blocksparse_num_local_blocks
106
+ self.vert_stride = config.blocksparse_vert_stride
107
+
108
+ assert (
109
+ config.blocksparse_block_size == config.blocksparse_triton_kernel_block_size
110
+ )
111
+
112
+ self.hidden_size = config.hidden_size
113
+ # Number of Query Heads
114
+ self.num_heads = config.num_attention_heads
115
+
116
+ self.head_dim = self.hidden_size // self.num_heads
117
+ self.tp_size = get_tensor_model_parallel_world_size()
118
+ # Number of total Key Value Heads before tensor parallel
119
+ self.num_key_value_heads = config.num_key_value_heads
120
+ self.num_q_per_kv = self.num_heads // self.num_key_value_heads
121
+ if self.tp_size > 1:
122
+ assert self.num_key_value_heads % self.tp_size == 0
123
+ self.num_kv_heads_per_partion = max(1, self.num_key_value_heads // self.tp_size)
124
+ self.num_heads_per_partition = self.num_heads // self.tp_size
125
+
126
+ self.max_position_embeddings = config.max_position_embeddings
127
+ self.rope_embedding_base = config.rope_embedding_base
128
+ self.rope_position_scale = config.rope_position_scale
129
+ self.is_causal = True
130
+
131
+ norm_factor = None
132
+ if config.mup_use_scaling:
133
+ norm_factor = self.head_dim / config.mup_attn_multiplier
134
+ else:
135
+ norm_factor = math.sqrt(self.head_dim)
136
+ self.scale = 1 / norm_factor
137
+
138
+ self.query_key_value = QKVParallelLinear(
139
+ self.hidden_size,
140
+ self.head_dim,
141
+ self.num_heads,
142
+ self.num_key_value_heads,
143
+ bias=True,
144
+ quant_config=quant_config,
145
+ prefix=f"{prefix}.qkv_proj",
146
+ )
147
+
148
+ self.dense = RowParallelLinear(
149
+ self.hidden_size,
150
+ self.hidden_size,
151
+ bias=True,
152
+ quant_config=quant_config,
153
+ prefix=f"{prefix}.o_proj",
154
+ )
155
+
156
+ if getattr(self.config, "rope_scaling", None) is not None:
157
+ rope_scaling = self.config.rope_scaling
158
+ for key in rope_scaling:
159
+ if isinstance(rope_scaling[key], list):
160
+ rope_scaling[key] = tuple(rope_scaling[key])
161
+
162
+ if "factor" not in rope_scaling:
163
+ rope_scaling["factor"] = self.rope_position_scale
164
+ else:
165
+ rope_scaling = {
166
+ "rope_type": "linear",
167
+ "factor": self.rope_position_scale,
168
+ }
169
+
170
+ self.rotary_emb = get_rope(
171
+ self.head_dim,
172
+ rotary_dim=self.head_dim,
173
+ max_position=self.max_position_embeddings,
174
+ base=self.rope_embedding_base,
175
+ rope_scaling=rope_scaling,
176
+ )
177
+
178
+ # blocksparse params
179
+ self.blocksparse_block_size = config.blocksparse_block_size
180
+ self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
181
+ self.blocksparse_vert_stride = config.blocksparse_vert_stride
182
+
183
+ use_dense_attn = (
184
+ getattr(self.config, "dense_attention_every_n_layers", None)
185
+ and (self.layer_id + 1) % self.config.dense_attention_every_n_layers == 0
186
+ )
187
+
188
+ bs_params = None
189
+ if not use_dense_attn:
190
+ bs_params = {
191
+ "max_seqlen": self.max_position_embeddings,
192
+ "num_heads": self.num_heads_per_partition,
193
+ "num_kv_heads": self.num_kv_heads_per_partion,
194
+ "block_size": self.sparse_block_size,
195
+ "local_blocks": self.local_blocks,
196
+ "vert_stride": self.vert_stride,
197
+ "homo_head": self.homo_heads,
198
+ }
199
+
200
+ self.attn = RadixAttention(
201
+ self.num_heads_per_partition,
202
+ self.head_dim,
203
+ self.scale,
204
+ num_kv_heads=self.num_kv_heads_per_partion,
205
+ layer_id=layer_id,
206
+ )
207
+
208
+ def forward(
209
+ self,
210
+ positions: torch.Tensor,
211
+ hidden_states: torch.Tensor,
212
+ forward_batch: ForwardBatch,
213
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
214
+ qkv, _ = self.query_key_value(hidden_states)
215
+
216
+ qkv = qkv.view(qkv.shape[:-1] + (-1, (self.num_q_per_kv + 2), self.head_dim))
217
+ q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
218
+
219
+ # NOTE: this is required by RotaryEmbed, which indeed does not have to
220
+ # TODO: allow 3D QK for rotary forward
221
+ q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
222
+ k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
223
+ v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
224
+
225
+ q, k = self.rotary_emb(positions, q, k)
226
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
227
+ output, _ = self.dense(attn_output)
228
+
229
+ return output
230
+
231
+
232
+ class Phi3SmallDecoderLayer(nn.Module):
233
+
234
+ def __init__(
235
+ self,
236
+ config: PretrainedConfig,
237
+ layer_id: int,
238
+ cache_config=None,
239
+ quant_config: Optional[QuantizationConfig] = None,
240
+ ):
241
+ super().__init__()
242
+ self.hidden_size = config.hidden_size
243
+ self.self_attn = Phi3SmallSelfAttention(
244
+ config, layer_id, quant_config=quant_config
245
+ )
246
+ self.mlp = Phi3SmallMLP(config, quant_config)
247
+
248
+ self.input_layernorm = nn.LayerNorm(
249
+ config.hidden_size, eps=config.layer_norm_epsilon
250
+ )
251
+ self.post_attention_layernorm = nn.LayerNorm(
252
+ config.hidden_size, eps=config.layer_norm_epsilon
253
+ )
254
+
255
+ def forward(
256
+ self,
257
+ positions: torch.Tensor,
258
+ hidden_states: torch.Tensor,
259
+ forward_batch: ForwardBatch,
260
+ ) -> torch.Tensor:
261
+ residual = hidden_states
262
+ hidden_states = self.input_layernorm(hidden_states)
263
+
264
+ hidden_states = self.self_attn(
265
+ positions=positions,
266
+ hidden_states=hidden_states,
267
+ forward_batch=forward_batch,
268
+ )
269
+ hidden_states = residual + hidden_states
270
+
271
+ residual = hidden_states
272
+ hidden_states = self.post_attention_layernorm(hidden_states)
273
+ hidden_states = self.mlp(hidden_states)
274
+ hidden_states = residual + hidden_states
275
+ return hidden_states
276
+
277
+
278
+ class Phi3SmallModel(nn.Module):
279
+
280
+ def __init__(
281
+ self,
282
+ config: Phi3Config,
283
+ quant_config: Optional[QuantizationConfig] = None,
284
+ prefix: str = "",
285
+ ):
286
+ super().__init__()
287
+
288
+ self.config = config
289
+ cache_config = None
290
+ self.embed_tokens = VocabParallelEmbedding(
291
+ config.vocab_size, config.hidden_size
292
+ )
293
+ self.mup_embedding_multiplier = config.mup_embedding_multiplier
294
+ self.start_layer, self.end_layer, self.layers = make_layers(
295
+ config.num_hidden_layers,
296
+ lambda prefix: Phi3SmallDecoderLayer(
297
+ config, int(prefix.split(".")[-1]), cache_config, quant_config
298
+ ),
299
+ prefix=f"{prefix}.layers",
300
+ )
301
+
302
+ self.final_layernorm = nn.LayerNorm(
303
+ config.hidden_size, eps=config.layer_norm_epsilon
304
+ )
305
+
306
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
307
+ return self.embed_tokens(input_ids)
308
+
309
+ def forward(
310
+ self,
311
+ input_ids: torch.LongTensor,
312
+ positions: Optional[torch.LongTensor],
313
+ forward_batch: ForwardBatch,
314
+ inputs_embeds: Optional[torch.Tensor],
315
+ ) -> Union[torch.Tensor]:
316
+
317
+ if inputs_embeds is not None:
318
+ hidden_states = inputs_embeds
319
+ else:
320
+ hidden_states = self.get_input_embeddings(input_ids)
321
+ if (
322
+ self.mup_embedding_multiplier is not None
323
+ and self.mup_embedding_multiplier > 0.0
324
+ ):
325
+ hidden_states = hidden_states * self.mup_embedding_multiplier
326
+
327
+ for i in range(self.start_layer, self.end_layer):
328
+ layer = self.layers[i]
329
+ hidden_states = layer(positions, hidden_states, forward_batch=forward_batch)
330
+
331
+ hidden_states = self.final_layernorm(hidden_states)
332
+ return hidden_states
333
+
334
+
335
+ class Phi3SmallForCausalLM(nn.Module):
336
+ _tied_weights_keys = ["lm_head.weight"]
337
+
338
+ def __init__(
339
+ self,
340
+ config: Phi3Config,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ cache_config=None,
343
+ ):
344
+
345
+ super().__init__()
346
+
347
+ self.config = config
348
+ self.quant_config = quant_config
349
+ self.model = Phi3SmallModel(
350
+ config=config,
351
+ quant_config=quant_config,
352
+ prefix="model",
353
+ )
354
+ self.torchao_config = global_server_args_dict["torchao_config"]
355
+ self.vocab_size = config.vocab_size
356
+ self.mup_width_multiplier = config.mup_width_multiplier
357
+ self.lm_head = ParallelLMHead(
358
+ self.vocab_size,
359
+ config.hidden_size,
360
+ org_num_embeddings=config.vocab_size,
361
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
362
+ quant_config=quant_config,
363
+ )
364
+ if self.config.tie_word_embeddings:
365
+ self.lm_head.weight = self.model.embed_tokens.weight
366
+ self.logits_processor = LogitsProcessor(config)
367
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
368
+
369
+ # tokens in tiktoken but not used
370
+ if hasattr(config, "dummy_token_indices"):
371
+ device = self.lm_head.weight.device
372
+ self.register_buffer(
373
+ "dummy_token_indices",
374
+ torch.LongTensor(config.dummy_token_indices).to(device),
375
+ persistent=False,
376
+ )
377
+ else:
378
+ self.dummy_token_indices = None
379
+
380
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
381
+ return self.model.get_input_embeddings(input_ids)
382
+
383
+ def set_input_embeddings(self, value):
384
+ self.model.embed_tokens = value
385
+
386
+ def get_output_embeddings(self):
387
+ return self.lm_head
388
+
389
+ def set_output_embeddings(self, value):
390
+ self.lm_head = value
391
+
392
+ def set_decoder(self, decoder):
393
+ self.model = decoder
394
+
395
+ def get_decoder(self):
396
+ return self.model
397
+
398
+ def compute_logits(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ sampling_metadata,
402
+ ) -> Optional[torch.Tensor]:
403
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
404
+ if self.dummy_token_indices is not None and logits is not None:
405
+ logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
406
+ return logits
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor,
411
+ positions: Optional[torch.LongTensor],
412
+ forward_batch: ForwardBatch,
413
+ inputs_embeds: Optional[torch.Tensor] = None,
414
+ get_embedding: bool = False,
415
+ ) -> LogitsProcessorOutput:
416
+ hidden_states = self.model(
417
+ input_ids=input_ids,
418
+ positions=positions,
419
+ forward_batch=forward_batch,
420
+ inputs_embeds=inputs_embeds,
421
+ )
422
+
423
+ if not get_embedding:
424
+ return self.logits_processor(
425
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
426
+ )
427
+
428
+ else:
429
+ return self.pooler(hidden_states, forward_batch)
430
+
431
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
432
+
433
+ params_dict = dict(self.named_parameters())
434
+ for name, loaded_weight in weights:
435
+ if "rotary_emb.inv_freq" in name:
436
+ continue
437
+ if name.endswith(".bias") and name not in params_dict:
438
+ continue
439
+
440
+ param = params_dict[name]
441
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
+ weight_loader(param, loaded_weight)
443
+
444
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
445
+
446
+
447
+ EntryClass = Phi3SmallForCausalLM
@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
44
44
  )
45
45
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.pooler import Pooler, PoolingType
47
48
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
49
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
49
50
  from sglang.srt.managers.schedule_batch import ImageInputs
@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
559
560
  )
560
561
 
561
562
  self.logits_processor = LogitsProcessor(config)
563
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
562
564
 
563
565
  def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
564
566
  pixel_values = image_input["pixel_values"].type(self.visual.dtype)
@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
577
579
  input_ids: torch.Tensor,
578
580
  positions: torch.Tensor,
579
581
  forward_batch: ForwardBatch,
582
+ get_embedding: bool = False,
580
583
  ):
581
584
  """Run forward pass for Qwen2-VL.
582
585
 
@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
599
602
  image_inputs = [
600
603
  img for img in forward_batch.image_inputs if img is not None
601
604
  ]
602
-
603
- positions = forward_batch.mrope_positions
605
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
606
+ positions = forward_batch.mrope_positions
604
607
  if (
605
608
  forward_batch.forward_mode.is_decode()
606
609
  or image_inputs is None
@@ -616,7 +619,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
616
619
 
617
620
  inputs_embeds = self.model.embed_tokens(input_ids)
618
621
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
619
- prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
622
+ prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
620
623
  for i, image in enumerate(forward_batch.image_inputs):
621
624
  if image is None:
622
625
  continue
@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
655
658
  forward_batch=forward_batch,
656
659
  input_embeds=inputs_embeds,
657
660
  )
658
- return self.logits_processor(
659
- input_ids, hidden_states, self.lm_head.weight, forward_batch
660
- )
661
+
662
+ if not get_embedding:
663
+ return self.logits_processor(
664
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
665
+ )
666
+ else:
667
+ return self.pooler(hidden_states, forward_batch)
661
668
 
662
669
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
663
670
  stacked_params_mapping = [