sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,220 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+
10
+ from sglang.srt.distributed import (
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ split_tensor_along_last_dim,
14
+ tensor_model_parallel_all_gather,
15
+ )
16
+ from sglang.srt.layers.layernorm import RMSNorm
17
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.layers.rotary_embedding import get_rope
23
+ from sglang.srt.layers.vocab_parallel_embedding import (
24
+ ParallelLMHead,
25
+ VocabParallelEmbedding,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.models.mimo import MiMoForCausalLM
30
+ from sglang.srt.models.qwen2 import (
31
+ Qwen2Attention,
32
+ Qwen2DecoderLayer,
33
+ Qwen2MLP,
34
+ Qwen2Model,
35
+ )
36
+ from sglang.srt.utils import add_prefix
37
+
38
+
39
+ class MiMoMultiTokenPredictorLayer(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ prefix: str,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ self.embed_tokens = VocabParallelEmbedding(
50
+ config.vocab_size,
51
+ config.hidden_size,
52
+ )
53
+ self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
54
+ self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
55
+ self.input_proj = nn.Linear(
56
+ config.hidden_size * 2, config.hidden_size, bias=False
57
+ )
58
+ self.mtp_block = Qwen2DecoderLayer(
59
+ config=config, quant_config=quant_config, prefix=prefix
60
+ )
61
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.Tensor,
66
+ positions: torch.Tensor,
67
+ forward_batch: ForwardBatch,
68
+ input_embeds: torch.Tensor = None,
69
+ ) -> torch.Tensor:
70
+
71
+ if input_embeds is None:
72
+ hidden_states = self.embed_tokens(input_ids)
73
+ else:
74
+ hidden_states = input_embeds
75
+ # masking inputs at position 0, as not needed by MTP
76
+ hidden_states[positions == 0] = 0
77
+
78
+ hidden_states = self.input_proj(
79
+ torch.cat(
80
+ (
81
+ self.hidden_layernorm(forward_batch.spec_info.hidden_states),
82
+ self.token_layernorm(hidden_states),
83
+ ),
84
+ dim=-1,
85
+ )
86
+ )
87
+
88
+ hidden_states, residual = self.mtp_block(
89
+ positions=positions,
90
+ hidden_states=hidden_states,
91
+ forward_batch=forward_batch,
92
+ residual=None,
93
+ )
94
+ hidden_states = residual + hidden_states
95
+ hidden_states = self.final_layernorm(hidden_states)
96
+ return hidden_states
97
+
98
+
99
+ class MiMoMTP(nn.Module):
100
+ def __init__(
101
+ self,
102
+ config: PretrainedConfig,
103
+ quant_config: Optional[QuantizationConfig] = None,
104
+ prefix: str = "",
105
+ ) -> None:
106
+ nn.Module.__init__(self)
107
+ self.config = config
108
+ self.tp_size = get_tensor_model_parallel_world_size()
109
+ self.quant_config = quant_config
110
+
111
+ self.model = MiMoMultiTokenPredictorLayer(
112
+ config,
113
+ prefix,
114
+ quant_config,
115
+ )
116
+ self.lm_head = ParallelLMHead(
117
+ config.vocab_size,
118
+ config.hidden_size,
119
+ quant_config=quant_config,
120
+ )
121
+ self.logits_processor = LogitsProcessor(config)
122
+
123
+ @torch.no_grad()
124
+ def forward(
125
+ self,
126
+ input_ids: torch.Tensor,
127
+ positions: torch.Tensor,
128
+ forward_batch: ForwardBatch,
129
+ ) -> torch.Tensor:
130
+ hidden_states = self.model(input_ids, positions, forward_batch)
131
+ return self.logits_processor(
132
+ input_ids, hidden_states, self.lm_head, forward_batch
133
+ )
134
+
135
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
136
+ stacked_params_mapping = [
137
+ # (param_name, shard_name, shard_id)
138
+ ("qkv_proj", "q_proj", "q"),
139
+ ("qkv_proj", "k_proj", "k"),
140
+ ("qkv_proj", "v_proj", "v"),
141
+ ("gate_up_proj", "gate_proj", 0),
142
+ ("gate_up_proj", "up_proj", 1),
143
+ ]
144
+
145
+ params_dict = dict(self.named_parameters())
146
+ for name, loaded_weight in weights:
147
+ if "rotary_emb.inv_freq" in name or "projector" in name:
148
+ continue
149
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
150
+ # Models trained using ColossalAI may include these tensors in
151
+ # the checkpoint. Skip them.
152
+ continue
153
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
154
+ continue
155
+ if name.startswith("model.vision_tower") and name not in params_dict:
156
+ continue
157
+ name = self.map_model_name_to_mtp_param_name(name)
158
+
159
+ for param_name, weight_name, shard_id in stacked_params_mapping:
160
+ if weight_name not in name:
161
+ continue
162
+ if "mtp_block" not in name:
163
+ break
164
+ name = name.replace(weight_name, param_name)
165
+ # Skip loading extra bias for GPTQ models.
166
+ if name.endswith(".bias") and name not in params_dict:
167
+ continue
168
+ param = params_dict[name]
169
+ weight_loader = param.weight_loader
170
+ weight_loader(param, loaded_weight, shard_id)
171
+ break
172
+ else:
173
+ # Skip loading extra bias for GPTQ models.
174
+ if name.endswith(".bias") and name not in params_dict:
175
+ continue
176
+ if "mtp_block" not in name and (
177
+ "embed_tokens" not in name
178
+ and "lm_head" not in name
179
+ and "token_layernorm" not in name
180
+ and "hidden_layernorm" not in name
181
+ and "input_proj" not in name
182
+ and "final_layernorm" not in name
183
+ ):
184
+ continue
185
+ param = params_dict[name]
186
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
187
+ weight_loader(param, loaded_weight)
188
+
189
+ def map_model_name_to_mtp_param_name(self, name: str) -> str:
190
+ import re
191
+
192
+ name_without_prefix = [
193
+ "token_layernorm",
194
+ "hidden_layernorm",
195
+ "input_proj",
196
+ "final_layernorm",
197
+ ]
198
+ pattern = r"model.mtp_layers.(\d+)."
199
+ group = re.match(pattern, name)
200
+ if group is not None:
201
+ for sub_name in name_without_prefix:
202
+ if sub_name in name:
203
+ name = name.replace(group.group(), "model.")
204
+ return name
205
+ name = name.replace(group.group(), "model.mtp_block.")
206
+ return name
207
+
208
+ def get_embed_and_head(self):
209
+ return self.model.embed_tokens.weight, self.lm_head.weight
210
+
211
+ def set_embed_and_head(self, embed, head):
212
+ del self.model.embed_tokens.weight
213
+ del self.lm_head.weight
214
+ self.model.embed_tokens.weight = embed
215
+ self.lm_head.weight = head
216
+ torch.cuda.empty_cache()
217
+ torch.cuda.synchronize()
218
+
219
+
220
+ EntryClass = MiMoMTP
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
1520
1520
  slice_start_id: int = mm_input.slice_start_id
1521
1521
  slice_end_id: int = mm_input.slice_end_id
1522
1522
 
1523
- media_token_pairs = [
1523
+ data_token_pairs = [
1524
1524
  (im_start_id, im_end_id),
1525
1525
  (slice_start_id, slice_end_id),
1526
1526
  (mm_input.audio_start_id, mm_input.audio_end_id),
1527
1527
  ]
1528
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
1528
+ data_start_token_ids = [im_start_id, mm_input.audio_start_id]
1529
+ pattern = MultiModalityDataPaddingPatternTokenPairs(
1530
+ data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
1531
+ )
1529
1532
 
1530
1533
  return pattern.pad_input_tokens(input_ids, mm_input)
1531
1534
 
@@ -1823,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
1823
1826
  **kwargs: Any,
1824
1827
  ) -> torch.Tensor:
1825
1828
 
1826
- mm_input = forward_batch.merge_mm_inputs()
1827
- placeholder_token_ids = (
1828
- ([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
1829
- if forward_batch.contains_mm_inputs()
1830
- else []
1831
- )
1832
1829
  hidden_states = general_mm_embed_routine(
1833
1830
  input_ids=input_ids,
1834
1831
  forward_batch=forward_batch,
1835
1832
  language_model=self.llm,
1836
1833
  image_data_embedding_func=self.get_image_feature,
1837
1834
  audio_data_embedding_func=self.get_audio_feature,
1838
- placeholder_tokens={
1839
- Modality.IMAGE: placeholder_token_ids,
1840
- Modality.AUDIO: placeholder_token_ids,
1841
- },
1842
1835
  positions=positions,
1843
1836
  )
1844
1837
  return hidden_states
@@ -13,6 +13,12 @@
13
13
  # ==============================================================================
14
14
  """Inference-only Mistral model."""
15
15
 
16
+ from typing import List, Union
17
+
18
+ import torch
19
+ from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
20
+
21
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem
16
22
  from sglang.srt.models.llama import LlamaForCausalLM
17
23
 
18
24
 
@@ -20,4 +26,68 @@ class MistralForCausalLM(LlamaForCausalLM):
20
26
  pass
21
27
 
22
28
 
23
- EntryClass = MistralForCausalLM
29
+ class Mistral3ForConditionalGeneration:
30
+ MULTIMODAL_PROJECTOR_TYPE = Mistral3MultiModalProjector
31
+
32
+ def __init__(self, **kwargs):
33
+ # lazy load inner class
34
+ # to bypass circular import
35
+ from sglang.srt.models.llava import LlavaForConditionalGeneration
36
+
37
+ # override config: mistral's projector adds patchmerger that doesn't require padding
38
+ kwargs["config"].vision_config.pad_image_border = False
39
+
40
+ self.inner = LlavaForConditionalGeneration(**kwargs)
41
+ self.inner.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(
42
+ kwargs["config"]
43
+ )
44
+ self.inner.get_image_feature = self.get_image_feature
45
+
46
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
47
+ """Extract features from image inputs.
48
+
49
+ Args:
50
+ items: List of MultimodalDataItem objects containing image data
51
+ Note that an item can be either "image" or "multi-images"
52
+
53
+ Returns:
54
+ torch.Tensor: features from image inputs, concatenated
55
+ """
56
+ features = []
57
+ for item in items:
58
+ # in each item, we assume pixel_values is always batched
59
+ pixel_values, image_sizes = item.pixel_values, item.image_sizes
60
+ image_outputs = self.vision_tower(
61
+ pixel_values, image_sizes, output_hidden_states=True
62
+ )
63
+ selected_image_feature = image_outputs.hidden_states[
64
+ self.vision_feature_layer
65
+ ]
66
+
67
+ if self.vision_feature_select_strategy in ["default", "patch"]:
68
+ selected_image_feature = selected_image_feature[:, 1:]
69
+ elif self.vision_feature_select_strategy == "full":
70
+ selected_image_feature = selected_image_feature
71
+ else:
72
+ raise ValueError(
73
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
74
+ )
75
+ features.append(
76
+ self.multi_modal_projector(
77
+ selected_image_feature.squeeze(0), image_sizes
78
+ )
79
+ )
80
+ ret = torch.cat(features, dim=0)
81
+ return ret
82
+
83
+ def __getattr__(self, name):
84
+ return getattr(self.inner, name)
85
+
86
+ def __hasattr__(self, name):
87
+ return hasattr(self.inner, name)
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ return self.inner(*args, **kwargs)
91
+
92
+
93
+ EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration]
@@ -16,13 +16,15 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Mixtral model."""
18
18
 
19
- from typing import Iterable, Optional, Tuple
19
+ import logging
20
+ from typing import Iterable, Optional, Tuple, Union
20
21
 
21
22
  import torch
22
23
  from torch import nn
23
24
  from transformers import MixtralConfig
24
25
 
25
26
  from sglang.srt.distributed import (
27
+ get_pp_group,
26
28
  get_tensor_model_parallel_world_size,
27
29
  tensor_model_parallel_all_reduce,
28
30
  )
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
38
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
41
  from sglang.srt.layers.radix_attention import RadixAttention
40
42
  from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
44
  from sglang.srt.layers.vocab_parallel_embedding import (
42
45
  ParallelLMHead,
43
46
  VocabParallelEmbedding,
44
47
  )
45
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
- from sglang.srt.utils import add_prefix
51
+ from sglang.srt.utils import add_prefix, make_layers
52
+
53
+ logger = logging.getLogger(__name__)
49
54
 
50
55
 
51
56
  class MixtralMoE(nn.Module):
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
257
262
  super().__init__()
258
263
  self.padding_idx = config.pad_token_id
259
264
  self.vocab_size = config.vocab_size
265
+ self.pp_group = get_pp_group()
260
266
 
261
- self.embed_tokens = VocabParallelEmbedding(
262
- config.vocab_size,
263
- config.hidden_size,
264
- prefix=add_prefix("embed_tokens", prefix),
265
- )
266
- self.layers = nn.ModuleList(
267
- [
268
- MixtralDecoderLayer(
269
- config,
270
- i,
271
- quant_config=quant_config,
272
- prefix=add_prefix(f"layers.{i}", prefix),
273
- )
274
- for i in range(config.num_hidden_layers)
275
- ]
267
+ if self.pp_group.is_first_rank:
268
+ self.embed_tokens = VocabParallelEmbedding(
269
+ config.vocab_size,
270
+ config.hidden_size,
271
+ prefix=add_prefix("embed_tokens", prefix),
272
+ )
273
+ else:
274
+ self.embed_tokens = PPMissingLayer()
275
+
276
+ self.layers, self.start_layer, self.end_layer = make_layers(
277
+ config.num_hidden_layers,
278
+ lambda idx, prefix: MixtralDecoderLayer(
279
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
280
+ ),
281
+ pp_rank=self.pp_group.rank_in_group,
282
+ pp_size=self.pp_group.world_size,
283
+ prefix="layers",
284
+ return_tuple=True,
276
285
  )
277
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+
287
+ if self.pp_group.is_last_rank:
288
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+ else:
290
+ self.norm = PPMissingLayer(return_tuple=True)
278
291
 
279
292
  def forward(
280
293
  self,
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
282
295
  positions: torch.Tensor,
283
296
  forward_batch: ForwardBatch,
284
297
  input_embeds: torch.Tensor = None,
285
- ) -> torch.Tensor:
286
- if input_embeds is None:
287
- hidden_states = self.embed_tokens(input_ids)
298
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
299
+ ) -> Union[torch.Tensor, PPProxyTensors]:
300
+ if self.pp_group.is_first_rank:
301
+ if input_embeds is None:
302
+ hidden_states = self.embed_tokens(input_ids)
303
+ else:
304
+ hidden_states = input_embeds
305
+ residual = None
288
306
  else:
289
- hidden_states = input_embeds
290
- residual = None
291
- for i in range(len(self.layers)):
307
+ assert pp_proxy_tensors is not None
308
+ hidden_states = pp_proxy_tensors["hidden_states"]
309
+ residual = pp_proxy_tensors["residual"]
310
+
311
+ for i in range(self.start_layer, self.end_layer):
292
312
  layer = self.layers[i]
293
313
  hidden_states, residual = layer(
294
314
  positions, hidden_states, forward_batch, residual
295
315
  )
296
- hidden_states, _ = self.norm(hidden_states, residual)
316
+
317
+ if not self.pp_group.is_last_rank:
318
+ return PPProxyTensors(
319
+ {
320
+ "hidden_states": hidden_states,
321
+ "residual": residual,
322
+ }
323
+ )
324
+ else:
325
+ hidden_states, _ = self.norm(hidden_states, residual)
326
+
297
327
  return hidden_states
298
328
 
299
329
 
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
306
336
  prefix: str = "",
307
337
  ) -> None:
308
338
  super().__init__()
339
+ self.pp_group = get_pp_group()
309
340
  self.config = config
310
341
  self.quant_config = quant_config
311
342
  self.model = MixtralModel(
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
322
353
  positions: torch.Tensor,
323
354
  forward_batch: ForwardBatch,
324
355
  input_embeds: torch.Tensor = None,
356
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
325
357
  ) -> torch.Tensor:
326
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
327
- return self.logits_processor(
328
- input_ids, hidden_states, self.lm_head, forward_batch
358
+ hidden_states = self.model(
359
+ input_ids,
360
+ positions,
361
+ forward_batch,
362
+ input_embeds,
363
+ pp_proxy_tensors=pp_proxy_tensors,
329
364
  )
330
365
 
366
+ if self.pp_group.is_last_rank:
367
+ return self.logits_processor(
368
+ input_ids, hidden_states, self.lm_head, forward_batch
369
+ )
370
+ else:
371
+ return hidden_states
372
+
373
+ @property
374
+ def start_layer(self):
375
+ return self.model.start_layer
376
+
377
+ @property
378
+ def end_layer(self):
379
+ return self.model.end_layer
380
+
331
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
332
382
  stacked_params_mapping = [
333
383
  # (param_name, shard_name, shard_id)
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
348
398
 
349
399
  params_dict = dict(self.named_parameters())
350
400
  for name, loaded_weight in weights:
401
+ layer_id = get_layer_id(name)
402
+ if (
403
+ layer_id is not None
404
+ and hasattr(self.model, "start_layer")
405
+ and (
406
+ layer_id < self.model.start_layer
407
+ or layer_id >= self.model.end_layer
408
+ )
409
+ ):
410
+ continue
411
+
351
412
  if "rotary_emb.inv_freq" in name:
352
413
  continue
353
414
 
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
398
459
  if name is None:
399
460
  continue
400
461
 
401
- param = params_dict[name]
402
- weight_loader = getattr(
403
- param, "weight_loader", default_weight_loader
404
- )
405
- weight_loader(param, loaded_weight)
462
+ if name in params_dict.keys():
463
+ param = params_dict[name]
464
+ weight_loader = getattr(
465
+ param, "weight_loader", default_weight_loader
466
+ )
467
+ weight_loader(param, loaded_weight)
468
+ else:
469
+ logger.warning(f"Parameter {name} not found in params_dict")
406
470
 
407
471
 
408
472
  EntryClass = MixtralForCausalLM
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
836
836
  prefix="multi_modal_projector",
837
837
  )
838
838
  self.logits_processor = LogitsProcessor(config.text_config)
839
- self.capture_mode = False
840
839
 
841
840
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
842
841
  pixel_values = torch.cat(
@@ -865,7 +864,6 @@ class MllamaForConditionalGeneration(nn.Module):
865
864
  pixel_values = torch.cat(
866
865
  [item.pixel_values for item in mm_input.mm_items], dim=0
867
866
  )
868
- # max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
869
867
  max_num_images = max(max_num_images, pixel_values.shape[1])
870
868
 
871
869
  max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
@@ -970,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
970
968
  positions: torch.Tensor,
971
969
  forward_batch: ForwardBatch,
972
970
  ) -> Union[Tuple, CausalLMOutputWithPast]:
971
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
972
+
973
973
  batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
974
974
  self._batch_image_inputs(forward_batch)
975
975
  )
@@ -978,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
978
978
  cross_attention_mask = None
979
979
  cross_attention_states = None
980
980
 
981
- if self.capture_mode:
981
+ if get_is_capture_mode():
982
982
  # NOTE: when doing cuda graph capture, we do not want to skip cross attention
983
983
  # Make is a constant value to avoid cuda graph capture issue
984
984
  skip_cross_attention = False