sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
30
30
  from sglang.srt.layers.dp_attention import (
31
31
  dp_gather_partial,
32
32
  dp_scatter,
33
- get_attention_dp_size,
34
33
  get_attention_tp_rank,
35
34
  get_attention_tp_size,
35
+ get_local_attention_dp_size,
36
36
  )
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.linear import (
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
46
46
  from sglang.srt.layers.rotary_embedding import get_rope
47
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49
+ from sglang.srt.model_executor.forward_batch_info import (
50
+ ForwardBatch,
51
+ ForwardMode,
52
+ PPProxyTensors,
53
+ )
50
54
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
51
55
  from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
52
56
 
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
81
85
  super().__init__()
82
86
  self.tp_size = get_tensor_model_parallel_world_size()
83
87
  self.top_k = config.num_experts_per_tok
88
+ self.device_module = torch.get_device_module()
84
89
 
85
90
  intermediate_size_moe = config.intermediate_size
86
91
  self.router = ReplicatedLinear(
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
113
118
  reduce_results=False, # We need to do scatter before reduce
114
119
  )
115
120
 
116
- def forward(self, hidden_states):
121
+ def forward(self, hidden_states, forward_batch: ForwardBatch):
122
+ shared_out, routed_out = self._forward_core(
123
+ hidden_states, forward_batch.forward_mode
124
+ )
125
+
126
+ out_aD = routed_out + shared_out
127
+
128
+ if self.tp_size > 1:
129
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
130
+
131
+ return out_aD
132
+
133
+ def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134
+ if hidden_states.shape[0] < 4:
135
+ return self._forward_core_shared_routed_overlap(hidden_states)
136
+ else:
137
+ return self._forward_core_normal(hidden_states)
138
+
139
+ def _forward_core_normal(self, hidden_states):
117
140
  # router_scores: [num_tokens, num_experts]
118
141
  router_logits, _ = self.router(hidden_states)
119
142
  shared_out = self.shared_expert(hidden_states)
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
121
144
  hidden_states=hidden_states,
122
145
  router_logits=router_logits,
123
146
  )
124
- out_aD = routed_out + shared_out
147
+ return shared_out, routed_out
125
148
 
126
- if self.tp_size > 1:
127
- out_aD = tensor_model_parallel_all_reduce(out_aD)
149
+ def _forward_core_shared_routed_overlap(self, hidden_states):
150
+ alt_stream = _get_or_create_alt_stream(self.device_module)
128
151
 
129
- return out_aD
152
+ alt_stream.wait_stream(self.device_module.current_stream())
153
+
154
+ shared_out = self.shared_expert(hidden_states)
155
+
156
+ with self.device_module.stream(alt_stream):
157
+ # router_scores: [num_tokens, num_experts]
158
+ router_logits, _ = self.router(hidden_states)
159
+ routed_out = self.experts(
160
+ hidden_states=hidden_states,
161
+ router_logits=router_logits,
162
+ )
163
+ self.device_module.current_stream().wait_stream(alt_stream)
164
+
165
+ return shared_out, routed_out
166
+
167
+
168
+ _alt_stream = None
169
+
170
+
171
+ def _get_or_create_alt_stream(device_module):
172
+ global _alt_stream
173
+ if _alt_stream is None:
174
+ _alt_stream = device_module.Stream()
175
+ return _alt_stream
130
176
 
131
177
 
132
178
  class Llama4Attention(nn.Module):
@@ -152,7 +198,6 @@ class Llama4Attention(nn.Module):
152
198
  self.use_rope = int((layer_id + 1) % 4 != 0)
153
199
  self.use_qk_norm = config.use_qk_norm and self.use_rope
154
200
 
155
- self.dp_size = get_attention_dp_size()
156
201
  attn_tp_rank = get_attention_tp_rank()
157
202
  attn_tp_size = get_attention_tp_size()
158
203
 
@@ -296,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
296
341
  rope_theta = config.rope_theta
297
342
  rope_scaling = config.rope_scaling
298
343
  max_position_embeddings = config.max_position_embeddings
299
- self.dp_size = get_attention_dp_size()
344
+ self.local_dp_size = get_local_attention_dp_size()
300
345
  self.attn_tp_size = get_attention_tp_size()
301
346
  self.attn_tp_rank = get_attention_tp_rank()
302
347
 
@@ -359,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
359
404
  # Gather
360
405
  if get_tensor_model_parallel_world_size() > 1:
361
406
  # all gather and all reduce
362
- if self.dp_size != 1:
407
+ if self.local_dp_size != 1:
363
408
  if self.attn_tp_rank == 0:
364
409
  hidden_states += residual
365
410
  hidden_states, local_hidden_states = (
@@ -380,11 +425,11 @@ class Llama4DecoderLayer(nn.Module):
380
425
  )
381
426
 
382
427
  # Fully Connected
383
- hidden_states = self.feed_forward(hidden_states)
428
+ hidden_states = self.feed_forward(hidden_states, forward_batch)
384
429
 
385
- # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
430
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
386
431
  # Scatter
387
- if self.dp_size != 1:
432
+ if self.local_dp_size != 1:
388
433
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
389
434
  # be careful about this!
390
435
  hidden_states, global_hidden_states = (
@@ -15,7 +15,8 @@
15
15
 
16
16
  import math
17
17
  import re
18
- from typing import Iterable, List, Optional, Tuple
18
+ from functools import lru_cache
19
+ from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
19
20
 
20
21
  import numpy as np
21
22
  import torch
@@ -28,10 +29,18 @@ from transformers import (
28
29
  Qwen2Config,
29
30
  SiglipVisionModel,
30
31
  )
32
+ from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
31
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
34
 
35
+ # leave till last and symbol only in case circular import
36
+ import sglang.srt.models as sgl_models
33
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
38
+ from sglang.srt.managers.mm_utils import general_mm_embed_routine
39
+ from sglang.srt.managers.schedule_batch import (
40
+ Modality,
41
+ MultimodalDataItem,
42
+ MultimodalInputs,
43
+ )
35
44
  from sglang.srt.mm_utils import (
36
45
  get_anyres_image_grid_shape,
37
46
  unpad_image,
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
51
  from sglang.srt.models.llama import LlamaForCausalLM
43
52
  from sglang.srt.models.mistral import MistralForCausalLM
44
53
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
- from sglang.srt.utils import add_prefix, flatten_nested_list
54
+ from sglang.srt.utils import add_prefix, flatten_nested_list, logger
46
55
 
47
56
 
48
57
  class LlavaBaseForCausalLM(nn.Module):
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
114
123
  image_inputs.image_offsets = offset_list
115
124
  return input_ids
116
125
 
117
- def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
126
+ def encode_images(
127
+ self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
128
+ ) -> torch.Tensor:
129
+ """
130
+ encode images by vision tower and multimodal projector
131
+ Args:
132
+ pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
133
+ Returns:
134
+ torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
135
+ """
118
136
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
119
137
  # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
120
138
 
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
583
601
  )
584
602
 
585
603
 
586
- EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
604
+ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
605
+ """
606
+ An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
607
+ It follows the structure of (vision_tower, multi_modal_projector, language_model)
608
+
609
+ Once a model config is loaded, text_config and vision_config will be extracted, and
610
+ LlavaForConditionalGeneration will load the language_model and vision_tower models
611
+ according to config.
612
+ """
613
+
614
+ MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
615
+
616
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
617
+ if hasattr(self.vision_tower, "pad_input_ids"):
618
+ return self.vision_tower.pad_input_ids(input_ids, image_inputs)
619
+ else:
620
+ return super().pad_input_ids(input_ids, image_inputs)
621
+
622
+ def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
623
+ """
624
+ Get the SGLang model implementation class according to config.
625
+
626
+ Args:
627
+ config: The config object of the model.
628
+ auto_model_type: The type of the auto model.
629
+
630
+ Returns:
631
+ The SGLang model implementation class.
632
+ """
633
+ config_cls_name = config.__class__.__name__
634
+ arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
635
+ if arch := arch_name_mapping.get(config_cls_name):
636
+ if isinstance(arch, tuple):
637
+ arch = arch[0]
638
+ logger.warning(
639
+ f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
640
+ )
641
+ try:
642
+ return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
643
+ except Exception as e:
644
+ raise ValueError(
645
+ f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
646
+ )
647
+ else:
648
+ raise ValueError(
649
+ f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
650
+ )
651
+
652
+ @lru_cache
653
+ def _config_cls_name_to_arch_name_mapping(
654
+ self, auto_model_type: Type[AutoModel]
655
+ ) -> Dict[str, str]:
656
+ mapping = {}
657
+ for config_cls, archs in auto_model_type._model_mapping.items():
658
+ if isinstance(archs, tuple):
659
+ mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
660
+ else:
661
+ mapping[config_cls.__name__] = archs.__name__
662
+ return mapping
663
+
664
+ def __init__(
665
+ self,
666
+ config: LlavaConfig,
667
+ quant_config: Optional[QuantizationConfig] = None,
668
+ prefix: str = "",
669
+ ) -> None:
670
+ super().__init__()
671
+
672
+ assert hasattr(config, "text_config")
673
+ assert hasattr(config, "vision_config")
674
+ self.config = config
675
+ self.text_config = config.text_config
676
+ self.vision_config = config.vision_config
677
+
678
+ if not hasattr(self.config, "vocab_size"):
679
+ self.config.vocab_size = self.config.text_config.vocab_size
680
+ if not hasattr(self.config, "image_aspect_ratio"):
681
+ self.config.image_aspect_ratio = "anyres"
682
+ if not hasattr(self.config, "image_grid_pinpoints"):
683
+ # from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
684
+ # self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
685
+ self.config.image_grid_pinpoints = [
686
+ [96, 96],
687
+ [224, 224],
688
+ [384, 384],
689
+ [512, 512],
690
+ [768, 768],
691
+ [1024, 1024],
692
+ ]
693
+ if not hasattr(self.config, "mm_patch_merge_type"):
694
+ self.config.mm_patch_merge_type = "flat"
695
+ if not hasattr(self.config, "image_token_index"):
696
+ self.config.image_token_index = 10
697
+ if not hasattr(self.config, "projector_hidden_act"):
698
+ self.config.projector_hidden_act = "gelu"
699
+
700
+ self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
701
+ self.vision_feature_select_strategy = getattr(
702
+ config, "vision_feature_select_strategy", "full"
703
+ )
704
+ self.image_size = self.config.vision_config.image_size
705
+ self.patch_size = self.config.vision_config.patch_size
706
+
707
+ self.mm_patch_merge_type = config.mm_patch_merge_type
708
+ self.image_aspect_ratio = config.image_aspect_ratio
709
+ self.image_grid_pinpoints = config.image_grid_pinpoints
710
+
711
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
712
+
713
+ self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
714
+
715
+ language_model_cls = self._get_sgl_model_cls(
716
+ config.text_config, AutoModelForCausalLM
717
+ )
718
+ vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
719
+ self.language_model = language_model_cls(
720
+ config.text_config,
721
+ quant_config=quant_config,
722
+ prefix=add_prefix("language_model", prefix),
723
+ )
724
+ self.vision_tower = vision_model_cls(
725
+ config.vision_config,
726
+ quant_config=quant_config,
727
+ prefix=add_prefix("vision_tower", prefix),
728
+ )
729
+
730
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
731
+ self.language_model.model.image_newline = nn.Parameter(
732
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
733
+ )
734
+
735
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
736
+ """Extract features from image inputs.
737
+
738
+ Args:
739
+ items: List of MultimodalDataItem objects containing image data
740
+ Note that an item can be either "image" or "multi-images"
741
+
742
+ Returns:
743
+ torch.Tensor: features from image inputs, concatenated
744
+ """
745
+ features = []
746
+ for item in items:
747
+ # in each item, we assume pixel_values is always batched
748
+ pixel_values, image_sizes = item.pixel_values, item.image_sizes
749
+ image_outputs = self.vision_tower(
750
+ pixel_values, image_sizes, output_hidden_states=True
751
+ )
752
+ selected_image_feature = image_outputs.hidden_states[
753
+ self.vision_feature_layer
754
+ ]
755
+
756
+ if self.vision_feature_select_strategy in ["default", "patch"]:
757
+ selected_image_feature = selected_image_feature[:, 1:]
758
+ elif self.vision_feature_select_strategy == "full":
759
+ selected_image_feature = selected_image_feature
760
+ else:
761
+ raise ValueError(
762
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
763
+ )
764
+ features.append(
765
+ self.multi_modal_projector(selected_image_feature.squeeze(0))
766
+ )
767
+ ret = torch.cat(features, dim=0)
768
+ return ret
769
+
770
+ def forward(
771
+ self,
772
+ input_ids: torch.Tensor,
773
+ positions: torch.Tensor,
774
+ forward_batch: ForwardBatch,
775
+ get_embedding: bool = False,
776
+ ):
777
+ hidden_states = general_mm_embed_routine(
778
+ input_ids=input_ids,
779
+ forward_batch=forward_batch,
780
+ get_embedding=get_embedding,
781
+ language_model=self.language_model,
782
+ image_data_embedding_func=self.get_image_feature,
783
+ placeholder_tokens=None, # using mm_item.pad_value
784
+ positions=positions,
785
+ )
786
+
787
+ return hidden_states
788
+
789
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
790
+ """Load weights for LlavaForConditionalGeneration.
791
+
792
+ Unlike the base class implementation, this one doesn't need to handle
793
+ weight name remapping as the weights are already properly structured with
794
+ 'language_model' and 'vision_tower' prefixes in the safetensors files.
795
+ """
796
+ if (
797
+ self.vision_feature_select_strategy == "patch"
798
+ or self.vision_feature_select_strategy == "full"
799
+ ):
800
+ pass
801
+ elif self.vision_feature_select_strategy == "cls_patch":
802
+ self.image_feature_len += 1
803
+ else:
804
+ raise ValueError(
805
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
806
+ )
807
+
808
+ # Create dictionaries for direct parameter loading
809
+ params_dict = dict(self.named_parameters())
810
+
811
+ # Load weights directly without remapping
812
+ for name, loaded_weight in weights:
813
+ for part in ("language_model", "vision_tower"):
814
+ if name.startswith(part):
815
+ name = name[len(part + ".") :]
816
+ getattr(self, part).load_weights([(name, loaded_weight)])
817
+ break
818
+ else:
819
+ param = params_dict[name]
820
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
821
+ weight_loader(param, loaded_weight)
822
+
823
+
824
+ EntryClass = [
825
+ LlavaLlamaForCausalLM,
826
+ LlavaQwenForCausalLM,
827
+ LlavaMistralForCausalLM,
828
+ LlavaForConditionalGeneration,
829
+ ]
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
197
197
  use_qkv_parallel=True,
198
198
  quant_config=quant_config,
199
199
  dropout=config.attention_dropout,
200
- use_context_forward=False,
200
+ qkv_backend="sdpa",
201
201
  softmax_in_single_precision=True,
202
202
  flatten_batch=False,
203
203
  prefix=add_prefix("self_attn", prefix),
@@ -16,13 +16,15 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Mixtral model."""
18
18
 
19
- from typing import Iterable, Optional, Tuple
19
+ import logging
20
+ from typing import Iterable, Optional, Tuple, Union
20
21
 
21
22
  import torch
22
23
  from torch import nn
23
24
  from transformers import MixtralConfig
24
25
 
25
26
  from sglang.srt.distributed import (
27
+ get_pp_group,
26
28
  get_tensor_model_parallel_world_size,
27
29
  tensor_model_parallel_all_reduce,
28
30
  )
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
38
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
41
  from sglang.srt.layers.radix_attention import RadixAttention
40
42
  from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
44
  from sglang.srt.layers.vocab_parallel_embedding import (
42
45
  ParallelLMHead,
43
46
  VocabParallelEmbedding,
44
47
  )
45
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
- from sglang.srt.utils import add_prefix
51
+ from sglang.srt.utils import add_prefix, make_layers
52
+
53
+ logger = logging.getLogger(__name__)
49
54
 
50
55
 
51
56
  class MixtralMoE(nn.Module):
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
257
262
  super().__init__()
258
263
  self.padding_idx = config.pad_token_id
259
264
  self.vocab_size = config.vocab_size
265
+ self.pp_group = get_pp_group()
260
266
 
261
- self.embed_tokens = VocabParallelEmbedding(
262
- config.vocab_size,
263
- config.hidden_size,
264
- prefix=add_prefix("embed_tokens", prefix),
265
- )
266
- self.layers = nn.ModuleList(
267
- [
268
- MixtralDecoderLayer(
269
- config,
270
- i,
271
- quant_config=quant_config,
272
- prefix=add_prefix(f"layers.{i}", prefix),
273
- )
274
- for i in range(config.num_hidden_layers)
275
- ]
267
+ if self.pp_group.is_first_rank:
268
+ self.embed_tokens = VocabParallelEmbedding(
269
+ config.vocab_size,
270
+ config.hidden_size,
271
+ prefix=add_prefix("embed_tokens", prefix),
272
+ )
273
+ else:
274
+ self.embed_tokens = PPMissingLayer()
275
+
276
+ self.layers, self.start_layer, self.end_layer = make_layers(
277
+ config.num_hidden_layers,
278
+ lambda idx, prefix: MixtralDecoderLayer(
279
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
280
+ ),
281
+ pp_rank=self.pp_group.rank_in_group,
282
+ pp_size=self.pp_group.world_size,
283
+ prefix="layers",
284
+ return_tuple=True,
276
285
  )
277
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+
287
+ if self.pp_group.is_last_rank:
288
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+ else:
290
+ self.norm = PPMissingLayer(return_tuple=True)
278
291
 
279
292
  def forward(
280
293
  self,
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
282
295
  positions: torch.Tensor,
283
296
  forward_batch: ForwardBatch,
284
297
  input_embeds: torch.Tensor = None,
285
- ) -> torch.Tensor:
286
- if input_embeds is None:
287
- hidden_states = self.embed_tokens(input_ids)
298
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
299
+ ) -> Union[torch.Tensor, PPProxyTensors]:
300
+ if self.pp_group.is_first_rank:
301
+ if input_embeds is None:
302
+ hidden_states = self.embed_tokens(input_ids)
303
+ else:
304
+ hidden_states = input_embeds
305
+ residual = None
288
306
  else:
289
- hidden_states = input_embeds
290
- residual = None
291
- for i in range(len(self.layers)):
307
+ assert pp_proxy_tensors is not None
308
+ hidden_states = pp_proxy_tensors["hidden_states"]
309
+ residual = pp_proxy_tensors["residual"]
310
+
311
+ for i in range(self.start_layer, self.end_layer):
292
312
  layer = self.layers[i]
293
313
  hidden_states, residual = layer(
294
314
  positions, hidden_states, forward_batch, residual
295
315
  )
296
- hidden_states, _ = self.norm(hidden_states, residual)
316
+
317
+ if not self.pp_group.is_last_rank:
318
+ return PPProxyTensors(
319
+ {
320
+ "hidden_states": hidden_states,
321
+ "residual": residual,
322
+ }
323
+ )
324
+ else:
325
+ hidden_states, _ = self.norm(hidden_states, residual)
326
+
297
327
  return hidden_states
298
328
 
299
329
 
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
306
336
  prefix: str = "",
307
337
  ) -> None:
308
338
  super().__init__()
339
+ self.pp_group = get_pp_group()
309
340
  self.config = config
310
341
  self.quant_config = quant_config
311
342
  self.model = MixtralModel(
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
322
353
  positions: torch.Tensor,
323
354
  forward_batch: ForwardBatch,
324
355
  input_embeds: torch.Tensor = None,
356
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
325
357
  ) -> torch.Tensor:
326
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
327
- return self.logits_processor(
328
- input_ids, hidden_states, self.lm_head, forward_batch
358
+ hidden_states = self.model(
359
+ input_ids,
360
+ positions,
361
+ forward_batch,
362
+ input_embeds,
363
+ pp_proxy_tensors=pp_proxy_tensors,
329
364
  )
330
365
 
366
+ if self.pp_group.is_last_rank:
367
+ return self.logits_processor(
368
+ input_ids, hidden_states, self.lm_head, forward_batch
369
+ )
370
+ else:
371
+ return hidden_states
372
+
373
+ @property
374
+ def start_layer(self):
375
+ return self.model.start_layer
376
+
377
+ @property
378
+ def end_layer(self):
379
+ return self.model.end_layer
380
+
331
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
332
382
  stacked_params_mapping = [
333
383
  # (param_name, shard_name, shard_id)
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
348
398
 
349
399
  params_dict = dict(self.named_parameters())
350
400
  for name, loaded_weight in weights:
401
+ layer_id = get_layer_id(name)
402
+ if (
403
+ layer_id is not None
404
+ and hasattr(self.model, "start_layer")
405
+ and (
406
+ layer_id < self.model.start_layer
407
+ or layer_id >= self.model.end_layer
408
+ )
409
+ ):
410
+ continue
411
+
351
412
  if "rotary_emb.inv_freq" in name:
352
413
  continue
353
414
 
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
398
459
  if name is None:
399
460
  continue
400
461
 
401
- param = params_dict[name]
402
- weight_loader = getattr(
403
- param, "weight_loader", default_weight_loader
404
- )
405
- weight_loader(param, loaded_weight)
462
+ if name in params_dict.keys():
463
+ param = params_dict[name]
464
+ weight_loader = getattr(
465
+ param, "weight_loader", default_weight_loader
466
+ )
467
+ weight_loader(param, loaded_weight)
468
+ else:
469
+ logger.warning(f"Parameter {name} not found in params_dict")
406
470
 
407
471
 
408
472
  EntryClass = MixtralForCausalLM
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
203
203
  use_qkv_parallel=True,
204
204
  quant_config=quant_config,
205
205
  dropout=0.0,
206
- use_context_forward=False,
206
+ qkv_backend="sdpa",
207
207
  softmax_in_single_precision=False,
208
208
  flatten_batch=False,
209
209
  prefix=add_prefix("self_attn", prefix),