sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
32
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
33
  from sglang.srt.distributed import (
34
34
  get_tp_group,
35
+ get_world_group,
35
36
  init_distributed_environment,
36
37
  initialize_model_parallel,
37
38
  set_custom_all_reduce,
@@ -400,11 +401,15 @@ class ModelRunner:
400
401
  tp_rank=self.tp_rank,
401
402
  tp_size=self.tp_size,
402
403
  dp_size=self.server_args.dp_size,
404
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
403
405
  pp_size=self.server_args.pp_size,
404
406
  )
405
407
 
406
408
  min_per_gpu_memory = get_available_gpu_memory(
407
- self.device, self.gpu_id, distributed=self.tp_size > 1
409
+ self.device,
410
+ self.gpu_id,
411
+ distributed=get_world_group().world_size > 1,
412
+ cpu_group=get_world_group().cpu_group,
408
413
  )
409
414
  self.tp_group = get_tp_group()
410
415
  self.attention_tp_group = get_attention_tp_group()
@@ -716,7 +721,10 @@ class ModelRunner:
716
721
 
717
722
  def profile_max_num_token(self, total_gpu_memory: int):
718
723
  available_gpu_memory = get_available_gpu_memory(
719
- self.device, self.gpu_id, distributed=self.tp_size > 1
724
+ self.device,
725
+ self.gpu_id,
726
+ distributed=get_world_group().world_size > 1,
727
+ cpu_group=get_world_group().cpu_group,
720
728
  )
721
729
  if self.use_mla_backend:
722
730
  num_layers = (
@@ -1085,32 +1093,33 @@ class ModelRunner:
1085
1093
  forward_batch: ForwardBatch,
1086
1094
  skip_attn_backend_init: bool = False,
1087
1095
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
1088
- ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1096
+ ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1089
1097
  can_run_cuda_graph = bool(
1090
1098
  forward_batch.forward_mode.is_cuda_graph()
1091
1099
  and self.cuda_graph_runner
1092
1100
  and self.cuda_graph_runner.can_run(forward_batch)
1093
1101
  )
1094
1102
  if can_run_cuda_graph:
1095
- return self.cuda_graph_runner.replay(
1103
+ ret = self.cuda_graph_runner.replay(
1096
1104
  forward_batch,
1097
1105
  skip_attn_backend_init=skip_attn_backend_init,
1098
1106
  pp_proxy_tensors=pp_proxy_tensors,
1099
1107
  )
1100
-
1101
- if forward_batch.forward_mode.is_decode():
1102
- return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1108
+ elif forward_batch.forward_mode.is_decode():
1109
+ ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1103
1110
  elif forward_batch.forward_mode.is_extend():
1104
- return self.forward_extend(
1111
+ ret = self.forward_extend(
1105
1112
  forward_batch,
1106
1113
  skip_attn_backend_init=skip_attn_backend_init,
1107
1114
  pp_proxy_tensors=pp_proxy_tensors,
1108
1115
  )
1109
1116
  elif forward_batch.forward_mode.is_idle():
1110
- return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1117
+ ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1111
1118
  else:
1112
1119
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1113
1120
 
1121
+ return ret, can_run_cuda_graph
1122
+
1114
1123
  def _preprocess_logits(
1115
1124
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
1116
1125
  ):
@@ -1145,9 +1154,7 @@ class ModelRunner:
1145
1154
  [self.sample(values, forward_batch) for values in logits_output],
1146
1155
  axis=-1,
1147
1156
  )
1148
- sampling_info = forward_batch.sampling_info
1149
- if sampling_info.thinking_budgets is not None:
1150
- sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
1157
+
1151
1158
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1152
1159
 
1153
1160
  # Sample the next tokens
@@ -1158,8 +1165,6 @@ class ModelRunner:
1158
1165
  forward_batch.top_logprobs_nums,
1159
1166
  forward_batch.token_ids_logprobs,
1160
1167
  )
1161
- if sampling_info.thinking_budgets is not None:
1162
- sampling_info.update_thinking_budgets(next_token_ids)
1163
1168
  return next_token_ids
1164
1169
 
1165
1170
  @property
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
188
188
  best when :math:`a \\leq \text{mean} \\leq b`.
189
189
  NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
190
190
  bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
191
- and the result is subsquently scaled and shifted by the mean and std args.
191
+ and the result is subsequently scaled and shifted by the mean and std args.
192
192
  Args:
193
193
  tensor: an n-dimensional `torch.Tensor`
194
194
  mean: the mean of the normal distribution
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
735
735
  img_size: Input image size.
736
736
  patch_size: Patch size.
737
737
  in_chans: Number of image input channels.
738
- num_classes: Mumber of classes for classification head.
738
+ num_classes: Number of classes for classification head.
739
739
  global_pool: Type of global pooling for final sequence (default: 'token').
740
740
  embed_dim: Transformer embedding dimension.
741
741
  depth: Depth of transformer.
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
36
36
  )
37
37
  from sglang.srt.layers.activation import SiluAndMul
38
38
  from sglang.srt.layers.dp_attention import (
39
+ attn_tp_all_gather,
40
+ attn_tp_reduce_scatter,
39
41
  dp_gather_partial,
40
42
  dp_scatter,
41
- get_attention_dp_size,
42
43
  get_attention_tp_rank,
43
44
  get_attention_tp_size,
44
- tp_all_gather,
45
- tp_reduce_scatter,
45
+ get_local_attention_dp_size,
46
46
  )
47
47
  from sglang.srt.layers.layernorm import RMSNorm
48
48
  from sglang.srt.layers.linear import (
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
438
438
  self.v_head_dim = v_head_dim
439
439
  self.q_lora_rank = q_lora_rank
440
440
  self.kv_lora_rank = kv_lora_rank
441
- self.dp_size = get_attention_dp_size()
442
441
  attn_tp_rank = get_attention_tp_rank()
443
442
  attn_tp_size = get_attention_tp_size()
444
443
 
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1133
1132
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1134
1133
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1135
1134
  self.layer_id = layer_id
1136
- self.dp_size = get_attention_dp_size()
1135
+ self.local_dp_size = get_local_attention_dp_size()
1137
1136
  self.attn_tp_size = get_attention_tp_size()
1138
1137
  self.attn_tp_rank = get_attention_tp_rank()
1139
1138
  self.self_attn = DeepseekV2AttentionMLA(
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1184
1183
  )
1185
1184
 
1186
1185
  self.input_is_scattered = (
1187
- previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1186
+ layer_id > 0
1187
+ and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1188
1188
  )
1189
1189
  self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1190
1190
 
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1264
1264
  # Gather
1265
1265
  if get_tensor_model_parallel_world_size() > 1:
1266
1266
  # all gather and all reduce
1267
- if self.dp_size != 1:
1267
+ if self.local_dp_size != 1:
1268
1268
  if self.attn_tp_rank == 0:
1269
1269
  hidden_states += residual
1270
1270
  hidden_states, local_hidden_states = (
@@ -1287,9 +1287,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1287
1287
  # Fully Connected
1288
1288
  hidden_states = self.mlp(hidden_states)
1289
1289
 
1290
- # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1290
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
1291
1291
  # Scatter
1292
- if self.dp_size != 1:
1292
+ if self.local_dp_size != 1:
1293
1293
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
1294
1294
  # be careful about this!
1295
1295
  hidden_states, global_hidden_states = (
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1323
1323
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1324
1324
  hidden_states,
1325
1325
  )
1326
- tp_all_gather(
1326
+ attn_tp_all_gather(
1327
1327
  list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1328
1328
  )
1329
1329
 
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1339
1339
  if self.input_is_scattered:
1340
1340
  tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1341
1341
  hidden_states = tensor_list[self.attn_tp_rank]
1342
- tp_reduce_scatter(hidden_states, tensor_list)
1342
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
1343
1343
  if hidden_states.shape[0] != 0:
1344
1344
  hidden_states, residual = self.post_attention_layernorm(
1345
1345
  hidden_states, residual
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1349
1349
  hidden_states += residual
1350
1350
  tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1351
1351
  hidden_states = tensor_list[self.attn_tp_rank]
1352
- tp_reduce_scatter(hidden_states, tensor_list)
1352
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
1353
1353
  residual = hidden_states
1354
1354
  if hidden_states.shape[0] != 0:
1355
1355
  hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1373
1373
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1374
1374
  hidden_states,
1375
1375
  )
1376
- tp_all_gather(
1376
+ attn_tp_all_gather(
1377
1377
  list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1378
1378
  )
1379
1379
 
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
1413
1413
  )
1414
1414
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1415
1415
 
1416
- self.dp_size = get_attention_dp_size()
1416
+ self.dp_size = get_local_attention_dp_size()
1417
1417
 
1418
1418
  def get_input_embeddings(self) -> torch.Tensor:
1419
1419
  return self.embed_tokens
@@ -1475,9 +1475,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1475
1475
  config.hidden_size,
1476
1476
  quant_config=quant_config,
1477
1477
  prefix=add_prefix("lm_head", prefix),
1478
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1478
1479
  )
1479
1480
  self.logits_processor = LogitsProcessor(config)
1480
- self.dp_size = get_attention_dp_size()
1481
+ self.dp_size = get_local_attention_dp_size()
1481
1482
 
1482
1483
  def determine_n_share_experts_fusion(
1483
1484
  self, architecture: str = "DeepseekV3ForCausalLM"
@@ -1486,22 +1487,24 @@ class DeepseekV2ForCausalLM(nn.Module):
1486
1487
  if self.n_share_experts_fusion > 0:
1487
1488
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1488
1489
  if (
1489
- self.config.architectures[0] != architecture
1490
+ not _is_cuda
1491
+ or self.config.architectures[0] != architecture
1490
1492
  or self.config.n_routed_experts != 256
1491
1493
  ):
1492
1494
  self.n_share_experts_fusion = 0
1493
1495
  global_server_args_dict["n_share_experts_fusion"] = 0
1494
1496
  log_info_on_rank0(
1495
1497
  logger,
1496
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1498
+ "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1497
1499
  )
1498
1500
  else:
1499
1501
  assert (
1500
1502
  self.n_share_experts_fusion == self.tp_size
1501
- ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1503
+ ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1502
1504
  elif self.n_share_experts_fusion == 0:
1503
1505
  if (
1504
- torch.cuda.get_device_capability("cuda") >= (9, 0)
1506
+ _is_cuda
1507
+ and torch.cuda.get_device_capability("cuda") >= (9, 0)
1505
1508
  and self.config.architectures[0] == architecture
1506
1509
  and self.config.n_routed_experts == 256
1507
1510
  and (not global_server_args_dict["enable_deepep_moe"])
@@ -1663,7 +1666,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1663
1666
  if is_nextn:
1664
1667
  if hasattr(self.config, "num_nextn_predict_layers"):
1665
1668
  num_nextn_layers = self.config.num_nextn_predict_layers
1666
- assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
1669
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1667
1670
  # compatible with old design
1668
1671
  nextn_layer_id = (
1669
1672
  0
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  ParallelLMHead,
46
46
  VocabParallelEmbedding,
47
47
  )
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49
50
  from sglang.srt.model_loader.weight_utils import (
50
51
  default_weight_loader,
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
420
421
  config.hidden_size,
421
422
  quant_config=quant_config,
422
423
  prefix=add_prefix("lm_head", prefix),
424
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
423
425
  )
424
426
  self.logits_processor = LogitsProcessor(config)
425
427
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
30
30
  from sglang.srt.layers.dp_attention import (
31
31
  dp_gather_partial,
32
32
  dp_scatter,
33
- get_attention_dp_size,
34
33
  get_attention_tp_rank,
35
34
  get_attention_tp_size,
35
+ get_local_attention_dp_size,
36
36
  )
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.linear import (
@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
198
198
  self.use_rope = int((layer_id + 1) % 4 != 0)
199
199
  self.use_qk_norm = config.use_qk_norm and self.use_rope
200
200
 
201
- self.dp_size = get_attention_dp_size()
202
201
  attn_tp_rank = get_attention_tp_rank()
203
202
  attn_tp_size = get_attention_tp_size()
204
203
 
@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
342
341
  rope_theta = config.rope_theta
343
342
  rope_scaling = config.rope_scaling
344
343
  max_position_embeddings = config.max_position_embeddings
345
- self.dp_size = get_attention_dp_size()
344
+ self.local_dp_size = get_local_attention_dp_size()
346
345
  self.attn_tp_size = get_attention_tp_size()
347
346
  self.attn_tp_rank = get_attention_tp_rank()
348
347
 
@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
405
404
  # Gather
406
405
  if get_tensor_model_parallel_world_size() > 1:
407
406
  # all gather and all reduce
408
- if self.dp_size != 1:
407
+ if self.local_dp_size != 1:
409
408
  if self.attn_tp_rank == 0:
410
409
  hidden_states += residual
411
410
  hidden_states, local_hidden_states = (
@@ -428,9 +427,9 @@ class Llama4DecoderLayer(nn.Module):
428
427
  # Fully Connected
429
428
  hidden_states = self.feed_forward(hidden_states, forward_batch)
430
429
 
431
- # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
430
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
432
431
  # Scatter
433
- if self.dp_size != 1:
432
+ if self.local_dp_size != 1:
434
433
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
435
434
  # be careful about this!
436
435
  hidden_states, global_hidden_states = (
@@ -15,7 +15,8 @@
15
15
 
16
16
  import math
17
17
  import re
18
- from typing import Iterable, List, Optional, Tuple
18
+ from functools import lru_cache
19
+ from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
19
20
 
20
21
  import numpy as np
21
22
  import torch
@@ -28,10 +29,18 @@ from transformers import (
28
29
  Qwen2Config,
29
30
  SiglipVisionModel,
30
31
  )
32
+ from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
31
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
34
 
35
+ # leave till last and symbol only in case circular import
36
+ import sglang.srt.models as sgl_models
33
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
38
+ from sglang.srt.managers.mm_utils import general_mm_embed_routine
39
+ from sglang.srt.managers.schedule_batch import (
40
+ Modality,
41
+ MultimodalDataItem,
42
+ MultimodalInputs,
43
+ )
35
44
  from sglang.srt.mm_utils import (
36
45
  get_anyres_image_grid_shape,
37
46
  unpad_image,
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
51
  from sglang.srt.models.llama import LlamaForCausalLM
43
52
  from sglang.srt.models.mistral import MistralForCausalLM
44
53
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
- from sglang.srt.utils import add_prefix, flatten_nested_list
54
+ from sglang.srt.utils import add_prefix, flatten_nested_list, logger
46
55
 
47
56
 
48
57
  class LlavaBaseForCausalLM(nn.Module):
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
114
123
  image_inputs.image_offsets = offset_list
115
124
  return input_ids
116
125
 
117
- def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
126
+ def encode_images(
127
+ self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
128
+ ) -> torch.Tensor:
129
+ """
130
+ encode images by vision tower and multimodal projector
131
+ Args:
132
+ pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
133
+ Returns:
134
+ torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
135
+ """
118
136
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
119
137
  # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
120
138
 
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
583
601
  )
584
602
 
585
603
 
586
- EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
604
+ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
605
+ """
606
+ An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
607
+ It follows the structure of (vision_tower, multi_modal_projector, language_model)
608
+
609
+ Once a model config is loaded, text_config and vision_config will be extracted, and
610
+ LlavaForConditionalGeneration will load the language_model and vision_tower models
611
+ according to config.
612
+ """
613
+
614
+ MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
615
+
616
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
617
+ if hasattr(self.vision_tower, "pad_input_ids"):
618
+ return self.vision_tower.pad_input_ids(input_ids, image_inputs)
619
+ else:
620
+ return super().pad_input_ids(input_ids, image_inputs)
621
+
622
+ def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
623
+ """
624
+ Get the SGLang model implementation class according to config.
625
+
626
+ Args:
627
+ config: The config object of the model.
628
+ auto_model_type: The type of the auto model.
629
+
630
+ Returns:
631
+ The SGLang model implementation class.
632
+ """
633
+ config_cls_name = config.__class__.__name__
634
+ arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
635
+ if arch := arch_name_mapping.get(config_cls_name):
636
+ if isinstance(arch, tuple):
637
+ arch = arch[0]
638
+ logger.warning(
639
+ f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
640
+ )
641
+ try:
642
+ return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
643
+ except Exception as e:
644
+ raise ValueError(
645
+ f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
646
+ )
647
+ else:
648
+ raise ValueError(
649
+ f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
650
+ )
651
+
652
+ @lru_cache
653
+ def _config_cls_name_to_arch_name_mapping(
654
+ self, auto_model_type: Type[AutoModel]
655
+ ) -> Dict[str, str]:
656
+ mapping = {}
657
+ for config_cls, archs in auto_model_type._model_mapping.items():
658
+ if isinstance(archs, tuple):
659
+ mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
660
+ else:
661
+ mapping[config_cls.__name__] = archs.__name__
662
+ return mapping
663
+
664
+ def __init__(
665
+ self,
666
+ config: LlavaConfig,
667
+ quant_config: Optional[QuantizationConfig] = None,
668
+ prefix: str = "",
669
+ ) -> None:
670
+ super().__init__()
671
+
672
+ assert hasattr(config, "text_config")
673
+ assert hasattr(config, "vision_config")
674
+ self.config = config
675
+ self.text_config = config.text_config
676
+ self.vision_config = config.vision_config
677
+
678
+ if not hasattr(self.config, "vocab_size"):
679
+ self.config.vocab_size = self.config.text_config.vocab_size
680
+ if not hasattr(self.config, "image_aspect_ratio"):
681
+ self.config.image_aspect_ratio = "anyres"
682
+ if not hasattr(self.config, "image_grid_pinpoints"):
683
+ # from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
684
+ # self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
685
+ self.config.image_grid_pinpoints = [
686
+ [96, 96],
687
+ [224, 224],
688
+ [384, 384],
689
+ [512, 512],
690
+ [768, 768],
691
+ [1024, 1024],
692
+ ]
693
+ if not hasattr(self.config, "mm_patch_merge_type"):
694
+ self.config.mm_patch_merge_type = "flat"
695
+ if not hasattr(self.config, "image_token_index"):
696
+ self.config.image_token_index = 10
697
+ if not hasattr(self.config, "projector_hidden_act"):
698
+ self.config.projector_hidden_act = "gelu"
699
+
700
+ self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
701
+ self.vision_feature_select_strategy = getattr(
702
+ config, "vision_feature_select_strategy", "full"
703
+ )
704
+ self.image_size = self.config.vision_config.image_size
705
+ self.patch_size = self.config.vision_config.patch_size
706
+
707
+ self.mm_patch_merge_type = config.mm_patch_merge_type
708
+ self.image_aspect_ratio = config.image_aspect_ratio
709
+ self.image_grid_pinpoints = config.image_grid_pinpoints
710
+
711
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
712
+
713
+ self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
714
+
715
+ language_model_cls = self._get_sgl_model_cls(
716
+ config.text_config, AutoModelForCausalLM
717
+ )
718
+ vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
719
+ self.language_model = language_model_cls(
720
+ config.text_config,
721
+ quant_config=quant_config,
722
+ prefix=add_prefix("language_model", prefix),
723
+ )
724
+ self.vision_tower = vision_model_cls(
725
+ config.vision_config,
726
+ quant_config=quant_config,
727
+ prefix=add_prefix("vision_tower", prefix),
728
+ )
729
+
730
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
731
+ self.language_model.model.image_newline = nn.Parameter(
732
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
733
+ )
734
+
735
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
736
+ """Extract features from image inputs.
737
+
738
+ Args:
739
+ items: List of MultimodalDataItem objects containing image data
740
+ Note that an item can be either "image" or "multi-images"
741
+
742
+ Returns:
743
+ torch.Tensor: features from image inputs, concatenated
744
+ """
745
+ features = []
746
+ for item in items:
747
+ # in each item, we assume pixel_values is always batched
748
+ pixel_values, image_sizes = item.pixel_values, item.image_sizes
749
+ image_outputs = self.vision_tower(
750
+ pixel_values, image_sizes, output_hidden_states=True
751
+ )
752
+ selected_image_feature = image_outputs.hidden_states[
753
+ self.vision_feature_layer
754
+ ]
755
+
756
+ if self.vision_feature_select_strategy in ["default", "patch"]:
757
+ selected_image_feature = selected_image_feature[:, 1:]
758
+ elif self.vision_feature_select_strategy == "full":
759
+ selected_image_feature = selected_image_feature
760
+ else:
761
+ raise ValueError(
762
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
763
+ )
764
+ features.append(
765
+ self.multi_modal_projector(selected_image_feature.squeeze(0))
766
+ )
767
+ ret = torch.cat(features, dim=0)
768
+ return ret
769
+
770
+ def forward(
771
+ self,
772
+ input_ids: torch.Tensor,
773
+ positions: torch.Tensor,
774
+ forward_batch: ForwardBatch,
775
+ get_embedding: bool = False,
776
+ ):
777
+ hidden_states = general_mm_embed_routine(
778
+ input_ids=input_ids,
779
+ forward_batch=forward_batch,
780
+ get_embedding=get_embedding,
781
+ language_model=self.language_model,
782
+ image_data_embedding_func=self.get_image_feature,
783
+ placeholder_tokens=None, # using mm_item.pad_value
784
+ positions=positions,
785
+ )
786
+
787
+ return hidden_states
788
+
789
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
790
+ """Load weights for LlavaForConditionalGeneration.
791
+
792
+ Unlike the base class implementation, this one doesn't need to handle
793
+ weight name remapping as the weights are already properly structured with
794
+ 'language_model' and 'vision_tower' prefixes in the safetensors files.
795
+ """
796
+ if (
797
+ self.vision_feature_select_strategy == "patch"
798
+ or self.vision_feature_select_strategy == "full"
799
+ ):
800
+ pass
801
+ elif self.vision_feature_select_strategy == "cls_patch":
802
+ self.image_feature_len += 1
803
+ else:
804
+ raise ValueError(
805
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
806
+ )
807
+
808
+ # Create dictionaries for direct parameter loading
809
+ params_dict = dict(self.named_parameters())
810
+
811
+ # Load weights directly without remapping
812
+ for name, loaded_weight in weights:
813
+ for part in ("language_model", "vision_tower"):
814
+ if name.startswith(part):
815
+ name = name[len(part + ".") :]
816
+ getattr(self, part).load_weights([(name, loaded_weight)])
817
+ break
818
+ else:
819
+ param = params_dict[name]
820
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
821
+ weight_loader(param, loaded_weight)
822
+
823
+
824
+ EntryClass = [
825
+ LlavaLlamaForCausalLM,
826
+ LlavaQwenForCausalLM,
827
+ LlavaMistralForCausalLM,
828
+ LlavaForConditionalGeneration,
829
+ ]