sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,447 @@
1
+ import math
2
+ from typing import Iterable, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Phi3Config
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from vllm.distributed import get_tensor_model_parallel_world_size
9
+ from vllm.model_executor.layers.rotary_embedding import get_rope
10
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
+ from vllm.model_executor.models.utils import make_layers
12
+
13
+ from sglang.srt.layers.linear import (
14
+ MergedColumnParallelLinear,
15
+ QKVParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
23
+ from sglang.srt.layers.vocab_parallel_embedding import (
24
+ DEFAULT_VOCAB_PADDING_SIZE,
25
+ ParallelLMHead,
26
+ VocabParallelEmbedding,
27
+ )
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+
31
+
32
+ @torch.jit.script
33
+ def quick_gelu(x):
34
+ return x * torch.sigmoid(1.702 * x)
35
+
36
+
37
+ @torch.jit.script
38
+ def gegelu(input, limit: Optional[float] = None):
39
+ a_gelu, a_linear = input[..., ::2], input[..., 1::2]
40
+ if limit is not None:
41
+ a_gelu = torch.where(
42
+ torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
43
+ )
44
+ a_linear = torch.where(
45
+ torch.isinf(a_linear),
46
+ a_linear,
47
+ a_linear.clamp(min=-limit, max=limit),
48
+ )
49
+ out_gelu = quick_gelu(a_gelu)
50
+ return out_gelu * (a_linear + 1)
51
+
52
+
53
+ class Phi3SmallMLP(nn.Module):
54
+
55
+ def __init__(
56
+ self,
57
+ config: PretrainedConfig,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ prefix: str = "",
60
+ ) -> None:
61
+ super().__init__()
62
+ self.config = config
63
+ assert (
64
+ self.config.hidden_act == "gegelu"
65
+ ), "Only `gegelu` is supported for the 4.7 series of models .."
66
+ self.hidden_size = config.hidden_size
67
+ self.gegelu_limit = config.gegelu_limit
68
+ self.intermediate_size = config.intermediate_size
69
+
70
+ self.up_proj = MergedColumnParallelLinear(
71
+ self.hidden_size,
72
+ 2 * [self.intermediate_size],
73
+ bias=True,
74
+ quant_config=quant_config,
75
+ prefix=f"{prefix}.up_proj",
76
+ )
77
+ self.down_proj = RowParallelLinear(
78
+ self.intermediate_size,
79
+ self.hidden_size,
80
+ bias=True,
81
+ quant_config=quant_config,
82
+ )
83
+
84
+ def forward(self, x):
85
+ gate_up, _ = self.up_proj(x)
86
+ x = gegelu(gate_up)
87
+ x, _ = self.down_proj(x)
88
+ return x
89
+
90
+
91
+ class Phi3SmallSelfAttention(nn.Module):
92
+
93
+ def __init__(
94
+ self,
95
+ config: PretrainedConfig,
96
+ layer_id: int = 0,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.layer_id = layer_id
102
+ self.config = config
103
+ self.sparse_block_size = config.blocksparse_block_size
104
+ self.homo_heads = config.blocksparse_homo_head_pattern
105
+ self.local_blocks = config.blocksparse_num_local_blocks
106
+ self.vert_stride = config.blocksparse_vert_stride
107
+
108
+ assert (
109
+ config.blocksparse_block_size == config.blocksparse_triton_kernel_block_size
110
+ )
111
+
112
+ self.hidden_size = config.hidden_size
113
+ # Number of Query Heads
114
+ self.num_heads = config.num_attention_heads
115
+
116
+ self.head_dim = self.hidden_size // self.num_heads
117
+ self.tp_size = get_tensor_model_parallel_world_size()
118
+ # Number of total Key Value Heads before tensor parallel
119
+ self.num_key_value_heads = config.num_key_value_heads
120
+ self.num_q_per_kv = self.num_heads // self.num_key_value_heads
121
+ if self.tp_size > 1:
122
+ assert self.num_key_value_heads % self.tp_size == 0
123
+ self.num_kv_heads_per_partion = max(1, self.num_key_value_heads // self.tp_size)
124
+ self.num_heads_per_partition = self.num_heads // self.tp_size
125
+
126
+ self.max_position_embeddings = config.max_position_embeddings
127
+ self.rope_embedding_base = config.rope_embedding_base
128
+ self.rope_position_scale = config.rope_position_scale
129
+ self.is_causal = True
130
+
131
+ norm_factor = None
132
+ if config.mup_use_scaling:
133
+ norm_factor = self.head_dim / config.mup_attn_multiplier
134
+ else:
135
+ norm_factor = math.sqrt(self.head_dim)
136
+ self.scale = 1 / norm_factor
137
+
138
+ self.query_key_value = QKVParallelLinear(
139
+ self.hidden_size,
140
+ self.head_dim,
141
+ self.num_heads,
142
+ self.num_key_value_heads,
143
+ bias=True,
144
+ quant_config=quant_config,
145
+ prefix=f"{prefix}.qkv_proj",
146
+ )
147
+
148
+ self.dense = RowParallelLinear(
149
+ self.hidden_size,
150
+ self.hidden_size,
151
+ bias=True,
152
+ quant_config=quant_config,
153
+ prefix=f"{prefix}.o_proj",
154
+ )
155
+
156
+ if getattr(self.config, "rope_scaling", None) is not None:
157
+ rope_scaling = self.config.rope_scaling
158
+ for key in rope_scaling:
159
+ if isinstance(rope_scaling[key], list):
160
+ rope_scaling[key] = tuple(rope_scaling[key])
161
+
162
+ if "factor" not in rope_scaling:
163
+ rope_scaling["factor"] = self.rope_position_scale
164
+ else:
165
+ rope_scaling = {
166
+ "rope_type": "linear",
167
+ "factor": self.rope_position_scale,
168
+ }
169
+
170
+ self.rotary_emb = get_rope(
171
+ self.head_dim,
172
+ rotary_dim=self.head_dim,
173
+ max_position=self.max_position_embeddings,
174
+ base=self.rope_embedding_base,
175
+ rope_scaling=rope_scaling,
176
+ )
177
+
178
+ # blocksparse params
179
+ self.blocksparse_block_size = config.blocksparse_block_size
180
+ self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
181
+ self.blocksparse_vert_stride = config.blocksparse_vert_stride
182
+
183
+ use_dense_attn = (
184
+ getattr(self.config, "dense_attention_every_n_layers", None)
185
+ and (self.layer_id + 1) % self.config.dense_attention_every_n_layers == 0
186
+ )
187
+
188
+ bs_params = None
189
+ if not use_dense_attn:
190
+ bs_params = {
191
+ "max_seqlen": self.max_position_embeddings,
192
+ "num_heads": self.num_heads_per_partition,
193
+ "num_kv_heads": self.num_kv_heads_per_partion,
194
+ "block_size": self.sparse_block_size,
195
+ "local_blocks": self.local_blocks,
196
+ "vert_stride": self.vert_stride,
197
+ "homo_head": self.homo_heads,
198
+ }
199
+
200
+ self.attn = RadixAttention(
201
+ self.num_heads_per_partition,
202
+ self.head_dim,
203
+ self.scale,
204
+ num_kv_heads=self.num_kv_heads_per_partion,
205
+ layer_id=layer_id,
206
+ )
207
+
208
+ def forward(
209
+ self,
210
+ positions: torch.Tensor,
211
+ hidden_states: torch.Tensor,
212
+ forward_batch: ForwardBatch,
213
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
214
+ qkv, _ = self.query_key_value(hidden_states)
215
+
216
+ qkv = qkv.view(qkv.shape[:-1] + (-1, (self.num_q_per_kv + 2), self.head_dim))
217
+ q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
218
+
219
+ # NOTE: this is required by RotaryEmbed, which indeed does not have to
220
+ # TODO: allow 3D QK for rotary forward
221
+ q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
222
+ k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
223
+ v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
224
+
225
+ q, k = self.rotary_emb(positions, q, k)
226
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
227
+ output, _ = self.dense(attn_output)
228
+
229
+ return output
230
+
231
+
232
+ class Phi3SmallDecoderLayer(nn.Module):
233
+
234
+ def __init__(
235
+ self,
236
+ config: PretrainedConfig,
237
+ layer_id: int,
238
+ cache_config=None,
239
+ quant_config: Optional[QuantizationConfig] = None,
240
+ ):
241
+ super().__init__()
242
+ self.hidden_size = config.hidden_size
243
+ self.self_attn = Phi3SmallSelfAttention(
244
+ config, layer_id, quant_config=quant_config
245
+ )
246
+ self.mlp = Phi3SmallMLP(config, quant_config)
247
+
248
+ self.input_layernorm = nn.LayerNorm(
249
+ config.hidden_size, eps=config.layer_norm_epsilon
250
+ )
251
+ self.post_attention_layernorm = nn.LayerNorm(
252
+ config.hidden_size, eps=config.layer_norm_epsilon
253
+ )
254
+
255
+ def forward(
256
+ self,
257
+ positions: torch.Tensor,
258
+ hidden_states: torch.Tensor,
259
+ forward_batch: ForwardBatch,
260
+ ) -> torch.Tensor:
261
+ residual = hidden_states
262
+ hidden_states = self.input_layernorm(hidden_states)
263
+
264
+ hidden_states = self.self_attn(
265
+ positions=positions,
266
+ hidden_states=hidden_states,
267
+ forward_batch=forward_batch,
268
+ )
269
+ hidden_states = residual + hidden_states
270
+
271
+ residual = hidden_states
272
+ hidden_states = self.post_attention_layernorm(hidden_states)
273
+ hidden_states = self.mlp(hidden_states)
274
+ hidden_states = residual + hidden_states
275
+ return hidden_states
276
+
277
+
278
+ class Phi3SmallModel(nn.Module):
279
+
280
+ def __init__(
281
+ self,
282
+ config: Phi3Config,
283
+ quant_config: Optional[QuantizationConfig] = None,
284
+ prefix: str = "",
285
+ ):
286
+ super().__init__()
287
+
288
+ self.config = config
289
+ cache_config = None
290
+ self.embed_tokens = VocabParallelEmbedding(
291
+ config.vocab_size, config.hidden_size
292
+ )
293
+ self.mup_embedding_multiplier = config.mup_embedding_multiplier
294
+ self.start_layer, self.end_layer, self.layers = make_layers(
295
+ config.num_hidden_layers,
296
+ lambda prefix: Phi3SmallDecoderLayer(
297
+ config, int(prefix.split(".")[-1]), cache_config, quant_config
298
+ ),
299
+ prefix=f"{prefix}.layers",
300
+ )
301
+
302
+ self.final_layernorm = nn.LayerNorm(
303
+ config.hidden_size, eps=config.layer_norm_epsilon
304
+ )
305
+
306
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
307
+ return self.embed_tokens(input_ids)
308
+
309
+ def forward(
310
+ self,
311
+ input_ids: torch.LongTensor,
312
+ positions: Optional[torch.LongTensor],
313
+ forward_batch: ForwardBatch,
314
+ inputs_embeds: Optional[torch.Tensor],
315
+ ) -> Union[torch.Tensor]:
316
+
317
+ if inputs_embeds is not None:
318
+ hidden_states = inputs_embeds
319
+ else:
320
+ hidden_states = self.get_input_embeddings(input_ids)
321
+ if (
322
+ self.mup_embedding_multiplier is not None
323
+ and self.mup_embedding_multiplier > 0.0
324
+ ):
325
+ hidden_states = hidden_states * self.mup_embedding_multiplier
326
+
327
+ for i in range(self.start_layer, self.end_layer):
328
+ layer = self.layers[i]
329
+ hidden_states = layer(positions, hidden_states, forward_batch=forward_batch)
330
+
331
+ hidden_states = self.final_layernorm(hidden_states)
332
+ return hidden_states
333
+
334
+
335
+ class Phi3SmallForCausalLM(nn.Module):
336
+ _tied_weights_keys = ["lm_head.weight"]
337
+
338
+ def __init__(
339
+ self,
340
+ config: Phi3Config,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ cache_config=None,
343
+ ):
344
+
345
+ super().__init__()
346
+
347
+ self.config = config
348
+ self.quant_config = quant_config
349
+ self.model = Phi3SmallModel(
350
+ config=config,
351
+ quant_config=quant_config,
352
+ prefix="model",
353
+ )
354
+ self.torchao_config = global_server_args_dict["torchao_config"]
355
+ self.vocab_size = config.vocab_size
356
+ self.mup_width_multiplier = config.mup_width_multiplier
357
+ self.lm_head = ParallelLMHead(
358
+ self.vocab_size,
359
+ config.hidden_size,
360
+ org_num_embeddings=config.vocab_size,
361
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
362
+ quant_config=quant_config,
363
+ )
364
+ if self.config.tie_word_embeddings:
365
+ self.lm_head.weight = self.model.embed_tokens.weight
366
+ self.logits_processor = LogitsProcessor(config)
367
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
368
+
369
+ # tokens in tiktoken but not used
370
+ if hasattr(config, "dummy_token_indices"):
371
+ device = self.lm_head.weight.device
372
+ self.register_buffer(
373
+ "dummy_token_indices",
374
+ torch.LongTensor(config.dummy_token_indices).to(device),
375
+ persistent=False,
376
+ )
377
+ else:
378
+ self.dummy_token_indices = None
379
+
380
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
381
+ return self.model.get_input_embeddings(input_ids)
382
+
383
+ def set_input_embeddings(self, value):
384
+ self.model.embed_tokens = value
385
+
386
+ def get_output_embeddings(self):
387
+ return self.lm_head
388
+
389
+ def set_output_embeddings(self, value):
390
+ self.lm_head = value
391
+
392
+ def set_decoder(self, decoder):
393
+ self.model = decoder
394
+
395
+ def get_decoder(self):
396
+ return self.model
397
+
398
+ def compute_logits(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ sampling_metadata,
402
+ ) -> Optional[torch.Tensor]:
403
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
404
+ if self.dummy_token_indices is not None and logits is not None:
405
+ logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
406
+ return logits
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor,
411
+ positions: Optional[torch.LongTensor],
412
+ forward_batch: ForwardBatch,
413
+ inputs_embeds: Optional[torch.Tensor] = None,
414
+ get_embedding: bool = False,
415
+ ) -> LogitsProcessorOutput:
416
+ hidden_states = self.model(
417
+ input_ids=input_ids,
418
+ positions=positions,
419
+ forward_batch=forward_batch,
420
+ inputs_embeds=inputs_embeds,
421
+ )
422
+
423
+ if not get_embedding:
424
+ return self.logits_processor(
425
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
426
+ )
427
+
428
+ else:
429
+ return self.pooler(hidden_states, forward_batch)
430
+
431
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
432
+
433
+ params_dict = dict(self.named_parameters())
434
+ for name, loaded_weight in weights:
435
+ if "rotary_emb.inv_freq" in name:
436
+ continue
437
+ if name.endswith(".bias") and name not in params_dict:
438
+ continue
439
+
440
+ param = params_dict[name]
441
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
+ weight_loader(param, loaded_weight)
443
+
444
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
445
+
446
+
447
+ EntryClass = Phi3SmallForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -1,20 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from
17
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
+
18
18
  from typing import Any, Dict, Iterable, Optional, Tuple
19
19
 
20
20
  import torch
@@ -1,21 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from llama2.py
17
16
  # Modify details for the adaptation of Qwen2 model.
18
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
+
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  VocabParallelEmbedding,
41
41
  )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.utils import make_layers
43
44
 
44
45
  Qwen2Config = None
45
46
 
@@ -230,11 +231,13 @@ class Qwen2Model(nn.Module):
230
231
  config.vocab_size,
231
232
  config.hidden_size,
232
233
  )
233
- self.layers = nn.ModuleList(
234
- [
235
- Qwen2DecoderLayer(config, i, quant_config=quant_config)
236
- for i in range(config.num_hidden_layers)
237
- ]
234
+ self.layers = make_layers(
235
+ config.num_hidden_layers,
236
+ lambda idx, prefix: Qwen2DecoderLayer(
237
+ layer_id=idx,
238
+ config=config,
239
+ quant_config=quant_config,
240
+ ),
238
241
  )
239
242
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
243
 
@@ -1,23 +1,22 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- # coding=utf-8
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
17
15
  # Adapted from
18
16
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
19
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
20
- from typing import Any, Dict, Iterable, List, Optional, Tuple
18
+
19
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import torch.nn.functional as F
@@ -27,11 +26,11 @@ from vllm.distributed import (
27
26
  get_tensor_model_parallel_world_size,
28
27
  tensor_model_parallel_all_reduce,
29
28
  )
30
- from vllm.model_executor.layers.fused_moe import FusedMoE
31
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
31
 
34
32
  from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
35
34
  from sglang.srt.layers.layernorm import RMSNorm
36
35
  from sglang.srt.layers.linear import (
37
36
  MergedColumnParallelLinear,
@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
44
44
  )
45
45
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.pooler import Pooler, PoolingType
47
48
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
49
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
49
50
  from sglang.srt.managers.schedule_batch import ImageInputs
@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
559
560
  )
560
561
 
561
562
  self.logits_processor = LogitsProcessor(config)
563
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
562
564
 
563
565
  def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
564
566
  pixel_values = image_input["pixel_values"].type(self.visual.dtype)
@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
577
579
  input_ids: torch.Tensor,
578
580
  positions: torch.Tensor,
579
581
  forward_batch: ForwardBatch,
582
+ get_embedding: bool = False,
580
583
  ):
581
584
  """Run forward pass for Qwen2-VL.
582
585
 
@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
599
602
  image_inputs = [
600
603
  img for img in forward_batch.image_inputs if img is not None
601
604
  ]
602
-
603
- positions = forward_batch.mrope_positions
605
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
606
+ positions = forward_batch.mrope_positions
604
607
  if (
605
608
  forward_batch.forward_mode.is_decode()
606
609
  or image_inputs is None
@@ -616,7 +619,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
616
619
 
617
620
  inputs_embeds = self.model.embed_tokens(input_ids)
618
621
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
619
- prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
622
+ prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
620
623
  for i, image in enumerate(forward_batch.image_inputs):
621
624
  if image is None:
622
625
  continue
@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
655
658
  forward_batch=forward_batch,
656
659
  input_embeds=inputs_embeds,
657
660
  )
658
- return self.logits_processor(
659
- input_ids, hidden_states, self.lm_head.weight, forward_batch
660
- )
661
+
662
+ if not get_embedding:
663
+ return self.logits_processor(
664
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
665
+ )
666
+ else:
667
+ return self.pooler(hidden_states, forward_batch)
661
668
 
662
669
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
663
670
  stacked_params_mapping = [