sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/hunyuan.py
CHANGED
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
|
+
from sglang.srt.eplb.expert_distribution import ExpertDistributionRecorder
|
31
32
|
from sglang.srt.layers.activation import SiluAndMul
|
32
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
34
|
from sglang.srt.layers.linear import (
|
@@ -48,7 +49,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
49
|
ParallelLMHead,
|
49
50
|
VocabParallelEmbedding,
|
50
51
|
)
|
51
|
-
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
52
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
53
|
from sglang.srt.model_loader.weight_utils import (
|
54
54
|
default_weight_loader,
|
sglang/srt/models/kimi_vl.py
CHANGED
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
|
|
154
154
|
return res
|
155
155
|
|
156
156
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
157
|
-
|
158
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
|
157
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
159
158
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
160
159
|
|
161
160
|
def forward(
|
sglang/srt/models/llama.py
CHANGED
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
|
|
697
697
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
698
698
|
self.model.load_kv_cache_scales(quantization_param_path)
|
699
699
|
|
700
|
-
def set_eagle3_layers_to_capture(self):
|
700
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
701
701
|
if not self.pp_group.is_last_rank:
|
702
702
|
return
|
703
703
|
|
704
|
-
|
705
|
-
|
706
|
-
|
704
|
+
if layer_ids is None:
|
705
|
+
self.capture_aux_hidden_states = True
|
706
|
+
num_layers = self.config.num_hidden_layers
|
707
|
+
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
708
|
+
else:
|
709
|
+
self.capture_aux_hidden_states = True
|
710
|
+
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
711
|
+
# of the (i-1)th layer as aux hidden state
|
712
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
707
713
|
|
708
714
|
|
709
715
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
sglang/srt/models/llama4.py
CHANGED
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
tensor_model_parallel_all_reduce,
|
29
29
|
)
|
30
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
30
31
|
from sglang.srt.layers.dp_attention import (
|
31
|
-
dp_gather_partial,
|
32
|
-
dp_scatter,
|
33
32
|
get_attention_tp_rank,
|
34
33
|
get_attention_tp_size,
|
35
34
|
get_local_attention_dp_size,
|
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
|
|
367
366
|
bias_o_proj=False,
|
368
367
|
prefix=add_prefix("self_attn", prefix),
|
369
368
|
)
|
370
|
-
|
369
|
+
self.config = config
|
370
|
+
is_moe_layer = self._is_moe_layer(layer_id)
|
371
|
+
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
|
372
|
+
|
371
373
|
if is_moe_layer:
|
372
374
|
self.feed_forward = Llama4MoE(
|
373
375
|
config=config,
|
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
|
|
387
389
|
config.hidden_size, eps=config.rms_norm_eps
|
388
390
|
)
|
389
391
|
|
392
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
393
|
+
layer_id=layer_id,
|
394
|
+
num_layers=config.num_hidden_layers,
|
395
|
+
is_layer_sparse=is_moe_layer,
|
396
|
+
is_previous_layer_sparse=is_previous_moe_layer,
|
397
|
+
)
|
398
|
+
|
399
|
+
self.layer_communicator = LayerCommunicator(
|
400
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
401
|
+
input_layernorm=self.input_layernorm,
|
402
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
403
|
+
)
|
404
|
+
|
405
|
+
def _is_moe_layer(self, layer_id: int) -> bool:
|
406
|
+
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
407
|
+
|
390
408
|
def forward(
|
391
409
|
self,
|
392
410
|
positions: torch.Tensor,
|
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
|
|
394
412
|
forward_batch: ForwardBatch,
|
395
413
|
residual: Optional[torch.Tensor],
|
396
414
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
397
|
-
|
398
|
-
residual
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
residual = hidden_states
|
403
|
-
hidden_states = self.input_layernorm(hidden_states)
|
404
|
-
else:
|
405
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
415
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
416
|
+
hidden_states, residual, forward_batch
|
417
|
+
)
|
418
|
+
|
419
|
+
if hidden_states.shape[0] != 0:
|
406
420
|
hidden_states = self.self_attn(
|
407
421
|
positions=positions,
|
408
422
|
hidden_states=hidden_states,
|
409
423
|
forward_batch=forward_batch,
|
410
424
|
)
|
411
425
|
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
if self.local_dp_size != 1:
|
416
|
-
if self.attn_tp_rank == 0:
|
417
|
-
hidden_states += residual
|
418
|
-
hidden_states, local_hidden_states = (
|
419
|
-
forward_batch.gathered_buffer,
|
420
|
-
hidden_states,
|
421
|
-
)
|
422
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
423
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
424
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
425
|
-
else:
|
426
|
-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
427
|
-
hidden_states, residual = self.post_attention_layernorm(
|
428
|
-
hidden_states, residual
|
429
|
-
)
|
430
|
-
else:
|
431
|
-
hidden_states, residual = self.post_attention_layernorm(
|
432
|
-
hidden_states, residual
|
433
|
-
)
|
426
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
427
|
+
hidden_states, residual, forward_batch
|
428
|
+
)
|
434
429
|
|
435
430
|
# Fully Connected
|
436
431
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
if self.local_dp_size != 1:
|
441
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
442
|
-
# be careful about this!
|
443
|
-
hidden_states, global_hidden_states = (
|
444
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
445
|
-
hidden_states,
|
446
|
-
)
|
447
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
432
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
433
|
+
hidden_states, residual, forward_batch
|
434
|
+
)
|
448
435
|
|
449
436
|
return hidden_states, residual
|
450
437
|
|
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
35
35
|
VocabParallelEmbedding,
|
36
36
|
)
|
37
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
38
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
|
+
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
|
39
40
|
|
40
41
|
|
41
42
|
class LlamaDecoderLayer(LlamaDecoderLayer):
|
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|
59
60
|
prefix=add_prefix("qkv_proj", prefix),
|
60
61
|
)
|
61
62
|
|
63
|
+
if config.model_type == "llama4_text":
|
64
|
+
inter_size = config.intermediate_size_mlp
|
65
|
+
else:
|
66
|
+
inter_size = config.intermediate_size
|
67
|
+
|
68
|
+
self.mlp = LlamaMLP(
|
69
|
+
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
|
70
|
+
)
|
71
|
+
|
62
72
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
63
73
|
|
64
74
|
def forward(
|
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
|
|
105
115
|
config.hidden_size,
|
106
116
|
prefix=add_prefix("embed_tokens", prefix),
|
107
117
|
)
|
108
|
-
|
118
|
+
|
109
119
|
if hasattr(config, "target_hidden_size"):
|
110
|
-
self.
|
120
|
+
self.hidden_size_in = config.target_hidden_size
|
111
121
|
else:
|
112
|
-
self.
|
122
|
+
self.hidden_size_in = config.hidden_size
|
123
|
+
|
124
|
+
self.fc = torch.nn.Linear(
|
125
|
+
self.hidden_size_in * 3,
|
126
|
+
config.hidden_size,
|
127
|
+
bias=getattr(config, "bias", False),
|
128
|
+
)
|
129
|
+
|
130
|
+
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
113
131
|
|
114
132
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
115
133
|
|
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
179
197
|
|
180
198
|
self.logits_processor = LogitsProcessor(config)
|
181
199
|
self.capture_aux_hidden_states = True
|
200
|
+
self.hot_token_id = None
|
201
|
+
|
202
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
203
|
+
params_dict = dict(self.named_parameters())
|
204
|
+
# Define the parameter mapping for stacked parameters
|
205
|
+
stacked_params_mapping = [
|
206
|
+
# (param_name, shard_name, shard_id)
|
207
|
+
(".qkv_proj", ".q_proj", "q"),
|
208
|
+
(".qkv_proj", ".k_proj", "k"),
|
209
|
+
(".qkv_proj", ".v_proj", "v"),
|
210
|
+
(".gate_up_proj", ".gate_proj", 0),
|
211
|
+
(".gate_up_proj", ".up_proj", 1),
|
212
|
+
]
|
182
213
|
|
183
|
-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
184
214
|
for name, loaded_weight in weights:
|
185
215
|
if "d2t" in name:
|
186
216
|
# d2t stores diffs between draft id and target id
|
187
217
|
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
218
|
+
continue
|
219
|
+
|
220
|
+
if "t2d" in name:
|
221
|
+
continue
|
222
|
+
|
223
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
224
|
+
if weight_name not in name:
|
225
|
+
continue
|
226
|
+
name = name.replace(weight_name, param_name)
|
227
|
+
param_name = f"model.{name}" if name not in params_dict else name
|
228
|
+
if param_name in params_dict:
|
229
|
+
param = params_dict[param_name]
|
230
|
+
weight_loader = getattr(
|
231
|
+
param, "weight_loader", default_weight_loader
|
232
|
+
)
|
233
|
+
weight_loader(param, loaded_weight, shard_id)
|
234
|
+
break
|
235
|
+
else:
|
236
|
+
# Handle regular parameters
|
237
|
+
param_name = name if name in params_dict else f"model.{name}"
|
238
|
+
if param_name in params_dict:
|
239
|
+
param = params_dict[param_name]
|
240
|
+
weight_loader = getattr(
|
241
|
+
param, "weight_loader", default_weight_loader
|
242
|
+
)
|
243
|
+
weight_loader(param, loaded_weight)
|
194
244
|
|
195
245
|
def get_hot_token_id(self):
|
196
246
|
return self.hot_token_id
|
sglang/srt/models/llava.py
CHANGED
@@ -41,16 +41,16 @@ from sglang.srt.managers.schedule_batch import (
|
|
41
41
|
MultimodalDataItem,
|
42
42
|
MultimodalInputs,
|
43
43
|
)
|
44
|
-
from sglang.srt.mm_utils import (
|
45
|
-
get_anyres_image_grid_shape,
|
46
|
-
unpad_image,
|
47
|
-
unpad_image_shape,
|
48
|
-
)
|
49
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
45
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
46
|
from sglang.srt.models.llama import LlamaForCausalLM
|
52
47
|
from sglang.srt.models.mistral import MistralForCausalLM
|
53
48
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
49
|
+
from sglang.srt.multimodal.mm_utils import (
|
50
|
+
get_anyres_image_grid_shape,
|
51
|
+
unpad_image,
|
52
|
+
unpad_image_shape,
|
53
|
+
)
|
54
54
|
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
55
55
|
|
56
56
|
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
|
|
32
32
|
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
|
33
33
|
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
34
34
|
from transformers.models.whisper.modeling_whisper import (
|
35
|
-
|
35
|
+
WhisperAttention,
|
36
36
|
WhisperConfig,
|
37
37
|
WhisperEncoder,
|
38
38
|
)
|
@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
|
1090
1090
|
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
1091
1091
|
super().__init__()
|
1092
1092
|
self.embed_dim = config.d_model
|
1093
|
-
self.self_attn =
|
1093
|
+
self.self_attn = WhisperAttention(
|
1094
1094
|
embed_dim=self.embed_dim,
|
1095
1095
|
num_heads=config.encoder_attention_heads,
|
1096
1096
|
dropout=config.attention_dropout,
|
sglang/srt/models/mistral.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Inference-only Mistral model."""
|
15
15
|
|
16
|
-
from typing import List
|
16
|
+
from typing import List
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
|