sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
|
|
697
697
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
698
698
|
self.model.load_kv_cache_scales(quantization_param_path)
|
699
699
|
|
700
|
-
def set_eagle3_layers_to_capture(self):
|
700
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
701
701
|
if not self.pp_group.is_last_rank:
|
702
702
|
return
|
703
703
|
|
704
|
-
|
705
|
-
|
706
|
-
|
704
|
+
if layer_ids is None:
|
705
|
+
self.capture_aux_hidden_states = True
|
706
|
+
num_layers = self.config.num_hidden_layers
|
707
|
+
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
708
|
+
else:
|
709
|
+
self.capture_aux_hidden_states = True
|
710
|
+
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
711
|
+
# of the (i-1)th layer as aux hidden state
|
712
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
707
713
|
|
708
714
|
|
709
715
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
sglang/srt/models/llama4.py
CHANGED
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
tensor_model_parallel_all_reduce,
|
29
29
|
)
|
30
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
30
31
|
from sglang.srt.layers.dp_attention import (
|
31
|
-
dp_gather_partial,
|
32
|
-
dp_scatter,
|
33
32
|
get_attention_tp_rank,
|
34
33
|
get_attention_tp_size,
|
35
34
|
get_local_attention_dp_size,
|
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
|
|
367
366
|
bias_o_proj=False,
|
368
367
|
prefix=add_prefix("self_attn", prefix),
|
369
368
|
)
|
370
|
-
|
369
|
+
self.config = config
|
370
|
+
is_moe_layer = self._is_moe_layer(layer_id)
|
371
|
+
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
|
372
|
+
|
371
373
|
if is_moe_layer:
|
372
374
|
self.feed_forward = Llama4MoE(
|
373
375
|
config=config,
|
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
|
|
387
389
|
config.hidden_size, eps=config.rms_norm_eps
|
388
390
|
)
|
389
391
|
|
392
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
393
|
+
layer_id=layer_id,
|
394
|
+
num_layers=config.num_hidden_layers,
|
395
|
+
is_layer_sparse=is_moe_layer,
|
396
|
+
is_previous_layer_sparse=is_previous_moe_layer,
|
397
|
+
)
|
398
|
+
|
399
|
+
self.layer_communicator = LayerCommunicator(
|
400
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
401
|
+
input_layernorm=self.input_layernorm,
|
402
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
403
|
+
)
|
404
|
+
|
405
|
+
def _is_moe_layer(self, layer_id: int) -> bool:
|
406
|
+
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
407
|
+
|
390
408
|
def forward(
|
391
409
|
self,
|
392
410
|
positions: torch.Tensor,
|
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
|
|
394
412
|
forward_batch: ForwardBatch,
|
395
413
|
residual: Optional[torch.Tensor],
|
396
414
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
397
|
-
|
398
|
-
residual
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
residual = hidden_states
|
403
|
-
hidden_states = self.input_layernorm(hidden_states)
|
404
|
-
else:
|
405
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
415
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
416
|
+
hidden_states, residual, forward_batch
|
417
|
+
)
|
418
|
+
|
419
|
+
if hidden_states.shape[0] != 0:
|
406
420
|
hidden_states = self.self_attn(
|
407
421
|
positions=positions,
|
408
422
|
hidden_states=hidden_states,
|
409
423
|
forward_batch=forward_batch,
|
410
424
|
)
|
411
425
|
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
if self.local_dp_size != 1:
|
416
|
-
if self.attn_tp_rank == 0:
|
417
|
-
hidden_states += residual
|
418
|
-
hidden_states, local_hidden_states = (
|
419
|
-
forward_batch.gathered_buffer,
|
420
|
-
hidden_states,
|
421
|
-
)
|
422
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
423
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
424
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
425
|
-
else:
|
426
|
-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
427
|
-
hidden_states, residual = self.post_attention_layernorm(
|
428
|
-
hidden_states, residual
|
429
|
-
)
|
430
|
-
else:
|
431
|
-
hidden_states, residual = self.post_attention_layernorm(
|
432
|
-
hidden_states, residual
|
433
|
-
)
|
426
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
427
|
+
hidden_states, residual, forward_batch
|
428
|
+
)
|
434
429
|
|
435
430
|
# Fully Connected
|
436
431
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
if self.local_dp_size != 1:
|
441
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
442
|
-
# be careful about this!
|
443
|
-
hidden_states, global_hidden_states = (
|
444
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
445
|
-
hidden_states,
|
446
|
-
)
|
447
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
432
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
433
|
+
hidden_states, residual, forward_batch
|
434
|
+
)
|
448
435
|
|
449
436
|
return hidden_states, residual
|
450
437
|
|
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
35
35
|
VocabParallelEmbedding,
|
36
36
|
)
|
37
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
38
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
|
+
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
|
39
40
|
|
40
41
|
|
41
42
|
class LlamaDecoderLayer(LlamaDecoderLayer):
|
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|
59
60
|
prefix=add_prefix("qkv_proj", prefix),
|
60
61
|
)
|
61
62
|
|
63
|
+
if config.model_type == "llama4_text":
|
64
|
+
inter_size = config.intermediate_size_mlp
|
65
|
+
else:
|
66
|
+
inter_size = config.intermediate_size
|
67
|
+
|
68
|
+
self.mlp = LlamaMLP(
|
69
|
+
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
|
70
|
+
)
|
71
|
+
|
62
72
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
63
73
|
|
64
74
|
def forward(
|
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
|
|
105
115
|
config.hidden_size,
|
106
116
|
prefix=add_prefix("embed_tokens", prefix),
|
107
117
|
)
|
108
|
-
|
118
|
+
|
109
119
|
if hasattr(config, "target_hidden_size"):
|
110
|
-
self.
|
120
|
+
self.hidden_size_in = config.target_hidden_size
|
111
121
|
else:
|
112
|
-
self.
|
122
|
+
self.hidden_size_in = config.hidden_size
|
123
|
+
|
124
|
+
self.fc = torch.nn.Linear(
|
125
|
+
self.hidden_size_in * 3,
|
126
|
+
config.hidden_size,
|
127
|
+
bias=getattr(config, "bias", False),
|
128
|
+
)
|
129
|
+
|
130
|
+
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
113
131
|
|
114
132
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
115
133
|
|
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
179
197
|
|
180
198
|
self.logits_processor = LogitsProcessor(config)
|
181
199
|
self.capture_aux_hidden_states = True
|
200
|
+
self.hot_token_id = None
|
201
|
+
|
202
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
203
|
+
params_dict = dict(self.named_parameters())
|
204
|
+
# Define the parameter mapping for stacked parameters
|
205
|
+
stacked_params_mapping = [
|
206
|
+
# (param_name, shard_name, shard_id)
|
207
|
+
(".qkv_proj", ".q_proj", "q"),
|
208
|
+
(".qkv_proj", ".k_proj", "k"),
|
209
|
+
(".qkv_proj", ".v_proj", "v"),
|
210
|
+
(".gate_up_proj", ".gate_proj", 0),
|
211
|
+
(".gate_up_proj", ".up_proj", 1),
|
212
|
+
]
|
182
213
|
|
183
|
-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
184
214
|
for name, loaded_weight in weights:
|
185
215
|
if "d2t" in name:
|
186
216
|
# d2t stores diffs between draft id and target id
|
187
217
|
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
218
|
+
continue
|
219
|
+
|
220
|
+
if "t2d" in name:
|
221
|
+
continue
|
222
|
+
|
223
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
224
|
+
if weight_name not in name:
|
225
|
+
continue
|
226
|
+
name = name.replace(weight_name, param_name)
|
227
|
+
param_name = f"model.{name}" if name not in params_dict else name
|
228
|
+
if param_name in params_dict:
|
229
|
+
param = params_dict[param_name]
|
230
|
+
weight_loader = getattr(
|
231
|
+
param, "weight_loader", default_weight_loader
|
232
|
+
)
|
233
|
+
weight_loader(param, loaded_weight, shard_id)
|
234
|
+
break
|
235
|
+
else:
|
236
|
+
# Handle regular parameters
|
237
|
+
param_name = name if name in params_dict else f"model.{name}"
|
238
|
+
if param_name in params_dict:
|
239
|
+
param = params_dict[param_name]
|
240
|
+
weight_loader = getattr(
|
241
|
+
param, "weight_loader", default_weight_loader
|
242
|
+
)
|
243
|
+
weight_loader(param, loaded_weight)
|
194
244
|
|
195
245
|
def get_hot_token_id(self):
|
196
246
|
return self.hot_token_id
|
sglang/srt/models/llava.py
CHANGED
@@ -41,16 +41,16 @@ from sglang.srt.managers.schedule_batch import (
|
|
41
41
|
MultimodalDataItem,
|
42
42
|
MultimodalInputs,
|
43
43
|
)
|
44
|
-
from sglang.srt.mm_utils import (
|
45
|
-
get_anyres_image_grid_shape,
|
46
|
-
unpad_image,
|
47
|
-
unpad_image_shape,
|
48
|
-
)
|
49
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
45
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
46
|
from sglang.srt.models.llama import LlamaForCausalLM
|
52
47
|
from sglang.srt.models.mistral import MistralForCausalLM
|
53
48
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
49
|
+
from sglang.srt.multimodal.mm_utils import (
|
50
|
+
get_anyres_image_grid_shape,
|
51
|
+
unpad_image,
|
52
|
+
unpad_image_shape,
|
53
|
+
)
|
54
54
|
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
55
55
|
|
56
56
|
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
|
|
32
32
|
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
|
33
33
|
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
34
34
|
from transformers.models.whisper.modeling_whisper import (
|
35
|
-
|
35
|
+
WhisperAttention,
|
36
36
|
WhisperConfig,
|
37
37
|
WhisperEncoder,
|
38
38
|
)
|
@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
|
1090
1090
|
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
1091
1091
|
super().__init__()
|
1092
1092
|
self.embed_dim = config.d_model
|
1093
|
-
self.self_attn =
|
1093
|
+
self.self_attn = WhisperAttention(
|
1094
1094
|
embed_dim=self.embed_dim,
|
1095
1095
|
num_heads=config.encoder_attention_heads,
|
1096
1096
|
dropout=config.attention_dropout,
|
sglang/srt/models/mistral.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Inference-only Mistral model."""
|
15
15
|
|
16
|
-
from typing import List
|
16
|
+
from typing import List
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
|
sglang/srt/models/mllama4.py
CHANGED
@@ -16,7 +16,9 @@ from sglang.srt.managers.mm_utils import (
|
|
16
16
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
17
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
18
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
19
|
-
from sglang.srt.utils import add_prefix
|
19
|
+
from sglang.srt.utils import add_prefix, is_cpu
|
20
|
+
|
21
|
+
_is_cpu = is_cpu()
|
20
22
|
|
21
23
|
|
22
24
|
class Llama4ForConditionalGeneration(nn.Module):
|
@@ -50,10 +52,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
50
52
|
self.logits_processor = LogitsProcessor(config.text_config)
|
51
53
|
|
52
54
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
53
|
-
|
54
|
-
im_token_id: int = mm_inputs.im_token_id
|
55
|
-
|
56
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
55
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
57
56
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
58
57
|
|
59
58
|
def get_image_feature(
|
@@ -110,13 +109,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
110
109
|
|
111
110
|
# rotary embeds should be sliced
|
112
111
|
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
113
|
-
|
114
|
-
|
115
|
-
|
112
|
+
if _is_cpu:
|
113
|
+
dim = self.language_model.config.original_total_num_kv_heads
|
114
|
+
else:
|
115
|
+
dim = self.language_model.config.num_key_value_heads
|
116
|
+
loaded_weight = permute(loaded_weight, dim)
|
116
117
|
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
117
|
-
|
118
|
-
|
119
|
-
|
118
|
+
if _is_cpu:
|
119
|
+
dim = self.language_model.config.original_num_attention_heads
|
120
|
+
else:
|
121
|
+
dim = self.language_model.config.num_attention_heads
|
122
|
+
loaded_weight = permute(loaded_weight, dim)
|
120
123
|
|
121
124
|
return name, loaded_weight
|
122
125
|
|
@@ -223,5 +226,34 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
223
226
|
)
|
224
227
|
weight_loader(param, loaded_weight)
|
225
228
|
|
229
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
230
|
+
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
|
231
|
+
self.language_model.set_eagle3_layers_to_capture(layer_ids)
|
232
|
+
|
233
|
+
def get_embed_and_head(self):
|
234
|
+
# For EAGLE3, we delegate to the language model which should have this method
|
235
|
+
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
|
236
|
+
embed = self.language_model.get_embed()
|
237
|
+
if hasattr(self.language_model, "get_embed_and_head"):
|
238
|
+
return self.language_model.get_embed_and_head()
|
239
|
+
elif hasattr(self.language_model, "lm_head"):
|
240
|
+
return embed, self.language_model.lm_head.weight
|
241
|
+
else:
|
242
|
+
# For EAGLE3, head might not be needed
|
243
|
+
return embed, None
|
244
|
+
|
245
|
+
def set_embed_and_head(self, embed, head):
|
246
|
+
if hasattr(self.language_model, "set_embed_and_head"):
|
247
|
+
return self.language_model.set_embed_and_head(embed, head)
|
248
|
+
else:
|
249
|
+
# For EAGLE3, only set embed
|
250
|
+
return self.language_model.set_embed(embed)
|
251
|
+
|
252
|
+
def get_embed(self):
|
253
|
+
return self.language_model.get_embed()
|
254
|
+
|
255
|
+
def set_embed(self, embed):
|
256
|
+
return self.language_model.set_embed(embed)
|
257
|
+
|
226
258
|
|
227
259
|
EntryClass = Llama4ForConditionalGeneration
|
sglang/srt/models/phi4mm.py
CHANGED
@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
|
|
446
446
|
return hidden_states
|
447
447
|
|
448
448
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
449
|
-
|
450
|
-
im_token_id: int = mm_inputs.im_token_id
|
451
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
449
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
452
450
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
453
451
|
|
454
452
|
def should_apply_lora(self, module_name: str) -> bool:
|
sglang/srt/models/pixtral.py
CHANGED
@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
|
|
268
268
|
|
269
269
|
DEFAULT_IMAGE_TOKEN_ID = 10
|
270
270
|
|
271
|
-
def pad_input_ids(self, input_ids: List[int],
|
272
|
-
return self.input_padder.pad_input_tokens(input_ids,
|
271
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
272
|
+
return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
|
273
273
|
|
274
274
|
def __init__(
|
275
275
|
self,
|
276
276
|
config: PixtralVisionConfig,
|
277
277
|
quant_config: Optional[QuantizationConfig] = None,
|
278
278
|
*,
|
279
|
-
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
280
279
|
num_hidden_layers_override: Optional[int] = None,
|
281
280
|
prefix: str = "",
|
282
281
|
) -> None:
|
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
|
|
314
313
|
)
|
315
314
|
|
316
315
|
# Initialize patch position embedding
|
317
|
-
self.image_token_id = image_token_id
|
318
316
|
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
319
|
-
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
320
|
-
[self.image_token_id]
|
321
|
-
)
|
317
|
+
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
|
322
318
|
|
323
319
|
@property
|
324
320
|
def dtype(self):
|
sglang/srt/models/qwen2.py
CHANGED
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
43
|
ParallelLMHead,
|
44
44
|
VocabParallelEmbedding,
|
45
45
|
)
|
46
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
48
|
from sglang.srt.model_loader.weight_utils import (
|
48
49
|
default_weight_loader,
|
@@ -100,6 +101,7 @@ class Qwen2Attention(nn.Module):
|
|
100
101
|
hidden_size: int,
|
101
102
|
num_heads: int,
|
102
103
|
num_kv_heads: int,
|
104
|
+
head_dim: Optional[int] = None,
|
103
105
|
layer_id: int = 0,
|
104
106
|
rope_theta: float = 1000000,
|
105
107
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
@@ -123,7 +125,10 @@ class Qwen2Attention(nn.Module):
|
|
123
125
|
# the KV heads across multiple tensor parallel GPUs.
|
124
126
|
assert tp_size % self.total_num_kv_heads == 0
|
125
127
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
126
|
-
|
128
|
+
if head_dim is not None:
|
129
|
+
self.head_dim = head_dim
|
130
|
+
else:
|
131
|
+
self.head_dim = hidden_size // self.total_num_heads
|
127
132
|
self.q_size = self.num_heads * self.head_dim
|
128
133
|
self.kv_size = self.num_kv_heads * self.head_dim
|
129
134
|
self.scaling = self.head_dim**-0.5
|
@@ -185,16 +190,19 @@ class Qwen2DecoderLayer(nn.Module):
|
|
185
190
|
layer_id: int = 0,
|
186
191
|
quant_config: Optional[QuantizationConfig] = None,
|
187
192
|
prefix: str = "",
|
193
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
188
194
|
) -> None:
|
189
195
|
super().__init__()
|
190
196
|
self.hidden_size = config.hidden_size
|
191
197
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
192
198
|
rope_scaling = getattr(config, "rope_scaling", None)
|
193
199
|
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
200
|
+
head_dim = getattr(config, "head_dim", None)
|
194
201
|
self.self_attn = Qwen2Attention(
|
195
202
|
hidden_size=self.hidden_size,
|
196
203
|
num_heads=config.num_attention_heads,
|
197
204
|
num_kv_heads=config.num_key_value_heads,
|
205
|
+
head_dim=head_dim,
|
198
206
|
layer_id=layer_id,
|
199
207
|
rope_theta=rope_theta,
|
200
208
|
rope_scaling=rope_scaling,
|
@@ -246,6 +254,7 @@ class Qwen2Model(nn.Module):
|
|
246
254
|
quant_config: Optional[QuantizationConfig] = None,
|
247
255
|
prefix: str = "",
|
248
256
|
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
|
257
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
249
258
|
) -> None:
|
250
259
|
super().__init__()
|
251
260
|
self.config = config
|
@@ -258,6 +267,7 @@ class Qwen2Model(nn.Module):
|
|
258
267
|
config.vocab_size,
|
259
268
|
config.hidden_size,
|
260
269
|
quant_config=quant_config,
|
270
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
261
271
|
prefix=add_prefix("embed_tokens", prefix),
|
262
272
|
)
|
263
273
|
else:
|
@@ -272,6 +282,7 @@ class Qwen2Model(nn.Module):
|
|
272
282
|
config=config,
|
273
283
|
quant_config=quant_config,
|
274
284
|
prefix=prefix,
|
285
|
+
alt_stream=alt_stream,
|
275
286
|
),
|
276
287
|
pp_rank=self.pp_group.rank_in_group,
|
277
288
|
pp_size=self.pp_group.world_size,
|
@@ -282,6 +293,9 @@ class Qwen2Model(nn.Module):
|
|
282
293
|
else:
|
283
294
|
self.norm = PPMissingLayer(return_tuple=True)
|
284
295
|
|
296
|
+
# For EAGLE3 support
|
297
|
+
self.layers_to_capture = []
|
298
|
+
|
285
299
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
286
300
|
if hasattr(self.config, "scale_emb"):
|
287
301
|
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
@@ -310,7 +324,12 @@ class Qwen2Model(nn.Module):
|
|
310
324
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
311
325
|
residual = pp_proxy_tensors["residual"]
|
312
326
|
|
327
|
+
aux_hidden_states = []
|
313
328
|
for i in range(self.start_layer, self.end_layer):
|
329
|
+
if i in self.layers_to_capture:
|
330
|
+
aux_hidden_states.append(
|
331
|
+
hidden_states + residual if residual is not None else hidden_states
|
332
|
+
)
|
314
333
|
layer = self.layers[i]
|
315
334
|
hidden_states, residual = layer(
|
316
335
|
positions,
|
@@ -326,8 +345,16 @@ class Qwen2Model(nn.Module):
|
|
326
345
|
}
|
327
346
|
)
|
328
347
|
else:
|
329
|
-
hidden_states
|
330
|
-
|
348
|
+
if hidden_states.shape[0] != 0:
|
349
|
+
if residual is None:
|
350
|
+
hidden_states = self.norm(hidden_states)
|
351
|
+
else:
|
352
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
353
|
+
|
354
|
+
if len(aux_hidden_states) == 0:
|
355
|
+
return hidden_states
|
356
|
+
|
357
|
+
return hidden_states, aux_hidden_states
|
331
358
|
|
332
359
|
# If this function is called, it should always initialize KV cache scale
|
333
360
|
# factors (or else raise an exception). Thus, handled exceptions should
|
@@ -398,6 +425,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
398
425
|
quant_config=quant_config,
|
399
426
|
prefix=add_prefix("lm_head", prefix),
|
400
427
|
)
|
428
|
+
|
401
429
|
else:
|
402
430
|
# ranks other than the last rank will have a placeholder layer
|
403
431
|
self.lm_head = PPMissingLayer()
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
493
493
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
494
494
|
|
495
495
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
496
|
-
|
497
|
-
im_token_id: int = mm_inputs.im_token_id
|
498
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
496
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
499
497
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
500
498
|
|
501
499
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|