sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/exaone.py
CHANGED
@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
|
|
307
307
|
self.transformer = ExaoneModel(
|
308
308
|
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
309
309
|
)
|
310
|
-
self.
|
311
|
-
|
312
|
-
|
310
|
+
if self.config.tie_word_embeddings:
|
311
|
+
self.lm_head = self.transformer.wte
|
312
|
+
else:
|
313
|
+
self.lm_head = ParallelLMHead(
|
314
|
+
config.vocab_size,
|
315
|
+
config.hidden_size,
|
316
|
+
prefix=add_prefix("lm_head", prefix),
|
317
|
+
)
|
313
318
|
self.logits_processor = LogitsProcessor(config)
|
314
319
|
|
315
320
|
@torch.no_grad()
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
|
-
from transformers import
|
24
|
+
from transformers import Gemma3Config, PreTrainedModel
|
25
25
|
|
26
26
|
from sglang.srt.hf_transformers_utils import get_processor
|
27
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
42
42
|
maybe_remap_kv_scale_name,
|
43
43
|
)
|
44
44
|
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
45
|
+
from sglang.srt.models.siglip import SiglipVisionModel
|
45
46
|
from sglang.srt.utils import add_prefix
|
46
47
|
|
47
48
|
logger = logging.getLogger(__name__)
|
@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
118
119
|
".k_proj.",
|
119
120
|
".v_proj.",
|
120
121
|
".o_proj.",
|
122
|
+
".out_proj.",
|
121
123
|
]
|
122
124
|
bitsandbytes_stacked_params_mapping = {
|
123
125
|
# shard_name, weight_name, index
|
@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
126
128
|
"v_proj": ("qkv_proj", 2),
|
127
129
|
"gate_proj": ("gate_up_proj", 0),
|
128
130
|
"up_proj": ("gate_up_proj", 1),
|
131
|
+
"out_proj": ("proj", 0),
|
129
132
|
}
|
130
133
|
|
131
134
|
packed_modules_mapping = {
|
@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
161
164
|
super().__init__(config=config)
|
162
165
|
self.config = config
|
163
166
|
self.quant_config = quant_config
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
167
|
+
|
168
|
+
self.vision_tower = SiglipVisionModel(
|
169
|
+
config=config.vision_config,
|
170
|
+
quant_config=quant_config,
|
171
|
+
prefix=add_prefix("vision_tower", prefix),
|
172
|
+
)
|
173
|
+
|
172
174
|
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
173
175
|
self.vocab_size = config.text_config.vocab_size
|
174
176
|
|
175
177
|
# Text model
|
176
178
|
self.language_model = Gemma3ForCausalLM(
|
177
|
-
config.text_config,
|
179
|
+
config.text_config,
|
180
|
+
quant_config,
|
181
|
+
prefix=add_prefix("language_model", prefix),
|
178
182
|
)
|
179
183
|
if self.language_model.logits_processor.logit_scale:
|
180
184
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
@@ -278,13 +282,28 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
278
282
|
Returns:
|
279
283
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
280
284
|
"""
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
285
|
+
if any(item.precomputed_features is not None for item in items):
|
286
|
+
if not all(item.precomputed_features is not None for item in items):
|
287
|
+
raise NotImplementedError(
|
288
|
+
"MM inputs where only some items are precomputed."
|
289
|
+
)
|
290
|
+
return torch.concat([item.precomputed_features for item in items])
|
286
291
|
|
287
|
-
|
292
|
+
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
293
|
+
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
294
|
+
vision_outputs_list = []
|
295
|
+
|
296
|
+
for pixel_value in all_pixel_values:
|
297
|
+
# Add batch dimension for single image processing
|
298
|
+
pixel_value_batch = pixel_value.unsqueeze(0)
|
299
|
+
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
|
300
|
+
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
|
301
|
+
|
302
|
+
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
|
303
|
+
vision_outputs_list.append(vision_output)
|
304
|
+
|
305
|
+
# Concatenate all vision outputs
|
306
|
+
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
288
307
|
image_features = self.multi_modal_projector(vision_outputs)
|
289
308
|
return image_features
|
290
309
|
|
@@ -360,6 +379,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
360
379
|
return self.language_model.tie_weights()
|
361
380
|
|
362
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
382
|
+
stacked_params_mapping = [
|
383
|
+
# (param_name, shard_name, shard_id)
|
384
|
+
(".qkv_proj", ".q_proj", "q"),
|
385
|
+
(".qkv_proj", ".k_proj", "k"),
|
386
|
+
(".qkv_proj", ".v_proj", "v"),
|
387
|
+
("gate_up_proj", "up_proj", 1),
|
388
|
+
("gate_up_proj", "gate_proj", 0),
|
389
|
+
]
|
363
390
|
"""Load weights for the model."""
|
364
391
|
params_dict = dict(self.named_parameters())
|
365
392
|
loaded_params: Set[str] = set()
|
@@ -373,21 +400,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
373
400
|
loaded_params.update(causal_loaded_params)
|
374
401
|
continue
|
375
402
|
else:
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
403
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
404
|
+
if weight_name not in name:
|
405
|
+
continue
|
406
|
+
name = name.replace(weight_name, param_name)
|
407
|
+
# Skip loading extra bias for GPTQ models.
|
408
|
+
if name.endswith(".bias") and name not in params_dict:
|
409
|
+
continue
|
410
|
+
param = params_dict[name]
|
411
|
+
weight_loader = param.weight_loader
|
412
|
+
weight_loader(param, loaded_weight, shard_id)
|
413
|
+
break
|
414
|
+
else:
|
415
|
+
if "vision_model" in name:
|
416
|
+
# adapt to VisionAttention
|
417
|
+
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
418
|
+
# Skip loading extra bias for GPTQ models
|
419
|
+
if name.endswith(".bias") and name not in params_dict:
|
420
|
+
continue
|
421
|
+
# Remapping the name of FP8 kv-scale
|
422
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
423
|
+
if name is None:
|
424
|
+
continue
|
425
|
+
param = params_dict[name]
|
426
|
+
weight_loader = getattr(
|
427
|
+
param, "weight_loader", default_weight_loader
|
428
|
+
)
|
429
|
+
weight_loader(param, loaded_weight)
|
391
430
|
loaded_params.add(name)
|
392
431
|
unloaded_params = params_dict.keys() - loaded_params
|
393
432
|
if unloaded_params:
|
@@ -398,5 +437,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
398
437
|
|
399
438
|
|
400
439
|
EntryClass = Gemma3ForConditionalGeneration
|
401
|
-
|
402
|
-
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
|
sglang/srt/models/llama.py
CHANGED
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
45
|
ParallelLMHead,
|
46
46
|
VocabParallelEmbedding,
|
47
47
|
)
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
49
50
|
from sglang.srt.model_loader.weight_utils import (
|
50
51
|
default_weight_loader,
|
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
|
|
420
421
|
config.hidden_size,
|
421
422
|
quant_config=quant_config,
|
422
423
|
prefix=add_prefix("lm_head", prefix),
|
424
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
423
425
|
)
|
424
426
|
self.logits_processor = LogitsProcessor(config)
|
425
427
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
sglang/srt/models/llama4.py
CHANGED
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
|
|
30
30
|
from sglang.srt.layers.dp_attention import (
|
31
31
|
dp_gather_partial,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_size,
|
34
33
|
get_attention_tp_rank,
|
35
34
|
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_size,
|
36
36
|
)
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.linear import (
|
@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
52
52
|
PPProxyTensors,
|
53
53
|
)
|
54
54
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
55
|
-
from sglang.srt.utils import
|
55
|
+
from sglang.srt.utils import (
|
56
|
+
add_prefix,
|
57
|
+
fast_topk,
|
58
|
+
get_compiler_backend,
|
59
|
+
is_cuda,
|
60
|
+
make_layers,
|
61
|
+
)
|
62
|
+
|
63
|
+
_is_cuda = is_cuda()
|
56
64
|
|
57
65
|
logger = logging.getLogger(__name__)
|
58
66
|
|
@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
|
|
131
139
|
return out_aD
|
132
140
|
|
133
141
|
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
134
|
-
if hidden_states.shape[0] < 4:
|
142
|
+
if hidden_states.shape[0] < 4 and _is_cuda:
|
135
143
|
return self._forward_core_shared_routed_overlap(hidden_states)
|
136
144
|
else:
|
137
145
|
return self._forward_core_normal(hidden_states)
|
@@ -198,7 +206,6 @@ class Llama4Attention(nn.Module):
|
|
198
206
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
199
207
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
200
208
|
|
201
|
-
self.dp_size = get_attention_dp_size()
|
202
209
|
attn_tp_rank = get_attention_tp_rank()
|
203
210
|
attn_tp_size = get_attention_tp_size()
|
204
211
|
|
@@ -342,7 +349,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
342
349
|
rope_theta = config.rope_theta
|
343
350
|
rope_scaling = config.rope_scaling
|
344
351
|
max_position_embeddings = config.max_position_embeddings
|
345
|
-
self.
|
352
|
+
self.local_dp_size = get_local_attention_dp_size()
|
346
353
|
self.attn_tp_size = get_attention_tp_size()
|
347
354
|
self.attn_tp_rank = get_attention_tp_rank()
|
348
355
|
|
@@ -405,7 +412,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
405
412
|
# Gather
|
406
413
|
if get_tensor_model_parallel_world_size() > 1:
|
407
414
|
# all gather and all reduce
|
408
|
-
if self.
|
415
|
+
if self.local_dp_size != 1:
|
409
416
|
if self.attn_tp_rank == 0:
|
410
417
|
hidden_states += residual
|
411
418
|
hidden_states, local_hidden_states = (
|
@@ -428,9 +435,9 @@ class Llama4DecoderLayer(nn.Module):
|
|
428
435
|
# Fully Connected
|
429
436
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
430
437
|
|
431
|
-
# TODO(ch-wan):
|
438
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
432
439
|
# Scatter
|
433
|
-
if self.
|
440
|
+
if self.local_dp_size != 1:
|
434
441
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
435
442
|
# be careful about this!
|
436
443
|
hidden_states, global_hidden_states = (
|
sglang/srt/models/llava.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import math
|
17
17
|
import re
|
18
|
-
from
|
18
|
+
from functools import lru_cache
|
19
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
19
20
|
|
20
21
|
import numpy as np
|
21
22
|
import torch
|
@@ -28,10 +29,18 @@ from transformers import (
|
|
28
29
|
Qwen2Config,
|
29
30
|
SiglipVisionModel,
|
30
31
|
)
|
32
|
+
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
31
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
34
|
|
35
|
+
# leave till last and symbol only in case circular import
|
36
|
+
import sglang.srt.models as sgl_models
|
33
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.
|
38
|
+
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
39
|
+
from sglang.srt.managers.schedule_batch import (
|
40
|
+
Modality,
|
41
|
+
MultimodalDataItem,
|
42
|
+
MultimodalInputs,
|
43
|
+
)
|
35
44
|
from sglang.srt.mm_utils import (
|
36
45
|
get_anyres_image_grid_shape,
|
37
46
|
unpad_image,
|
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
51
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
52
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
53
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix, flatten_nested_list
|
54
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
46
55
|
|
47
56
|
|
48
57
|
class LlavaBaseForCausalLM(nn.Module):
|
@@ -114,10 +123,18 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
114
123
|
image_inputs.image_offsets = offset_list
|
115
124
|
return input_ids
|
116
125
|
|
117
|
-
def encode_images(
|
126
|
+
def encode_images(
|
127
|
+
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
encode images by vision tower and multimodal projector
|
131
|
+
Args:
|
132
|
+
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
133
|
+
Returns:
|
134
|
+
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
135
|
+
"""
|
118
136
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
119
137
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
120
|
-
|
121
138
|
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
122
139
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
123
140
|
selected_image_feature = selected_image_feature[:, 1:]
|
@@ -128,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
128
145
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
129
146
|
)
|
130
147
|
image_features = self.multi_modal_projector(selected_image_feature)
|
131
|
-
|
132
148
|
return image_features
|
133
149
|
|
134
150
|
@torch.no_grad()
|
@@ -583,4 +599,239 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
583
599
|
)
|
584
600
|
|
585
601
|
|
586
|
-
|
602
|
+
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
603
|
+
"""
|
604
|
+
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
605
|
+
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
606
|
+
|
607
|
+
Once a model config is loaded, text_config and vision_config will be extracted, and
|
608
|
+
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
609
|
+
according to config.
|
610
|
+
"""
|
611
|
+
|
612
|
+
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
613
|
+
|
614
|
+
@property
|
615
|
+
def dtype(self):
|
616
|
+
return self.torch_dtype
|
617
|
+
|
618
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
619
|
+
if hasattr(self.vision_tower, "pad_input_ids"):
|
620
|
+
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
621
|
+
else:
|
622
|
+
return super().pad_input_ids(input_ids, image_inputs)
|
623
|
+
|
624
|
+
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
625
|
+
"""
|
626
|
+
Get the SGLang model implementation class according to config.
|
627
|
+
|
628
|
+
Args:
|
629
|
+
config: The config object of the model.
|
630
|
+
auto_model_type: The type of the auto model.
|
631
|
+
|
632
|
+
Returns:
|
633
|
+
The SGLang model implementation class.
|
634
|
+
"""
|
635
|
+
config_cls_name = config.__class__.__name__
|
636
|
+
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
637
|
+
if arch := arch_name_mapping.get(config_cls_name):
|
638
|
+
if isinstance(arch, tuple):
|
639
|
+
arch = arch[0]
|
640
|
+
logger.warning(
|
641
|
+
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
642
|
+
)
|
643
|
+
try:
|
644
|
+
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
645
|
+
except Exception as e:
|
646
|
+
raise ValueError(
|
647
|
+
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
648
|
+
)
|
649
|
+
else:
|
650
|
+
raise ValueError(
|
651
|
+
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
652
|
+
)
|
653
|
+
|
654
|
+
@lru_cache
|
655
|
+
def _config_cls_name_to_arch_name_mapping(
|
656
|
+
self, auto_model_type: Type[AutoModel]
|
657
|
+
) -> Dict[str, str]:
|
658
|
+
mapping = {}
|
659
|
+
for config_cls, archs in auto_model_type._model_mapping.items():
|
660
|
+
if isinstance(archs, tuple):
|
661
|
+
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
662
|
+
else:
|
663
|
+
mapping[config_cls.__name__] = archs.__name__
|
664
|
+
return mapping
|
665
|
+
|
666
|
+
def __init__(
|
667
|
+
self,
|
668
|
+
config: LlavaConfig,
|
669
|
+
quant_config: Optional[QuantizationConfig] = None,
|
670
|
+
prefix: str = "",
|
671
|
+
) -> None:
|
672
|
+
super().__init__()
|
673
|
+
|
674
|
+
assert hasattr(config, "text_config")
|
675
|
+
assert hasattr(config, "vision_config")
|
676
|
+
self.config = config
|
677
|
+
self.text_config = self.config.text_config
|
678
|
+
self.vision_config = self.config.vision_config
|
679
|
+
self.torch_dtype = getattr(self.config, "torch_dtype")
|
680
|
+
|
681
|
+
if not getattr(self.text_config, "torch_dtype"):
|
682
|
+
self.text_config.torch_dtype = self.torch_dtype
|
683
|
+
if not getattr(self.vision_config, "torch_dtype"):
|
684
|
+
self.vision_config.torch_dtype = self.torch_dtype
|
685
|
+
|
686
|
+
if not hasattr(self.config, "vocab_size"):
|
687
|
+
self.config.vocab_size = self.text_config.vocab_size
|
688
|
+
if not hasattr(self.config, "image_aspect_ratio"):
|
689
|
+
self.config.image_aspect_ratio = "anyres"
|
690
|
+
if not hasattr(self.config, "image_grid_pinpoints"):
|
691
|
+
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
692
|
+
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
693
|
+
self.config.image_grid_pinpoints = [
|
694
|
+
[96, 96],
|
695
|
+
[224, 224],
|
696
|
+
[384, 384],
|
697
|
+
[512, 512],
|
698
|
+
[768, 768],
|
699
|
+
[1024, 1024],
|
700
|
+
]
|
701
|
+
if not hasattr(self.config, "mm_patch_merge_type"):
|
702
|
+
self.config.mm_patch_merge_type = "flat"
|
703
|
+
if not hasattr(self.config, "image_token_index"):
|
704
|
+
self.config.image_token_index = 10
|
705
|
+
if not hasattr(self.config, "projector_hidden_act"):
|
706
|
+
self.config.projector_hidden_act = "gelu"
|
707
|
+
|
708
|
+
self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
|
709
|
+
self.vision_feature_select_strategy = getattr(
|
710
|
+
self.config, "vision_feature_select_strategy", "full"
|
711
|
+
)
|
712
|
+
self.image_size = self.vision_config.image_size
|
713
|
+
self.patch_size = self.vision_config.patch_size
|
714
|
+
|
715
|
+
self.mm_patch_merge_type = self.config.mm_patch_merge_type
|
716
|
+
self.image_aspect_ratio = self.config.image_aspect_ratio
|
717
|
+
self.image_grid_pinpoints = self.config.image_grid_pinpoints
|
718
|
+
|
719
|
+
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
720
|
+
|
721
|
+
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
722
|
+
|
723
|
+
language_model_cls = self._get_sgl_model_cls(
|
724
|
+
self.text_config, AutoModelForCausalLM
|
725
|
+
)
|
726
|
+
vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
|
727
|
+
self.language_model = language_model_cls(
|
728
|
+
self.text_config,
|
729
|
+
quant_config=quant_config,
|
730
|
+
prefix=add_prefix("language_model", prefix),
|
731
|
+
)
|
732
|
+
self.vision_tower = vision_model_cls(
|
733
|
+
self.vision_config,
|
734
|
+
quant_config=quant_config,
|
735
|
+
prefix=add_prefix("vision_tower", prefix),
|
736
|
+
)
|
737
|
+
|
738
|
+
if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
|
739
|
+
self.language_model.model.image_newline = nn.Parameter(
|
740
|
+
torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
|
741
|
+
)
|
742
|
+
|
743
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
744
|
+
"""Extract features from image inputs.
|
745
|
+
|
746
|
+
Args:
|
747
|
+
items: List of MultimodalDataItem objects containing image data
|
748
|
+
Note that an item can be either "image" or "multi-images"
|
749
|
+
|
750
|
+
Returns:
|
751
|
+
torch.Tensor: features from image inputs, concatenated
|
752
|
+
"""
|
753
|
+
features = []
|
754
|
+
for item in items:
|
755
|
+
# in each item, we assume pixel_values is always batched
|
756
|
+
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
757
|
+
image_outputs = self.vision_tower(
|
758
|
+
pixel_values, image_sizes, output_hidden_states=True
|
759
|
+
)
|
760
|
+
selected_image_feature = image_outputs.hidden_states[
|
761
|
+
self.vision_feature_layer
|
762
|
+
]
|
763
|
+
|
764
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
765
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
766
|
+
elif self.vision_feature_select_strategy == "full":
|
767
|
+
selected_image_feature = selected_image_feature
|
768
|
+
else:
|
769
|
+
raise ValueError(
|
770
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
771
|
+
)
|
772
|
+
features.append(
|
773
|
+
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
774
|
+
)
|
775
|
+
ret = torch.cat(features, dim=0)
|
776
|
+
return ret
|
777
|
+
|
778
|
+
def forward(
|
779
|
+
self,
|
780
|
+
input_ids: torch.Tensor,
|
781
|
+
positions: torch.Tensor,
|
782
|
+
forward_batch: ForwardBatch,
|
783
|
+
get_embedding: bool = False,
|
784
|
+
):
|
785
|
+
hidden_states = general_mm_embed_routine(
|
786
|
+
input_ids=input_ids,
|
787
|
+
forward_batch=forward_batch,
|
788
|
+
get_embedding=get_embedding,
|
789
|
+
language_model=self.language_model,
|
790
|
+
image_data_embedding_func=self.get_image_feature,
|
791
|
+
placeholder_tokens=None, # using mm_item.pad_value
|
792
|
+
positions=positions,
|
793
|
+
)
|
794
|
+
|
795
|
+
return hidden_states
|
796
|
+
|
797
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
798
|
+
"""Load weights for LlavaForConditionalGeneration.
|
799
|
+
|
800
|
+
Unlike the base class implementation, this one doesn't need to handle
|
801
|
+
weight name remapping as the weights are already properly structured with
|
802
|
+
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
803
|
+
"""
|
804
|
+
if (
|
805
|
+
self.vision_feature_select_strategy == "patch"
|
806
|
+
or self.vision_feature_select_strategy == "full"
|
807
|
+
):
|
808
|
+
pass
|
809
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
810
|
+
self.image_feature_len += 1
|
811
|
+
else:
|
812
|
+
raise ValueError(
|
813
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
814
|
+
)
|
815
|
+
|
816
|
+
# Create dictionaries for direct parameter loading
|
817
|
+
params_dict = dict(self.named_parameters())
|
818
|
+
|
819
|
+
# Load weights directly without remapping
|
820
|
+
for name, loaded_weight in weights:
|
821
|
+
for part in ("language_model", "vision_tower"):
|
822
|
+
if name.startswith(part):
|
823
|
+
name = name[len(part + ".") :]
|
824
|
+
getattr(self, part).load_weights([(name, loaded_weight)])
|
825
|
+
break
|
826
|
+
else:
|
827
|
+
param = params_dict[name]
|
828
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
829
|
+
weight_loader(param, loaded_weight)
|
830
|
+
|
831
|
+
|
832
|
+
EntryClass = [
|
833
|
+
LlavaLlamaForCausalLM,
|
834
|
+
LlavaQwenForCausalLM,
|
835
|
+
LlavaMistralForCausalLM,
|
836
|
+
LlavaForConditionalGeneration,
|
837
|
+
]
|