sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
32
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
33
|
from sglang.srt.distributed import (
|
34
34
|
get_tp_group,
|
35
|
+
get_world_group,
|
35
36
|
init_distributed_environment,
|
36
37
|
initialize_model_parallel,
|
37
38
|
set_custom_all_reduce,
|
@@ -400,11 +401,15 @@ class ModelRunner:
|
|
400
401
|
tp_rank=self.tp_rank,
|
401
402
|
tp_size=self.tp_size,
|
402
403
|
dp_size=self.server_args.dp_size,
|
404
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
403
405
|
pp_size=self.server_args.pp_size,
|
404
406
|
)
|
405
407
|
|
406
408
|
min_per_gpu_memory = get_available_gpu_memory(
|
407
|
-
self.device,
|
409
|
+
self.device,
|
410
|
+
self.gpu_id,
|
411
|
+
distributed=get_world_group().world_size > 1,
|
412
|
+
cpu_group=get_world_group().cpu_group,
|
408
413
|
)
|
409
414
|
self.tp_group = get_tp_group()
|
410
415
|
self.attention_tp_group = get_attention_tp_group()
|
@@ -716,7 +721,10 @@ class ModelRunner:
|
|
716
721
|
|
717
722
|
def profile_max_num_token(self, total_gpu_memory: int):
|
718
723
|
available_gpu_memory = get_available_gpu_memory(
|
719
|
-
self.device,
|
724
|
+
self.device,
|
725
|
+
self.gpu_id,
|
726
|
+
distributed=get_world_group().world_size > 1,
|
727
|
+
cpu_group=get_world_group().cpu_group,
|
720
728
|
)
|
721
729
|
if self.use_mla_backend:
|
722
730
|
num_layers = (
|
@@ -1085,32 +1093,33 @@ class ModelRunner:
|
|
1085
1093
|
forward_batch: ForwardBatch,
|
1086
1094
|
skip_attn_backend_init: bool = False,
|
1087
1095
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1088
|
-
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
1096
|
+
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1089
1097
|
can_run_cuda_graph = bool(
|
1090
1098
|
forward_batch.forward_mode.is_cuda_graph()
|
1091
1099
|
and self.cuda_graph_runner
|
1092
1100
|
and self.cuda_graph_runner.can_run(forward_batch)
|
1093
1101
|
)
|
1094
1102
|
if can_run_cuda_graph:
|
1095
|
-
|
1103
|
+
ret = self.cuda_graph_runner.replay(
|
1096
1104
|
forward_batch,
|
1097
1105
|
skip_attn_backend_init=skip_attn_backend_init,
|
1098
1106
|
pp_proxy_tensors=pp_proxy_tensors,
|
1099
1107
|
)
|
1100
|
-
|
1101
|
-
|
1102
|
-
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1108
|
+
elif forward_batch.forward_mode.is_decode():
|
1109
|
+
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1103
1110
|
elif forward_batch.forward_mode.is_extend():
|
1104
|
-
|
1111
|
+
ret = self.forward_extend(
|
1105
1112
|
forward_batch,
|
1106
1113
|
skip_attn_backend_init=skip_attn_backend_init,
|
1107
1114
|
pp_proxy_tensors=pp_proxy_tensors,
|
1108
1115
|
)
|
1109
1116
|
elif forward_batch.forward_mode.is_idle():
|
1110
|
-
|
1117
|
+
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1111
1118
|
else:
|
1112
1119
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1113
1120
|
|
1121
|
+
return ret, can_run_cuda_graph
|
1122
|
+
|
1114
1123
|
def _preprocess_logits(
|
1115
1124
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
1116
1125
|
):
|
@@ -1145,9 +1154,7 @@ class ModelRunner:
|
|
1145
1154
|
[self.sample(values, forward_batch) for values in logits_output],
|
1146
1155
|
axis=-1,
|
1147
1156
|
)
|
1148
|
-
|
1149
|
-
if sampling_info.thinking_budgets is not None:
|
1150
|
-
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
|
1157
|
+
|
1151
1158
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
1152
1159
|
|
1153
1160
|
# Sample the next tokens
|
@@ -1158,8 +1165,6 @@ class ModelRunner:
|
|
1158
1165
|
forward_batch.top_logprobs_nums,
|
1159
1166
|
forward_batch.token_ids_logprobs,
|
1160
1167
|
)
|
1161
|
-
if sampling_info.thinking_budgets is not None:
|
1162
|
-
sampling_info.update_thinking_budgets(next_token_ids)
|
1163
1168
|
return next_token_ids
|
1164
1169
|
|
1165
1170
|
@property
|
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
|
|
188
188
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
189
189
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
190
190
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
191
|
-
and the result is
|
191
|
+
and the result is subsequently scaled and shifted by the mean and std args.
|
192
192
|
Args:
|
193
193
|
tensor: an n-dimensional `torch.Tensor`
|
194
194
|
mean: the mean of the normal distribution
|
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
|
|
735
735
|
img_size: Input image size.
|
736
736
|
patch_size: Patch size.
|
737
737
|
in_chans: Number of image input channels.
|
738
|
-
num_classes:
|
738
|
+
num_classes: Number of classes for classification head.
|
739
739
|
global_pool: Type of global pooling for final sequence (default: 'token').
|
740
740
|
embed_dim: Transformer embedding dimension.
|
741
741
|
depth: Depth of transformer.
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
|
|
36
36
|
)
|
37
37
|
from sglang.srt.layers.activation import SiluAndMul
|
38
38
|
from sglang.srt.layers.dp_attention import (
|
39
|
+
attn_tp_all_gather,
|
40
|
+
attn_tp_reduce_scatter,
|
39
41
|
dp_gather_partial,
|
40
42
|
dp_scatter,
|
41
|
-
get_attention_dp_size,
|
42
43
|
get_attention_tp_rank,
|
43
44
|
get_attention_tp_size,
|
44
|
-
|
45
|
-
tp_reduce_scatter,
|
45
|
+
get_local_attention_dp_size,
|
46
46
|
)
|
47
47
|
from sglang.srt.layers.layernorm import RMSNorm
|
48
48
|
from sglang.srt.layers.linear import (
|
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
438
438
|
self.v_head_dim = v_head_dim
|
439
439
|
self.q_lora_rank = q_lora_rank
|
440
440
|
self.kv_lora_rank = kv_lora_rank
|
441
|
-
self.dp_size = get_attention_dp_size()
|
442
441
|
attn_tp_rank = get_attention_tp_rank()
|
443
442
|
attn_tp_size = get_attention_tp_size()
|
444
443
|
|
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1133
1132
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1134
1133
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1135
1134
|
self.layer_id = layer_id
|
1136
|
-
self.
|
1135
|
+
self.local_dp_size = get_local_attention_dp_size()
|
1137
1136
|
self.attn_tp_size = get_attention_tp_size()
|
1138
1137
|
self.attn_tp_rank = get_attention_tp_rank()
|
1139
1138
|
self.self_attn = DeepseekV2AttentionMLA(
|
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1184
1183
|
)
|
1185
1184
|
|
1186
1185
|
self.input_is_scattered = (
|
1187
|
-
|
1186
|
+
layer_id > 0
|
1187
|
+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1188
1188
|
)
|
1189
1189
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1190
1190
|
|
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1264
1264
|
# Gather
|
1265
1265
|
if get_tensor_model_parallel_world_size() > 1:
|
1266
1266
|
# all gather and all reduce
|
1267
|
-
if self.
|
1267
|
+
if self.local_dp_size != 1:
|
1268
1268
|
if self.attn_tp_rank == 0:
|
1269
1269
|
hidden_states += residual
|
1270
1270
|
hidden_states, local_hidden_states = (
|
@@ -1287,9 +1287,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1287
1287
|
# Fully Connected
|
1288
1288
|
hidden_states = self.mlp(hidden_states)
|
1289
1289
|
|
1290
|
-
# TODO(ch-wan):
|
1290
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
1291
1291
|
# Scatter
|
1292
|
-
if self.
|
1292
|
+
if self.local_dp_size != 1:
|
1293
1293
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1294
1294
|
# be careful about this!
|
1295
1295
|
hidden_states, global_hidden_states = (
|
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1323
1323
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1324
1324
|
hidden_states,
|
1325
1325
|
)
|
1326
|
-
|
1326
|
+
attn_tp_all_gather(
|
1327
1327
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1328
1328
|
)
|
1329
1329
|
|
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1339
1339
|
if self.input_is_scattered:
|
1340
1340
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1341
1341
|
hidden_states = tensor_list[self.attn_tp_rank]
|
1342
|
-
|
1342
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1343
1343
|
if hidden_states.shape[0] != 0:
|
1344
1344
|
hidden_states, residual = self.post_attention_layernorm(
|
1345
1345
|
hidden_states, residual
|
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1349
1349
|
hidden_states += residual
|
1350
1350
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1351
1351
|
hidden_states = tensor_list[self.attn_tp_rank]
|
1352
|
-
|
1352
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1353
1353
|
residual = hidden_states
|
1354
1354
|
if hidden_states.shape[0] != 0:
|
1355
1355
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1373
1373
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1374
1374
|
hidden_states,
|
1375
1375
|
)
|
1376
|
-
|
1376
|
+
attn_tp_all_gather(
|
1377
1377
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1378
1378
|
)
|
1379
1379
|
|
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
|
|
1413
1413
|
)
|
1414
1414
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1415
1415
|
|
1416
|
-
self.dp_size =
|
1416
|
+
self.dp_size = get_local_attention_dp_size()
|
1417
1417
|
|
1418
1418
|
def get_input_embeddings(self) -> torch.Tensor:
|
1419
1419
|
return self.embed_tokens
|
@@ -1475,9 +1475,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1475
1475
|
config.hidden_size,
|
1476
1476
|
quant_config=quant_config,
|
1477
1477
|
prefix=add_prefix("lm_head", prefix),
|
1478
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1478
1479
|
)
|
1479
1480
|
self.logits_processor = LogitsProcessor(config)
|
1480
|
-
self.dp_size =
|
1481
|
+
self.dp_size = get_local_attention_dp_size()
|
1481
1482
|
|
1482
1483
|
def determine_n_share_experts_fusion(
|
1483
1484
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
@@ -1486,22 +1487,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1486
1487
|
if self.n_share_experts_fusion > 0:
|
1487
1488
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1488
1489
|
if (
|
1489
|
-
|
1490
|
+
not _is_cuda
|
1491
|
+
or self.config.architectures[0] != architecture
|
1490
1492
|
or self.config.n_routed_experts != 256
|
1491
1493
|
):
|
1492
1494
|
self.n_share_experts_fusion = 0
|
1493
1495
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
1494
1496
|
log_info_on_rank0(
|
1495
1497
|
logger,
|
1496
|
-
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1498
|
+
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1497
1499
|
)
|
1498
1500
|
else:
|
1499
1501
|
assert (
|
1500
1502
|
self.n_share_experts_fusion == self.tp_size
|
1501
|
-
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized
|
1503
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
1502
1504
|
elif self.n_share_experts_fusion == 0:
|
1503
1505
|
if (
|
1504
|
-
|
1506
|
+
_is_cuda
|
1507
|
+
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1505
1508
|
and self.config.architectures[0] == architecture
|
1506
1509
|
and self.config.n_routed_experts == 256
|
1507
1510
|
and (not global_server_args_dict["enable_deepep_moe"])
|
@@ -1663,7 +1666,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1663
1666
|
if is_nextn:
|
1664
1667
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1665
1668
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
1666
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is
|
1669
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
1667
1670
|
# compatible with old design
|
1668
1671
|
nextn_layer_id = (
|
1669
1672
|
0
|
sglang/srt/models/llama.py
CHANGED
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
45
|
ParallelLMHead,
|
46
46
|
VocabParallelEmbedding,
|
47
47
|
)
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
49
50
|
from sglang.srt.model_loader.weight_utils import (
|
50
51
|
default_weight_loader,
|
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
|
|
420
421
|
config.hidden_size,
|
421
422
|
quant_config=quant_config,
|
422
423
|
prefix=add_prefix("lm_head", prefix),
|
424
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
423
425
|
)
|
424
426
|
self.logits_processor = LogitsProcessor(config)
|
425
427
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
sglang/srt/models/llama4.py
CHANGED
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
|
|
30
30
|
from sglang.srt.layers.dp_attention import (
|
31
31
|
dp_gather_partial,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_size,
|
34
33
|
get_attention_tp_rank,
|
35
34
|
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_size,
|
36
36
|
)
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.linear import (
|
@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
|
|
198
198
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
199
199
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
200
200
|
|
201
|
-
self.dp_size = get_attention_dp_size()
|
202
201
|
attn_tp_rank = get_attention_tp_rank()
|
203
202
|
attn_tp_size = get_attention_tp_size()
|
204
203
|
|
@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
342
341
|
rope_theta = config.rope_theta
|
343
342
|
rope_scaling = config.rope_scaling
|
344
343
|
max_position_embeddings = config.max_position_embeddings
|
345
|
-
self.
|
344
|
+
self.local_dp_size = get_local_attention_dp_size()
|
346
345
|
self.attn_tp_size = get_attention_tp_size()
|
347
346
|
self.attn_tp_rank = get_attention_tp_rank()
|
348
347
|
|
@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
405
404
|
# Gather
|
406
405
|
if get_tensor_model_parallel_world_size() > 1:
|
407
406
|
# all gather and all reduce
|
408
|
-
if self.
|
407
|
+
if self.local_dp_size != 1:
|
409
408
|
if self.attn_tp_rank == 0:
|
410
409
|
hidden_states += residual
|
411
410
|
hidden_states, local_hidden_states = (
|
@@ -428,9 +427,9 @@ class Llama4DecoderLayer(nn.Module):
|
|
428
427
|
# Fully Connected
|
429
428
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
430
429
|
|
431
|
-
# TODO(ch-wan):
|
430
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
432
431
|
# Scatter
|
433
|
-
if self.
|
432
|
+
if self.local_dp_size != 1:
|
434
433
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
435
434
|
# be careful about this!
|
436
435
|
hidden_states, global_hidden_states = (
|
sglang/srt/models/llava.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import math
|
17
17
|
import re
|
18
|
-
from
|
18
|
+
from functools import lru_cache
|
19
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
19
20
|
|
20
21
|
import numpy as np
|
21
22
|
import torch
|
@@ -28,10 +29,18 @@ from transformers import (
|
|
28
29
|
Qwen2Config,
|
29
30
|
SiglipVisionModel,
|
30
31
|
)
|
32
|
+
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
31
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
34
|
|
35
|
+
# leave till last and symbol only in case circular import
|
36
|
+
import sglang.srt.models as sgl_models
|
33
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.
|
38
|
+
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
39
|
+
from sglang.srt.managers.schedule_batch import (
|
40
|
+
Modality,
|
41
|
+
MultimodalDataItem,
|
42
|
+
MultimodalInputs,
|
43
|
+
)
|
35
44
|
from sglang.srt.mm_utils import (
|
36
45
|
get_anyres_image_grid_shape,
|
37
46
|
unpad_image,
|
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
51
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
52
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
53
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix, flatten_nested_list
|
54
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
46
55
|
|
47
56
|
|
48
57
|
class LlavaBaseForCausalLM(nn.Module):
|
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
114
123
|
image_inputs.image_offsets = offset_list
|
115
124
|
return input_ids
|
116
125
|
|
117
|
-
def encode_images(
|
126
|
+
def encode_images(
|
127
|
+
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
encode images by vision tower and multimodal projector
|
131
|
+
Args:
|
132
|
+
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
133
|
+
Returns:
|
134
|
+
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
135
|
+
"""
|
118
136
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
119
137
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
120
138
|
|
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
583
601
|
)
|
584
602
|
|
585
603
|
|
586
|
-
|
604
|
+
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
605
|
+
"""
|
606
|
+
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
607
|
+
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
608
|
+
|
609
|
+
Once a model config is loaded, text_config and vision_config will be extracted, and
|
610
|
+
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
611
|
+
according to config.
|
612
|
+
"""
|
613
|
+
|
614
|
+
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
615
|
+
|
616
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
617
|
+
if hasattr(self.vision_tower, "pad_input_ids"):
|
618
|
+
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
619
|
+
else:
|
620
|
+
return super().pad_input_ids(input_ids, image_inputs)
|
621
|
+
|
622
|
+
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
623
|
+
"""
|
624
|
+
Get the SGLang model implementation class according to config.
|
625
|
+
|
626
|
+
Args:
|
627
|
+
config: The config object of the model.
|
628
|
+
auto_model_type: The type of the auto model.
|
629
|
+
|
630
|
+
Returns:
|
631
|
+
The SGLang model implementation class.
|
632
|
+
"""
|
633
|
+
config_cls_name = config.__class__.__name__
|
634
|
+
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
635
|
+
if arch := arch_name_mapping.get(config_cls_name):
|
636
|
+
if isinstance(arch, tuple):
|
637
|
+
arch = arch[0]
|
638
|
+
logger.warning(
|
639
|
+
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
640
|
+
)
|
641
|
+
try:
|
642
|
+
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
643
|
+
except Exception as e:
|
644
|
+
raise ValueError(
|
645
|
+
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
646
|
+
)
|
647
|
+
else:
|
648
|
+
raise ValueError(
|
649
|
+
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
650
|
+
)
|
651
|
+
|
652
|
+
@lru_cache
|
653
|
+
def _config_cls_name_to_arch_name_mapping(
|
654
|
+
self, auto_model_type: Type[AutoModel]
|
655
|
+
) -> Dict[str, str]:
|
656
|
+
mapping = {}
|
657
|
+
for config_cls, archs in auto_model_type._model_mapping.items():
|
658
|
+
if isinstance(archs, tuple):
|
659
|
+
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
660
|
+
else:
|
661
|
+
mapping[config_cls.__name__] = archs.__name__
|
662
|
+
return mapping
|
663
|
+
|
664
|
+
def __init__(
|
665
|
+
self,
|
666
|
+
config: LlavaConfig,
|
667
|
+
quant_config: Optional[QuantizationConfig] = None,
|
668
|
+
prefix: str = "",
|
669
|
+
) -> None:
|
670
|
+
super().__init__()
|
671
|
+
|
672
|
+
assert hasattr(config, "text_config")
|
673
|
+
assert hasattr(config, "vision_config")
|
674
|
+
self.config = config
|
675
|
+
self.text_config = config.text_config
|
676
|
+
self.vision_config = config.vision_config
|
677
|
+
|
678
|
+
if not hasattr(self.config, "vocab_size"):
|
679
|
+
self.config.vocab_size = self.config.text_config.vocab_size
|
680
|
+
if not hasattr(self.config, "image_aspect_ratio"):
|
681
|
+
self.config.image_aspect_ratio = "anyres"
|
682
|
+
if not hasattr(self.config, "image_grid_pinpoints"):
|
683
|
+
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
684
|
+
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
685
|
+
self.config.image_grid_pinpoints = [
|
686
|
+
[96, 96],
|
687
|
+
[224, 224],
|
688
|
+
[384, 384],
|
689
|
+
[512, 512],
|
690
|
+
[768, 768],
|
691
|
+
[1024, 1024],
|
692
|
+
]
|
693
|
+
if not hasattr(self.config, "mm_patch_merge_type"):
|
694
|
+
self.config.mm_patch_merge_type = "flat"
|
695
|
+
if not hasattr(self.config, "image_token_index"):
|
696
|
+
self.config.image_token_index = 10
|
697
|
+
if not hasattr(self.config, "projector_hidden_act"):
|
698
|
+
self.config.projector_hidden_act = "gelu"
|
699
|
+
|
700
|
+
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
701
|
+
self.vision_feature_select_strategy = getattr(
|
702
|
+
config, "vision_feature_select_strategy", "full"
|
703
|
+
)
|
704
|
+
self.image_size = self.config.vision_config.image_size
|
705
|
+
self.patch_size = self.config.vision_config.patch_size
|
706
|
+
|
707
|
+
self.mm_patch_merge_type = config.mm_patch_merge_type
|
708
|
+
self.image_aspect_ratio = config.image_aspect_ratio
|
709
|
+
self.image_grid_pinpoints = config.image_grid_pinpoints
|
710
|
+
|
711
|
+
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
712
|
+
|
713
|
+
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
714
|
+
|
715
|
+
language_model_cls = self._get_sgl_model_cls(
|
716
|
+
config.text_config, AutoModelForCausalLM
|
717
|
+
)
|
718
|
+
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
719
|
+
self.language_model = language_model_cls(
|
720
|
+
config.text_config,
|
721
|
+
quant_config=quant_config,
|
722
|
+
prefix=add_prefix("language_model", prefix),
|
723
|
+
)
|
724
|
+
self.vision_tower = vision_model_cls(
|
725
|
+
config.vision_config,
|
726
|
+
quant_config=quant_config,
|
727
|
+
prefix=add_prefix("vision_tower", prefix),
|
728
|
+
)
|
729
|
+
|
730
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
731
|
+
self.language_model.model.image_newline = nn.Parameter(
|
732
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
733
|
+
)
|
734
|
+
|
735
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
736
|
+
"""Extract features from image inputs.
|
737
|
+
|
738
|
+
Args:
|
739
|
+
items: List of MultimodalDataItem objects containing image data
|
740
|
+
Note that an item can be either "image" or "multi-images"
|
741
|
+
|
742
|
+
Returns:
|
743
|
+
torch.Tensor: features from image inputs, concatenated
|
744
|
+
"""
|
745
|
+
features = []
|
746
|
+
for item in items:
|
747
|
+
# in each item, we assume pixel_values is always batched
|
748
|
+
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
749
|
+
image_outputs = self.vision_tower(
|
750
|
+
pixel_values, image_sizes, output_hidden_states=True
|
751
|
+
)
|
752
|
+
selected_image_feature = image_outputs.hidden_states[
|
753
|
+
self.vision_feature_layer
|
754
|
+
]
|
755
|
+
|
756
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
757
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
758
|
+
elif self.vision_feature_select_strategy == "full":
|
759
|
+
selected_image_feature = selected_image_feature
|
760
|
+
else:
|
761
|
+
raise ValueError(
|
762
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
763
|
+
)
|
764
|
+
features.append(
|
765
|
+
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
766
|
+
)
|
767
|
+
ret = torch.cat(features, dim=0)
|
768
|
+
return ret
|
769
|
+
|
770
|
+
def forward(
|
771
|
+
self,
|
772
|
+
input_ids: torch.Tensor,
|
773
|
+
positions: torch.Tensor,
|
774
|
+
forward_batch: ForwardBatch,
|
775
|
+
get_embedding: bool = False,
|
776
|
+
):
|
777
|
+
hidden_states = general_mm_embed_routine(
|
778
|
+
input_ids=input_ids,
|
779
|
+
forward_batch=forward_batch,
|
780
|
+
get_embedding=get_embedding,
|
781
|
+
language_model=self.language_model,
|
782
|
+
image_data_embedding_func=self.get_image_feature,
|
783
|
+
placeholder_tokens=None, # using mm_item.pad_value
|
784
|
+
positions=positions,
|
785
|
+
)
|
786
|
+
|
787
|
+
return hidden_states
|
788
|
+
|
789
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
790
|
+
"""Load weights for LlavaForConditionalGeneration.
|
791
|
+
|
792
|
+
Unlike the base class implementation, this one doesn't need to handle
|
793
|
+
weight name remapping as the weights are already properly structured with
|
794
|
+
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
795
|
+
"""
|
796
|
+
if (
|
797
|
+
self.vision_feature_select_strategy == "patch"
|
798
|
+
or self.vision_feature_select_strategy == "full"
|
799
|
+
):
|
800
|
+
pass
|
801
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
802
|
+
self.image_feature_len += 1
|
803
|
+
else:
|
804
|
+
raise ValueError(
|
805
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
806
|
+
)
|
807
|
+
|
808
|
+
# Create dictionaries for direct parameter loading
|
809
|
+
params_dict = dict(self.named_parameters())
|
810
|
+
|
811
|
+
# Load weights directly without remapping
|
812
|
+
for name, loaded_weight in weights:
|
813
|
+
for part in ("language_model", "vision_tower"):
|
814
|
+
if name.startswith(part):
|
815
|
+
name = name[len(part + ".") :]
|
816
|
+
getattr(self, part).load_weights([(name, loaded_weight)])
|
817
|
+
break
|
818
|
+
else:
|
819
|
+
param = params_dict[name]
|
820
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
821
|
+
weight_loader(param, loaded_weight)
|
822
|
+
|
823
|
+
|
824
|
+
EntryClass = [
|
825
|
+
LlavaLlamaForCausalLM,
|
826
|
+
LlavaQwenForCausalLM,
|
827
|
+
LlavaMistralForCausalLM,
|
828
|
+
LlavaForConditionalGeneration,
|
829
|
+
]
|