sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/exaone.py
CHANGED
@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
|
|
307
307
|
self.transformer = ExaoneModel(
|
308
308
|
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
309
309
|
)
|
310
|
-
self.
|
311
|
-
|
312
|
-
|
310
|
+
if self.config.tie_word_embeddings:
|
311
|
+
self.lm_head = self.transformer.wte
|
312
|
+
else:
|
313
|
+
self.lm_head = ParallelLMHead(
|
314
|
+
config.vocab_size,
|
315
|
+
config.hidden_size,
|
316
|
+
prefix=add_prefix("lm_head", prefix),
|
317
|
+
)
|
313
318
|
self.logits_processor = LogitsProcessor(config)
|
314
319
|
|
315
320
|
@torch.no_grad()
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
|
-
from transformers import
|
24
|
+
from transformers import Gemma3Config, PreTrainedModel
|
25
25
|
|
26
26
|
from sglang.srt.hf_transformers_utils import get_processor
|
27
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
42
42
|
maybe_remap_kv_scale_name,
|
43
43
|
)
|
44
44
|
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
45
|
+
from sglang.srt.models.siglip import SiglipVisionModel
|
45
46
|
from sglang.srt.utils import add_prefix
|
46
47
|
|
47
48
|
logger = logging.getLogger(__name__)
|
@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
118
119
|
".k_proj.",
|
119
120
|
".v_proj.",
|
120
121
|
".o_proj.",
|
122
|
+
".out_proj.",
|
121
123
|
]
|
122
124
|
bitsandbytes_stacked_params_mapping = {
|
123
125
|
# shard_name, weight_name, index
|
@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
126
128
|
"v_proj": ("qkv_proj", 2),
|
127
129
|
"gate_proj": ("gate_up_proj", 0),
|
128
130
|
"up_proj": ("gate_up_proj", 1),
|
131
|
+
"out_proj": ("proj", 0),
|
129
132
|
}
|
130
133
|
|
131
134
|
packed_modules_mapping = {
|
@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
161
164
|
super().__init__(config=config)
|
162
165
|
self.config = config
|
163
166
|
self.quant_config = quant_config
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
167
|
+
|
168
|
+
self.vision_tower = SiglipVisionModel(
|
169
|
+
config=config.vision_config,
|
170
|
+
quant_config=quant_config,
|
171
|
+
prefix=add_prefix("vision_tower", prefix),
|
172
|
+
)
|
173
|
+
|
172
174
|
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
173
175
|
self.vocab_size = config.text_config.vocab_size
|
174
176
|
|
175
177
|
# Text model
|
176
178
|
self.language_model = Gemma3ForCausalLM(
|
177
|
-
config.text_config,
|
179
|
+
config.text_config,
|
180
|
+
quant_config,
|
181
|
+
prefix=add_prefix("language_model", prefix),
|
178
182
|
)
|
179
183
|
if self.language_model.logits_processor.logit_scale:
|
180
184
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
@@ -278,13 +282,28 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
278
282
|
Returns:
|
279
283
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
280
284
|
"""
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
285
|
+
if any(item.precomputed_features is not None for item in items):
|
286
|
+
if not all(item.precomputed_features is not None for item in items):
|
287
|
+
raise NotImplementedError(
|
288
|
+
"MM inputs where only some items are precomputed."
|
289
|
+
)
|
290
|
+
return torch.concat([item.precomputed_features for item in items])
|
286
291
|
|
287
|
-
|
292
|
+
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
293
|
+
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
294
|
+
vision_outputs_list = []
|
295
|
+
|
296
|
+
for pixel_value in all_pixel_values:
|
297
|
+
# Add batch dimension for single image processing
|
298
|
+
pixel_value_batch = pixel_value.unsqueeze(0)
|
299
|
+
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
|
300
|
+
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
|
301
|
+
|
302
|
+
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
|
303
|
+
vision_outputs_list.append(vision_output)
|
304
|
+
|
305
|
+
# Concatenate all vision outputs
|
306
|
+
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
288
307
|
image_features = self.multi_modal_projector(vision_outputs)
|
289
308
|
return image_features
|
290
309
|
|
@@ -360,6 +379,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
360
379
|
return self.language_model.tie_weights()
|
361
380
|
|
362
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
382
|
+
stacked_params_mapping = [
|
383
|
+
# (param_name, shard_name, shard_id)
|
384
|
+
(".qkv_proj", ".q_proj", "q"),
|
385
|
+
(".qkv_proj", ".k_proj", "k"),
|
386
|
+
(".qkv_proj", ".v_proj", "v"),
|
387
|
+
("gate_up_proj", "up_proj", 1),
|
388
|
+
("gate_up_proj", "gate_proj", 0),
|
389
|
+
]
|
363
390
|
"""Load weights for the model."""
|
364
391
|
params_dict = dict(self.named_parameters())
|
365
392
|
loaded_params: Set[str] = set()
|
@@ -373,21 +400,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
373
400
|
loaded_params.update(causal_loaded_params)
|
374
401
|
continue
|
375
402
|
else:
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
403
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
404
|
+
if weight_name not in name:
|
405
|
+
continue
|
406
|
+
name = name.replace(weight_name, param_name)
|
407
|
+
# Skip loading extra bias for GPTQ models.
|
408
|
+
if name.endswith(".bias") and name not in params_dict:
|
409
|
+
continue
|
410
|
+
param = params_dict[name]
|
411
|
+
weight_loader = param.weight_loader
|
412
|
+
weight_loader(param, loaded_weight, shard_id)
|
413
|
+
break
|
414
|
+
else:
|
415
|
+
if "vision_model" in name:
|
416
|
+
# adapt to VisionAttention
|
417
|
+
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
418
|
+
# Skip loading extra bias for GPTQ models
|
419
|
+
if name.endswith(".bias") and name not in params_dict:
|
420
|
+
continue
|
421
|
+
# Remapping the name of FP8 kv-scale
|
422
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
423
|
+
if name is None:
|
424
|
+
continue
|
425
|
+
param = params_dict[name]
|
426
|
+
weight_loader = getattr(
|
427
|
+
param, "weight_loader", default_weight_loader
|
428
|
+
)
|
429
|
+
weight_loader(param, loaded_weight)
|
391
430
|
loaded_params.add(name)
|
392
431
|
unloaded_params = params_dict.keys() - loaded_params
|
393
432
|
if unloaded_params:
|
@@ -398,5 +437,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
398
437
|
|
399
438
|
|
400
439
|
EntryClass = Gemma3ForConditionalGeneration
|
401
|
-
|
402
|
-
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
|
sglang/srt/models/llama4.py
CHANGED
@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
52
52
|
PPProxyTensors,
|
53
53
|
)
|
54
54
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
55
|
-
from sglang.srt.utils import
|
55
|
+
from sglang.srt.utils import (
|
56
|
+
add_prefix,
|
57
|
+
fast_topk,
|
58
|
+
get_compiler_backend,
|
59
|
+
is_cuda,
|
60
|
+
make_layers,
|
61
|
+
)
|
62
|
+
|
63
|
+
_is_cuda = is_cuda()
|
56
64
|
|
57
65
|
logger = logging.getLogger(__name__)
|
58
66
|
|
@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
|
|
131
139
|
return out_aD
|
132
140
|
|
133
141
|
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
134
|
-
if hidden_states.shape[0] < 4:
|
142
|
+
if hidden_states.shape[0] < 4 and _is_cuda:
|
135
143
|
return self._forward_core_shared_routed_overlap(hidden_states)
|
136
144
|
else:
|
137
145
|
return self._forward_core_normal(hidden_states)
|
sglang/srt/models/llava.py
CHANGED
@@ -135,7 +135,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
135
135
|
"""
|
136
136
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
137
137
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
138
|
-
|
139
138
|
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
140
139
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
141
140
|
selected_image_feature = selected_image_feature[:, 1:]
|
@@ -146,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
146
145
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
147
146
|
)
|
148
147
|
image_features = self.multi_modal_projector(selected_image_feature)
|
149
|
-
|
150
148
|
return image_features
|
151
149
|
|
152
150
|
@torch.no_grad()
|
@@ -613,6 +611,10 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
613
611
|
|
614
612
|
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
615
613
|
|
614
|
+
@property
|
615
|
+
def dtype(self):
|
616
|
+
return self.torch_dtype
|
617
|
+
|
616
618
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
617
619
|
if hasattr(self.vision_tower, "pad_input_ids"):
|
618
620
|
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
@@ -672,11 +674,17 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
672
674
|
assert hasattr(config, "text_config")
|
673
675
|
assert hasattr(config, "vision_config")
|
674
676
|
self.config = config
|
675
|
-
self.text_config = config.text_config
|
676
|
-
self.vision_config = config.vision_config
|
677
|
+
self.text_config = self.config.text_config
|
678
|
+
self.vision_config = self.config.vision_config
|
679
|
+
self.torch_dtype = getattr(self.config, "torch_dtype")
|
680
|
+
|
681
|
+
if not getattr(self.text_config, "torch_dtype"):
|
682
|
+
self.text_config.torch_dtype = self.torch_dtype
|
683
|
+
if not getattr(self.vision_config, "torch_dtype"):
|
684
|
+
self.vision_config.torch_dtype = self.torch_dtype
|
677
685
|
|
678
686
|
if not hasattr(self.config, "vocab_size"):
|
679
|
-
self.config.vocab_size = self.
|
687
|
+
self.config.vocab_size = self.text_config.vocab_size
|
680
688
|
if not hasattr(self.config, "image_aspect_ratio"):
|
681
689
|
self.config.image_aspect_ratio = "anyres"
|
682
690
|
if not hasattr(self.config, "image_grid_pinpoints"):
|
@@ -697,39 +705,39 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
|
697
705
|
if not hasattr(self.config, "projector_hidden_act"):
|
698
706
|
self.config.projector_hidden_act = "gelu"
|
699
707
|
|
700
|
-
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
708
|
+
self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
|
701
709
|
self.vision_feature_select_strategy = getattr(
|
702
|
-
config, "vision_feature_select_strategy", "full"
|
710
|
+
self.config, "vision_feature_select_strategy", "full"
|
703
711
|
)
|
704
|
-
self.image_size = self.
|
705
|
-
self.patch_size = self.
|
712
|
+
self.image_size = self.vision_config.image_size
|
713
|
+
self.patch_size = self.vision_config.patch_size
|
706
714
|
|
707
|
-
self.mm_patch_merge_type = config.mm_patch_merge_type
|
708
|
-
self.image_aspect_ratio = config.image_aspect_ratio
|
709
|
-
self.image_grid_pinpoints = config.image_grid_pinpoints
|
715
|
+
self.mm_patch_merge_type = self.config.mm_patch_merge_type
|
716
|
+
self.image_aspect_ratio = self.config.image_aspect_ratio
|
717
|
+
self.image_grid_pinpoints = self.config.image_grid_pinpoints
|
710
718
|
|
711
719
|
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
712
720
|
|
713
721
|
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
714
722
|
|
715
723
|
language_model_cls = self._get_sgl_model_cls(
|
716
|
-
|
724
|
+
self.text_config, AutoModelForCausalLM
|
717
725
|
)
|
718
|
-
vision_model_cls = self._get_sgl_model_cls(
|
726
|
+
vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
|
719
727
|
self.language_model = language_model_cls(
|
720
|
-
|
728
|
+
self.text_config,
|
721
729
|
quant_config=quant_config,
|
722
730
|
prefix=add_prefix("language_model", prefix),
|
723
731
|
)
|
724
732
|
self.vision_tower = vision_model_cls(
|
725
|
-
|
733
|
+
self.vision_config,
|
726
734
|
quant_config=quant_config,
|
727
735
|
prefix=add_prefix("vision_tower", prefix),
|
728
736
|
)
|
729
737
|
|
730
|
-
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
738
|
+
if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
|
731
739
|
self.language_model.model.image_newline = nn.Parameter(
|
732
|
-
torch.empty(
|
740
|
+
torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
|
733
741
|
)
|
734
742
|
|
735
743
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
@@ -0,0 +1,220 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
|
2
|
+
|
3
|
+
from functools import partial
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import PretrainedConfig
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
split_tensor_along_last_dim,
|
14
|
+
tensor_model_parallel_all_gather,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
17
|
+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
23
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
|
+
ParallelLMHead,
|
25
|
+
VocabParallelEmbedding,
|
26
|
+
)
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.models.mimo import MiMoForCausalLM
|
30
|
+
from sglang.srt.models.qwen2 import (
|
31
|
+
Qwen2Attention,
|
32
|
+
Qwen2DecoderLayer,
|
33
|
+
Qwen2MLP,
|
34
|
+
Qwen2Model,
|
35
|
+
)
|
36
|
+
from sglang.srt.utils import add_prefix
|
37
|
+
|
38
|
+
|
39
|
+
class MiMoMultiTokenPredictorLayer(nn.Module):
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: PretrainedConfig,
|
44
|
+
prefix: str,
|
45
|
+
quant_config: Optional[QuantizationConfig] = None,
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
self.embed_tokens = VocabParallelEmbedding(
|
50
|
+
config.vocab_size,
|
51
|
+
config.hidden_size,
|
52
|
+
)
|
53
|
+
self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
54
|
+
self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
55
|
+
self.input_proj = nn.Linear(
|
56
|
+
config.hidden_size * 2, config.hidden_size, bias=False
|
57
|
+
)
|
58
|
+
self.mtp_block = Qwen2DecoderLayer(
|
59
|
+
config=config, quant_config=quant_config, prefix=prefix
|
60
|
+
)
|
61
|
+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
62
|
+
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
input_ids: torch.Tensor,
|
66
|
+
positions: torch.Tensor,
|
67
|
+
forward_batch: ForwardBatch,
|
68
|
+
input_embeds: torch.Tensor = None,
|
69
|
+
) -> torch.Tensor:
|
70
|
+
|
71
|
+
if input_embeds is None:
|
72
|
+
hidden_states = self.embed_tokens(input_ids)
|
73
|
+
else:
|
74
|
+
hidden_states = input_embeds
|
75
|
+
# masking inputs at position 0, as not needed by MTP
|
76
|
+
hidden_states[positions == 0] = 0
|
77
|
+
|
78
|
+
hidden_states = self.input_proj(
|
79
|
+
torch.cat(
|
80
|
+
(
|
81
|
+
self.hidden_layernorm(forward_batch.spec_info.hidden_states),
|
82
|
+
self.token_layernorm(hidden_states),
|
83
|
+
),
|
84
|
+
dim=-1,
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
hidden_states, residual = self.mtp_block(
|
89
|
+
positions=positions,
|
90
|
+
hidden_states=hidden_states,
|
91
|
+
forward_batch=forward_batch,
|
92
|
+
residual=None,
|
93
|
+
)
|
94
|
+
hidden_states = residual + hidden_states
|
95
|
+
hidden_states = self.final_layernorm(hidden_states)
|
96
|
+
return hidden_states
|
97
|
+
|
98
|
+
|
99
|
+
class MiMoMTP(nn.Module):
|
100
|
+
def __init__(
|
101
|
+
self,
|
102
|
+
config: PretrainedConfig,
|
103
|
+
quant_config: Optional[QuantizationConfig] = None,
|
104
|
+
prefix: str = "",
|
105
|
+
) -> None:
|
106
|
+
nn.Module.__init__(self)
|
107
|
+
self.config = config
|
108
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
109
|
+
self.quant_config = quant_config
|
110
|
+
|
111
|
+
self.model = MiMoMultiTokenPredictorLayer(
|
112
|
+
config,
|
113
|
+
prefix,
|
114
|
+
quant_config,
|
115
|
+
)
|
116
|
+
self.lm_head = ParallelLMHead(
|
117
|
+
config.vocab_size,
|
118
|
+
config.hidden_size,
|
119
|
+
quant_config=quant_config,
|
120
|
+
)
|
121
|
+
self.logits_processor = LogitsProcessor(config)
|
122
|
+
|
123
|
+
@torch.no_grad()
|
124
|
+
def forward(
|
125
|
+
self,
|
126
|
+
input_ids: torch.Tensor,
|
127
|
+
positions: torch.Tensor,
|
128
|
+
forward_batch: ForwardBatch,
|
129
|
+
) -> torch.Tensor:
|
130
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
131
|
+
return self.logits_processor(
|
132
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
133
|
+
)
|
134
|
+
|
135
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
136
|
+
stacked_params_mapping = [
|
137
|
+
# (param_name, shard_name, shard_id)
|
138
|
+
("qkv_proj", "q_proj", "q"),
|
139
|
+
("qkv_proj", "k_proj", "k"),
|
140
|
+
("qkv_proj", "v_proj", "v"),
|
141
|
+
("gate_up_proj", "gate_proj", 0),
|
142
|
+
("gate_up_proj", "up_proj", 1),
|
143
|
+
]
|
144
|
+
|
145
|
+
params_dict = dict(self.named_parameters())
|
146
|
+
for name, loaded_weight in weights:
|
147
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
148
|
+
continue
|
149
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
150
|
+
# Models trained using ColossalAI may include these tensors in
|
151
|
+
# the checkpoint. Skip them.
|
152
|
+
continue
|
153
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
154
|
+
continue
|
155
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
156
|
+
continue
|
157
|
+
name = self.map_model_name_to_mtp_param_name(name)
|
158
|
+
|
159
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
160
|
+
if weight_name not in name:
|
161
|
+
continue
|
162
|
+
if "mtp_block" not in name:
|
163
|
+
break
|
164
|
+
name = name.replace(weight_name, param_name)
|
165
|
+
# Skip loading extra bias for GPTQ models.
|
166
|
+
if name.endswith(".bias") and name not in params_dict:
|
167
|
+
continue
|
168
|
+
param = params_dict[name]
|
169
|
+
weight_loader = param.weight_loader
|
170
|
+
weight_loader(param, loaded_weight, shard_id)
|
171
|
+
break
|
172
|
+
else:
|
173
|
+
# Skip loading extra bias for GPTQ models.
|
174
|
+
if name.endswith(".bias") and name not in params_dict:
|
175
|
+
continue
|
176
|
+
if "mtp_block" not in name and (
|
177
|
+
"embed_tokens" not in name
|
178
|
+
and "lm_head" not in name
|
179
|
+
and "token_layernorm" not in name
|
180
|
+
and "hidden_layernorm" not in name
|
181
|
+
and "input_proj" not in name
|
182
|
+
and "final_layernorm" not in name
|
183
|
+
):
|
184
|
+
continue
|
185
|
+
param = params_dict[name]
|
186
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
187
|
+
weight_loader(param, loaded_weight)
|
188
|
+
|
189
|
+
def map_model_name_to_mtp_param_name(self, name: str) -> str:
|
190
|
+
import re
|
191
|
+
|
192
|
+
name_without_prefix = [
|
193
|
+
"token_layernorm",
|
194
|
+
"hidden_layernorm",
|
195
|
+
"input_proj",
|
196
|
+
"final_layernorm",
|
197
|
+
]
|
198
|
+
pattern = r"model.mtp_layers.(\d+)."
|
199
|
+
group = re.match(pattern, name)
|
200
|
+
if group is not None:
|
201
|
+
for sub_name in name_without_prefix:
|
202
|
+
if sub_name in name:
|
203
|
+
name = name.replace(group.group(), "model.")
|
204
|
+
return name
|
205
|
+
name = name.replace(group.group(), "model.mtp_block.")
|
206
|
+
return name
|
207
|
+
|
208
|
+
def get_embed_and_head(self):
|
209
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
210
|
+
|
211
|
+
def set_embed_and_head(self, embed, head):
|
212
|
+
del self.model.embed_tokens.weight
|
213
|
+
del self.lm_head.weight
|
214
|
+
self.model.embed_tokens.weight = embed
|
215
|
+
self.lm_head.weight = head
|
216
|
+
torch.cuda.empty_cache()
|
217
|
+
torch.cuda.synchronize()
|
218
|
+
|
219
|
+
|
220
|
+
EntryClass = MiMoMTP
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1520
1520
|
slice_start_id: int = mm_input.slice_start_id
|
1521
1521
|
slice_end_id: int = mm_input.slice_end_id
|
1522
1522
|
|
1523
|
-
|
1523
|
+
data_token_pairs = [
|
1524
1524
|
(im_start_id, im_end_id),
|
1525
1525
|
(slice_start_id, slice_end_id),
|
1526
1526
|
(mm_input.audio_start_id, mm_input.audio_end_id),
|
1527
1527
|
]
|
1528
|
-
|
1528
|
+
data_start_token_ids = [im_start_id, mm_input.audio_start_id]
|
1529
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(
|
1530
|
+
data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
|
1531
|
+
)
|
1529
1532
|
|
1530
1533
|
return pattern.pad_input_tokens(input_ids, mm_input)
|
1531
1534
|
|
@@ -1823,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1823
1826
|
**kwargs: Any,
|
1824
1827
|
) -> torch.Tensor:
|
1825
1828
|
|
1826
|
-
mm_input = forward_batch.merge_mm_inputs()
|
1827
|
-
placeholder_token_ids = (
|
1828
|
-
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
1829
|
-
if forward_batch.contains_mm_inputs()
|
1830
|
-
else []
|
1831
|
-
)
|
1832
1829
|
hidden_states = general_mm_embed_routine(
|
1833
1830
|
input_ids=input_ids,
|
1834
1831
|
forward_batch=forward_batch,
|
1835
1832
|
language_model=self.llm,
|
1836
1833
|
image_data_embedding_func=self.get_image_feature,
|
1837
1834
|
audio_data_embedding_func=self.get_audio_feature,
|
1838
|
-
placeholder_tokens={
|
1839
|
-
Modality.IMAGE: placeholder_token_ids,
|
1840
|
-
Modality.AUDIO: placeholder_token_ids,
|
1841
|
-
},
|
1842
1835
|
positions=positions,
|
1843
1836
|
)
|
1844
1837
|
return hidden_states
|
sglang/srt/models/mistral.py
CHANGED
@@ -13,6 +13,12 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Inference-only Mistral model."""
|
15
15
|
|
16
|
+
from typing import List, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
|
20
|
+
|
21
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
16
22
|
from sglang.srt.models.llama import LlamaForCausalLM
|
17
23
|
|
18
24
|
|
@@ -20,4 +26,68 @@ class MistralForCausalLM(LlamaForCausalLM):
|
|
20
26
|
pass
|
21
27
|
|
22
28
|
|
23
|
-
|
29
|
+
class Mistral3ForConditionalGeneration:
|
30
|
+
MULTIMODAL_PROJECTOR_TYPE = Mistral3MultiModalProjector
|
31
|
+
|
32
|
+
def __init__(self, **kwargs):
|
33
|
+
# lazy load inner class
|
34
|
+
# to bypass circular import
|
35
|
+
from sglang.srt.models.llava import LlavaForConditionalGeneration
|
36
|
+
|
37
|
+
# override config: mistral's projector adds patchmerger that doesn't require padding
|
38
|
+
kwargs["config"].vision_config.pad_image_border = False
|
39
|
+
|
40
|
+
self.inner = LlavaForConditionalGeneration(**kwargs)
|
41
|
+
self.inner.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(
|
42
|
+
kwargs["config"]
|
43
|
+
)
|
44
|
+
self.inner.get_image_feature = self.get_image_feature
|
45
|
+
|
46
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
47
|
+
"""Extract features from image inputs.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
items: List of MultimodalDataItem objects containing image data
|
51
|
+
Note that an item can be either "image" or "multi-images"
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
torch.Tensor: features from image inputs, concatenated
|
55
|
+
"""
|
56
|
+
features = []
|
57
|
+
for item in items:
|
58
|
+
# in each item, we assume pixel_values is always batched
|
59
|
+
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
60
|
+
image_outputs = self.vision_tower(
|
61
|
+
pixel_values, image_sizes, output_hidden_states=True
|
62
|
+
)
|
63
|
+
selected_image_feature = image_outputs.hidden_states[
|
64
|
+
self.vision_feature_layer
|
65
|
+
]
|
66
|
+
|
67
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
68
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
69
|
+
elif self.vision_feature_select_strategy == "full":
|
70
|
+
selected_image_feature = selected_image_feature
|
71
|
+
else:
|
72
|
+
raise ValueError(
|
73
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
74
|
+
)
|
75
|
+
features.append(
|
76
|
+
self.multi_modal_projector(
|
77
|
+
selected_image_feature.squeeze(0), image_sizes
|
78
|
+
)
|
79
|
+
)
|
80
|
+
ret = torch.cat(features, dim=0)
|
81
|
+
return ret
|
82
|
+
|
83
|
+
def __getattr__(self, name):
|
84
|
+
return getattr(self.inner, name)
|
85
|
+
|
86
|
+
def __hasattr__(self, name):
|
87
|
+
return hasattr(self.inner, name)
|
88
|
+
|
89
|
+
def __call__(self, *args, **kwargs):
|
90
|
+
return self.inner(*args, **kwargs)
|
91
|
+
|
92
|
+
|
93
|
+
EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration]
|