sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/configs/model_config.py +24 -1
  2. sglang/srt/conversation.py +21 -2
  3. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  4. sglang/srt/disaggregation/ascend/conn.py +44 -0
  5. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  6. sglang/srt/disaggregation/mooncake/conn.py +15 -14
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  8. sglang/srt/disaggregation/utils.py +25 -3
  9. sglang/srt/entrypoints/engine.py +1 -1
  10. sglang/srt/entrypoints/http_server.py +1 -0
  11. sglang/srt/entrypoints/openai/protocol.py +11 -0
  12. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/kimik2_detector.py +220 -0
  15. sglang/srt/hf_transformers_utils.py +18 -0
  16. sglang/srt/jinja_template_utils.py +8 -0
  17. sglang/srt/layers/communicator.py +17 -4
  18. sglang/srt/layers/linear.py +12 -2
  19. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  20. sglang/srt/layers/moe/ep_moe/layer.py +2 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
  22. sglang/srt/layers/moe/topk.py +8 -2
  23. sglang/srt/layers/parameter.py +19 -3
  24. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  25. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  26. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  27. sglang/srt/managers/io_struct.py +27 -2
  28. sglang/srt/managers/mm_utils.py +55 -94
  29. sglang/srt/managers/schedule_batch.py +16 -5
  30. sglang/srt/managers/scheduler.py +21 -1
  31. sglang/srt/managers/tokenizer_manager.py +16 -0
  32. sglang/srt/mem_cache/memory_pool.py +65 -40
  33. sglang/srt/model_executor/forward_batch_info.py +13 -1
  34. sglang/srt/model_loader/loader.py +23 -12
  35. sglang/srt/models/deepseek_janus_pro.py +1 -1
  36. sglang/srt/models/deepseek_v2.py +62 -17
  37. sglang/srt/models/deepseek_vl2.py +1 -1
  38. sglang/srt/models/gemma3_mm.py +1 -1
  39. sglang/srt/models/gemma3n_mm.py +6 -3
  40. sglang/srt/models/internvl.py +8 -2
  41. sglang/srt/models/kimi_vl.py +8 -2
  42. sglang/srt/models/llama.py +2 -0
  43. sglang/srt/models/llava.py +3 -1
  44. sglang/srt/models/llavavid.py +1 -1
  45. sglang/srt/models/minicpmo.py +1 -2
  46. sglang/srt/models/minicpmv.py +1 -1
  47. sglang/srt/models/mixtral_quant.py +4 -0
  48. sglang/srt/models/mllama4.py +13 -4
  49. sglang/srt/models/phi4mm.py +8 -2
  50. sglang/srt/models/phimoe.py +553 -0
  51. sglang/srt/models/qwen2.py +2 -0
  52. sglang/srt/models/qwen2_5_vl.py +10 -7
  53. sglang/srt/models/qwen2_vl.py +12 -1
  54. sglang/srt/models/vila.py +8 -2
  55. sglang/srt/multimodal/processors/base_processor.py +197 -137
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  57. sglang/srt/multimodal/processors/gemma3.py +4 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  59. sglang/srt/multimodal/processors/internvl.py +1 -1
  60. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  61. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  62. sglang/srt/multimodal/processors/minicpm.py +4 -3
  63. sglang/srt/multimodal/processors/mllama4.py +1 -1
  64. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  65. sglang/srt/multimodal/processors/pixtral.py +1 -1
  66. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  67. sglang/srt/multimodal/processors/vila.py +1 -1
  68. sglang/srt/server_args.py +11 -4
  69. sglang/srt/utils.py +154 -31
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
  72. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
  73. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,553 @@
1
+ from typing import Iterable, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
9
+ from sglang.srt.layers.linear import (
10
+ QKVParallelLinear,
11
+ ReplicatedLinear,
12
+ RowParallelLinear,
13
+ )
14
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
15
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
16
+ from sglang.srt.layers.pooler import Pooler, PoolingType
17
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
+ from sglang.srt.layers.radix_attention import RadixAttention
19
+ from sglang.srt.layers.rotary_embedding import get_rope
20
+ from sglang.srt.layers.utils import PPMissingLayer
21
+ from sglang.srt.layers.vocab_parallel_embedding import (
22
+ DEFAULT_VOCAB_PADDING_SIZE,
23
+ ParallelLMHead,
24
+ VocabParallelEmbedding,
25
+ )
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.model_loader.weight_utils import (
28
+ default_weight_loader,
29
+ maybe_remap_kv_scale_name,
30
+ )
31
+ from sglang.srt.utils import add_prefix, make_layers
32
+
33
+
34
+ class PhiMoEConfig(PretrainedConfig):
35
+
36
+ model_type = "phimoe"
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_size=32000,
41
+ hidden_size=4096,
42
+ intermediate_size=14336,
43
+ num_hidden_layers=32,
44
+ num_attention_heads=32,
45
+ num_key_value_heads=8,
46
+ head_dim=None,
47
+ hidden_act="silu",
48
+ max_position_embeddings=4096 * 32,
49
+ initializer_range=0.02,
50
+ rms_norm_eps=1e-5,
51
+ use_cache=True,
52
+ pad_token_id=None,
53
+ bos_token_id=1,
54
+ eos_token_id=2,
55
+ tie_word_embeddings=False,
56
+ rope_theta=1e6,
57
+ sliding_window=None,
58
+ attention_dropout=0.0,
59
+ num_experts_per_tok=2,
60
+ num_local_experts=16,
61
+ output_router_logits=False,
62
+ router_aux_loss_coef=0.001,
63
+ router_jitter_noise=0.0,
64
+ attention_bias=False,
65
+ lm_head_bias=False,
66
+ **kwargs,
67
+ ):
68
+ self.vocab_size = vocab_size
69
+ self.max_position_embeddings = max_position_embeddings
70
+ self.hidden_size = hidden_size
71
+ self.intermediate_size = intermediate_size
72
+ self.num_hidden_layers = num_hidden_layers
73
+ self.num_attention_heads = num_attention_heads
74
+ self.sliding_window = sliding_window
75
+ self.attention_bias = attention_bias
76
+ self.lm_head_bias = lm_head_bias
77
+ # for backward compatibility
78
+ if num_key_value_heads is None:
79
+ num_key_value_heads = num_attention_heads
80
+ if head_dim is None:
81
+ head_dim = hidden_size // num_attention_heads
82
+
83
+ self.num_key_value_heads = num_key_value_heads
84
+ self.head_dim = head_dim
85
+ self.hidden_act = hidden_act
86
+ self.initializer_range = initializer_range
87
+ self.rms_norm_eps = rms_norm_eps
88
+ self.use_cache = use_cache
89
+ self.rope_theta = rope_theta
90
+ self.attention_dropout = attention_dropout
91
+
92
+ self.num_experts_per_tok = num_experts_per_tok
93
+ self.num_local_experts = num_local_experts
94
+ self.output_router_logits = output_router_logits
95
+ self.router_aux_loss_coef = router_aux_loss_coef
96
+ self.router_jitter_noise = router_jitter_noise
97
+ super().__init__(
98
+ pad_token_id=pad_token_id,
99
+ bos_token_id=bos_token_id,
100
+ eos_token_id=eos_token_id,
101
+ tie_word_embeddings=tie_word_embeddings,
102
+ **kwargs,
103
+ )
104
+
105
+
106
+ def sparsemixer(scores, jitter_eps=0.01):
107
+ ################ Select first expert (topk=2) ################
108
+
109
+ # compute mask for sparsity
110
+ mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
111
+ factor = scores.abs().clamp(min=mask_logits_threshold)
112
+ mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (
113
+ 2 * jitter_eps
114
+ )
115
+
116
+ # apply mask
117
+ masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
118
+ selected_experts = max_ind
119
+
120
+ # compute scores for gradients
121
+ masked_gates = torch.softmax(masked_gates, dim=-1)
122
+ multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
123
+
124
+ multiplier = multiplier_o
125
+
126
+ # masked out first expert
127
+ masked_scores = torch.scatter(
128
+ scores,
129
+ -1,
130
+ selected_experts,
131
+ float("-inf"),
132
+ )
133
+
134
+ ################ Select second expert (topk=2) ################
135
+ # compute mask for sparsity
136
+ mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
137
+ factor = scores.abs().clamp(min=mask_logits_threshold)
138
+ mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (
139
+ 2 * jitter_eps
140
+ )
141
+
142
+ # apply mask
143
+ masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
144
+ selected_experts_top2 = max_ind
145
+ # compute scores for gradients
146
+ masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
147
+ multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
148
+
149
+ multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
150
+ selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
151
+
152
+ return (
153
+ multiplier,
154
+ selected_experts,
155
+ )
156
+
157
+
158
+ def phimoe_routing_function(
159
+ hidden_states: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ ):
164
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
165
+ assert topk == 2, "Only top-2 routing is supported"
166
+ assert renormalize is False, "Renormalization is not supported"
167
+
168
+ topk_weights, topk_ids = sparsemixer(gating_output)
169
+ return topk_weights, topk_ids
170
+
171
+
172
+ class PhiMoE(nn.Module):
173
+ """A tensor-parallel MoE implementation for PhiMoE that shards each expert
174
+ across all ranks.
175
+
176
+ Each expert's weights are sharded across all ranks and a fused MoE
177
+ kernel is used for the forward pass, and finally we reduce the outputs
178
+ across ranks.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ num_experts: int,
184
+ top_k: int,
185
+ hidden_size: int,
186
+ intermediate_size: int,
187
+ layer_id: int,
188
+ quant_config: Optional[QuantizationConfig] = None,
189
+ prefix: str = "",
190
+ ):
191
+ super().__init__()
192
+ self.hidden_size = hidden_size
193
+ self.tp_size = get_tensor_model_parallel_world_size()
194
+
195
+ # Gate always runs at half / full precision for now.
196
+ self.gate = ReplicatedLinear(
197
+ hidden_size,
198
+ num_experts,
199
+ bias=False,
200
+ quant_config=None,
201
+ )
202
+
203
+ self.experts = FusedMoE(
204
+ num_experts=num_experts,
205
+ top_k=top_k,
206
+ hidden_size=hidden_size,
207
+ intermediate_size=intermediate_size,
208
+ reduce_results=True,
209
+ renormalize=False,
210
+ quant_config=quant_config,
211
+ custom_routing_function=phimoe_routing_function,
212
+ prefix=add_prefix("experts", prefix),
213
+ )
214
+
215
+ def forward(
216
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
217
+ ) -> torch.Tensor:
218
+ # NOTE: hidden_states can have either 1D or 2D shape.
219
+ orig_shape = hidden_states.shape
220
+ hidden_states = hidden_states.view(-1, self.hidden_size)
221
+ router_logits, _ = self.gate(hidden_states)
222
+ final_hidden_states = self.experts(hidden_states, router_logits)
223
+ return final_hidden_states.view(orig_shape)
224
+
225
+
226
+ class PhiMoEAttention(nn.Module):
227
+
228
+ def __init__(
229
+ self,
230
+ hidden_size: int,
231
+ num_heads: int,
232
+ num_kv_heads: int,
233
+ head_dim: Optional[int] = None,
234
+ max_position: int = 4096 * 32,
235
+ rope_theta: float = 10000,
236
+ layer_id: int = 0,
237
+ attention_bias: bool = False,
238
+ quant_config: Optional[QuantizationConfig] = None,
239
+ rope_scaling: Optional[dict] = None,
240
+ prefix: str = "",
241
+ ) -> None:
242
+ super().__init__()
243
+ self.hidden_size = hidden_size
244
+
245
+ attn_tp_rank = get_attention_tp_rank()
246
+ attn_tp_size = get_attention_tp_size()
247
+
248
+ self.total_num_heads = num_heads
249
+ assert self.total_num_heads % attn_tp_size == 0
250
+ self.num_heads = self.total_num_heads // attn_tp_size
251
+ self.total_num_kv_heads = num_kv_heads
252
+ if self.total_num_kv_heads >= attn_tp_size:
253
+ # Number of KV heads is greater than TP size, so we partition
254
+ # the KV heads across multiple tensor parallel GPUs.
255
+ assert self.total_num_kv_heads % attn_tp_size == 0
256
+ else:
257
+ # Number of KV heads is less than TP size, so we replicate
258
+ # the KV heads across multiple tensor parallel GPUs.
259
+ assert attn_tp_size % self.total_num_kv_heads == 0
260
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
261
+ if head_dim is None:
262
+ head_dim = hidden_size // num_heads
263
+ self.head_dim = head_dim
264
+
265
+ self.q_size = self.num_heads * self.head_dim
266
+ self.kv_size = self.num_kv_heads * self.head_dim
267
+ self.scaling = self.head_dim**-0.5
268
+ self.rope_theta = rope_theta
269
+ self.rope_scaling = rope_scaling
270
+
271
+ self.qkv_proj = QKVParallelLinear(
272
+ hidden_size,
273
+ self.head_dim,
274
+ self.total_num_heads,
275
+ self.total_num_kv_heads,
276
+ bias=attention_bias,
277
+ quant_config=quant_config,
278
+ tp_rank=attn_tp_rank,
279
+ tp_size=attn_tp_size,
280
+ prefix=add_prefix("qkv_proj", prefix),
281
+ )
282
+ self.o_proj = RowParallelLinear(
283
+ self.total_num_heads * self.head_dim,
284
+ hidden_size,
285
+ bias=attention_bias,
286
+ quant_config=quant_config,
287
+ tp_rank=attn_tp_rank,
288
+ tp_size=attn_tp_size,
289
+ prefix=add_prefix("o_proj", prefix),
290
+ )
291
+ self.rotary_emb = get_rope(
292
+ self.head_dim,
293
+ rotary_dim=self.head_dim,
294
+ max_position=max_position,
295
+ base=int(self.rope_theta),
296
+ rope_scaling=self.rope_scaling,
297
+ )
298
+ self.attn = RadixAttention(
299
+ self.num_heads,
300
+ self.head_dim,
301
+ self.scaling,
302
+ num_kv_heads=self.num_kv_heads,
303
+ layer_id=layer_id,
304
+ quant_config=quant_config,
305
+ prefix=add_prefix("attn", prefix),
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ positions: torch.Tensor,
311
+ hidden_states: torch.Tensor,
312
+ forward_batch: ForwardBatch,
313
+ ) -> torch.Tensor:
314
+ qkv, _ = self.qkv_proj(hidden_states)
315
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
316
+ q, k = self.rotary_emb(positions, q, k)
317
+ attn_output = self.attn(q, k, v, forward_batch)
318
+ output, _ = self.o_proj(attn_output)
319
+ return output
320
+
321
+
322
+ class PhiMoEDecoderLayer(nn.Module):
323
+
324
+ def __init__(
325
+ self,
326
+ config: PhiMoEConfig,
327
+ layer_id: int,
328
+ quant_config: Optional[QuantizationConfig] = None,
329
+ prefix: str = "",
330
+ ) -> None:
331
+ super().__init__()
332
+ self.hidden_size = config.hidden_size
333
+ rope_theta = getattr(config, "rope_theta", 10000)
334
+ self.self_attn = PhiMoEAttention(
335
+ hidden_size=self.hidden_size,
336
+ num_heads=config.num_attention_heads,
337
+ max_position=config.max_position_embeddings,
338
+ num_kv_heads=config.num_key_value_heads,
339
+ head_dim=getattr(
340
+ config, "head_dim", self.hidden_size // config.num_attention_heads
341
+ ),
342
+ rope_theta=rope_theta,
343
+ layer_id=layer_id,
344
+ attention_bias=config.attention_bias,
345
+ quant_config=quant_config,
346
+ rope_scaling=config.rope_scaling,
347
+ prefix=add_prefix("self_attn", prefix),
348
+ )
349
+ self.block_sparse_moe = PhiMoE(
350
+ num_experts=config.num_local_experts,
351
+ top_k=config.num_experts_per_tok,
352
+ hidden_size=config.hidden_size,
353
+ intermediate_size=config.intermediate_size,
354
+ layer_id=layer_id,
355
+ quant_config=quant_config,
356
+ prefix=add_prefix("block_sparse_moe", prefix),
357
+ )
358
+ self.input_layernorm = nn.LayerNorm(
359
+ config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
360
+ )
361
+ self.post_attention_layernorm = nn.LayerNorm(
362
+ config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
363
+ )
364
+
365
+ def forward(
366
+ self,
367
+ positions: torch.Tensor,
368
+ hidden_states: torch.Tensor,
369
+ residual: Optional[torch.Tensor],
370
+ forward_batch: ForwardBatch,
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ residual = hidden_states
373
+
374
+ hidden_states = self.input_layernorm(hidden_states)
375
+
376
+ hidden_states = self.self_attn(
377
+ positions=positions,
378
+ hidden_states=hidden_states,
379
+ forward_batch=forward_batch,
380
+ )
381
+ hidden_states = hidden_states + residual
382
+
383
+ residual = hidden_states
384
+ hidden_states = self.post_attention_layernorm(hidden_states)
385
+ hidden_states = self.block_sparse_moe(
386
+ hidden_states, forward_batch=forward_batch
387
+ )
388
+
389
+ hidden_states = hidden_states + residual
390
+ return hidden_states, residual
391
+
392
+
393
+ class PhiMoEModel(nn.Module):
394
+
395
+ def __init__(
396
+ self,
397
+ config: PhiMoEConfig,
398
+ quant_config: Optional[QuantizationConfig] = None,
399
+ prefix: str = "",
400
+ ):
401
+ super().__init__()
402
+
403
+ self.config = config
404
+ self.quant_config = quant_config
405
+ self.vocab_size = config.vocab_size
406
+ self.embed_tokens = VocabParallelEmbedding(
407
+ config.vocab_size,
408
+ config.hidden_size,
409
+ quant_config=quant_config,
410
+ prefix=add_prefix("embed_tokens", prefix),
411
+ )
412
+
413
+ self.layers = make_layers(
414
+ config.num_hidden_layers,
415
+ lambda idx, prefix: PhiMoEDecoderLayer(
416
+ config, int(prefix.split(".")[-1]), quant_config, prefix=prefix
417
+ ),
418
+ prefix=add_prefix("layers", prefix),
419
+ )
420
+ self.norm = nn.LayerNorm(
421
+ config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
422
+ )
423
+
424
+ def forward(
425
+ self,
426
+ input_ids: torch.Tensor,
427
+ positions: torch.Tensor,
428
+ forward_batch: ForwardBatch,
429
+ input_embeds: Optional[torch.Tensor] = None,
430
+ ) -> Union[torch.Tensor]:
431
+ if input_embeds is None:
432
+ hidden_states = self.embed_tokens(input_ids)
433
+ else:
434
+ hidden_states = input_embeds
435
+ residual = None
436
+
437
+ for layer in self.layers:
438
+ hidden_states, residual = layer(
439
+ positions, hidden_states, residual, forward_batch=forward_batch
440
+ )
441
+
442
+ hidden_states = self.norm(hidden_states)
443
+ return hidden_states
444
+
445
+
446
+ class PhiMoEForCausalLM(nn.Module):
447
+
448
+ def __init__(
449
+ self,
450
+ config: PhiMoEConfig,
451
+ quant_config: Optional[QuantizationConfig] = None,
452
+ prefix: str = "",
453
+ ):
454
+
455
+ super().__init__()
456
+ self.config = config
457
+ self.quant_config = quant_config
458
+
459
+ self.model = PhiMoEModel(
460
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
461
+ )
462
+ self.lm_head = ParallelLMHead(
463
+ config.vocab_size,
464
+ config.hidden_size,
465
+ org_num_embeddings=config.vocab_size,
466
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
467
+ quant_config=quant_config,
468
+ bias=True,
469
+ prefix=add_prefix("lm_head", prefix),
470
+ )
471
+ if self.config.tie_word_embeddings:
472
+ self.lm_head.weight = self.model.embed_tokens.weight
473
+ self.logits_processor = LogitsProcessor(config)
474
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
475
+
476
+ @torch.no_grad()
477
+ def forward(
478
+ self,
479
+ input_ids: torch.Tensor,
480
+ positions: torch.Tensor,
481
+ forward_batch: ForwardBatch,
482
+ inputs_embeds: Optional[torch.Tensor] = None,
483
+ get_embedding: bool = False,
484
+ ) -> LogitsProcessorOutput:
485
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
486
+
487
+ if not get_embedding:
488
+ return self.logits_processor(
489
+ input_ids, hidden_states, self.lm_head, forward_batch
490
+ )
491
+
492
+ else:
493
+ return self.pooler(hidden_states, forward_batch)
494
+
495
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
496
+ stacked_params_mapping = [
497
+ # (param_name, shard_name, shard_id)
498
+ ("qkv_proj", "q_proj", "q"),
499
+ ("qkv_proj", "k_proj", "k"),
500
+ ("qkv_proj", "v_proj", "v"),
501
+ ]
502
+
503
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
504
+ ckpt_gate_proj_name="w1",
505
+ ckpt_down_proj_name="w2",
506
+ ckpt_up_proj_name="w3",
507
+ num_experts=self.config.num_local_experts,
508
+ )
509
+
510
+ params_dict = dict(self.named_parameters())
511
+ for name, loaded_weight in weights:
512
+ for param_name, weight_name, shard_id in stacked_params_mapping:
513
+ if weight_name not in name:
514
+ continue
515
+ name = name.replace(weight_name, param_name)
516
+ if name.endswith(".bias") and name not in params_dict:
517
+ continue
518
+ param = params_dict[name]
519
+ weight_loader = param.weight_loader
520
+ weight_loader(param, loaded_weight, shard_id)
521
+ break
522
+ else:
523
+ for mapping in expert_params_mapping:
524
+ param_name, weight_name, expert_id, shard_id = mapping
525
+ if weight_name not in name:
526
+ continue
527
+ name = name.replace(weight_name, param_name)
528
+ param = params_dict[name]
529
+ weight_loader = param.weight_loader
530
+ weight_loader(
531
+ param,
532
+ loaded_weight,
533
+ name,
534
+ shard_id=shard_id,
535
+ expert_id=expert_id,
536
+ )
537
+ break
538
+ else:
539
+ if name.endswith(".bias") and name not in params_dict:
540
+ continue
541
+ # Remapping the name of FP8 kv-scale.
542
+ name = maybe_remap_kv_scale_name(name, params_dict)
543
+ if name is None:
544
+ continue
545
+
546
+ param = params_dict[name]
547
+ weight_loader = getattr(
548
+ param, "weight_loader", default_weight_loader
549
+ )
550
+ weight_loader(param, loaded_weight)
551
+
552
+
553
+ EntryClass = PhiMoEForCausalLM
@@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module):
538
538
  # Skip loading extra bias for GPTQ models.
539
539
  if name.endswith(".bias") and name not in params_dict:
540
540
  continue
541
+ if name not in params_dict:
542
+ continue
541
543
  param = params_dict[name]
542
544
  weight_loader = param.weight_loader
543
545
  weight_loader(param, loaded_weight, shard_id)
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
56
56
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
57
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
58
  from sglang.srt.models.qwen2 import Qwen2Model
59
- from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
60
59
  from sglang.srt.utils import add_prefix
61
60
 
62
61
  logger = logging.getLogger(__name__)
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
507
506
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
508
507
  return image_embeds
509
508
 
510
- def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
511
- pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
512
- video_embeds = self.visual(
513
- pixel_values_videos, grid_thw=video_input["video_grid_thw"]
514
- )
509
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
510
+ # in qwen-vl, last dim is the same
511
+ pixel_values = torch.cat(
512
+ [getattr(item, "pixel_values_videos") for item in items], dim=0
513
+ ).type(self.visual.dtype)
514
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
515
+ assert pixel_values.dim() == 2, pixel_values.dim()
516
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
517
+ video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
515
518
  return video_embeds
516
519
 
517
520
  def get_input_embeddings(self):
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
553
556
  input_ids=input_ids,
554
557
  forward_batch=forward_batch,
555
558
  language_model=self.model,
556
- image_data_embedding_func=self.get_image_feature,
559
+ multimodal_model=self,
557
560
  positions=positions,
558
561
  )
559
562
 
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
493
493
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
494
494
  return image_embeds
495
495
 
496
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
497
+ # in qwen-vl, last dim is the same
498
+ pixel_values = torch.cat(
499
+ [item.pixel_values_videos for item in items], dim=0
500
+ ).type(self.visual.dtype)
501
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
502
+ assert pixel_values.dim() == 2, pixel_values.dim()
503
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
504
+ video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
505
+ return video_embeds
506
+
496
507
  def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
497
508
  pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
498
509
  video_embeds = self.visual(
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
538
549
  input_ids=input_ids,
539
550
  forward_batch=forward_batch,
540
551
  language_model=self.model,
541
- image_data_embedding_func=self.get_image_feature,
552
+ multimodal_model=self,
542
553
  positions=positions,
543
554
  )
544
555
 
sglang/srt/models/vila.py CHANGED
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
17
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
  from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
20
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
20
+ from sglang.srt.managers.schedule_batch import (
21
+ Modality,
22
+ MultimodalDataItem,
23
+ MultimodalInputs,
24
+ )
21
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
26
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
27
 
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
223
227
  input_ids=input_ids,
224
228
  forward_batch=forward_batch,
225
229
  language_model=self.llm,
226
- image_data_embedding_func=self.get_image_feature,
230
+ data_embedding_funcs={
231
+ Modality.IMAGE: self.get_image_feature,
232
+ },
227
233
  get_embedding=get_embedding,
228
234
  positions=positions,
229
235
  )