sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
307
307
  self.transformer = ExaoneModel(
308
308
  config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
309
309
  )
310
- self.lm_head = ParallelLMHead(
311
- config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
312
- )
310
+ if self.config.tie_word_embeddings:
311
+ self.lm_head = self.transformer.wte
312
+ else:
313
+ self.lm_head = ParallelLMHead(
314
+ config.vocab_size,
315
+ config.hidden_size,
316
+ prefix=add_prefix("lm_head", prefix),
317
+ )
313
318
  self.logits_processor = LogitsProcessor(config)
314
319
 
315
320
  @torch.no_grad()
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from transformers import AutoModel, Gemma3Config, PreTrainedModel
24
+ from transformers import Gemma3Config, PreTrainedModel
25
25
 
26
26
  from sglang.srt.hf_transformers_utils import get_processor
27
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
42
42
  maybe_remap_kv_scale_name,
43
43
  )
44
44
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
+ from sglang.srt.models.siglip import SiglipVisionModel
45
46
  from sglang.srt.utils import add_prefix
46
47
 
47
48
  logger = logging.getLogger(__name__)
@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
118
119
  ".k_proj.",
119
120
  ".v_proj.",
120
121
  ".o_proj.",
122
+ ".out_proj.",
121
123
  ]
122
124
  bitsandbytes_stacked_params_mapping = {
123
125
  # shard_name, weight_name, index
@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
126
128
  "v_proj": ("qkv_proj", 2),
127
129
  "gate_proj": ("gate_up_proj", 0),
128
130
  "up_proj": ("gate_up_proj", 1),
131
+ "out_proj": ("proj", 0),
129
132
  }
130
133
 
131
134
  packed_modules_mapping = {
@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
161
164
  super().__init__(config=config)
162
165
  self.config = config
163
166
  self.quant_config = quant_config
164
- # Vision components
165
- # TODO: replace with vision attention
166
- # self.vision_tower = SiglipVisionModel(
167
- # config.vision_config,
168
- # quant_config,
169
- # prefix=add_prefix("vision_tower", prefix),
170
- # )
171
- self.vision_tower = AutoModel.from_config(config=config.vision_config)
167
+
168
+ self.vision_tower = SiglipVisionModel(
169
+ config=config.vision_config,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("vision_tower", prefix),
172
+ )
173
+
172
174
  self.multi_modal_projector = Gemma3MultiModalProjector(config)
173
175
  self.vocab_size = config.text_config.vocab_size
174
176
 
175
177
  # Text model
176
178
  self.language_model = Gemma3ForCausalLM(
177
- config.text_config, quant_config, prefix=add_prefix("model", prefix)
179
+ config.text_config,
180
+ quant_config,
181
+ prefix=add_prefix("language_model", prefix),
178
182
  )
179
183
  if self.language_model.logits_processor.logit_scale:
180
184
  logit_scale = getattr(config, "logit_scale", 1.0)
@@ -278,13 +282,28 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
278
282
  Returns:
279
283
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
280
284
  """
281
- pixel_values = torch.stack(
282
- flatten_nested_list([item.pixel_values for item in items]), dim=0
283
- )
284
- pixel_values = pixel_values.to(device=self.vision_tower.device)
285
- pixel_values = pixel_values.to(dtype=self.language_model.dtype())
285
+ if any(item.precomputed_features is not None for item in items):
286
+ if not all(item.precomputed_features is not None for item in items):
287
+ raise NotImplementedError(
288
+ "MM inputs where only some items are precomputed."
289
+ )
290
+ return torch.concat([item.precomputed_features for item in items])
286
291
 
287
- vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
292
+ # Process images one by one to handle flatten_batch=True constraint in vision_tower
293
+ all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
294
+ vision_outputs_list = []
295
+
296
+ for pixel_value in all_pixel_values:
297
+ # Add batch dimension for single image processing
298
+ pixel_value_batch = pixel_value.unsqueeze(0)
299
+ pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
300
+ pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
301
+
302
+ vision_output = self.vision_tower(pixel_values=pixel_value_batch)
303
+ vision_outputs_list.append(vision_output)
304
+
305
+ # Concatenate all vision outputs
306
+ vision_outputs = torch.cat(vision_outputs_list, dim=0)
288
307
  image_features = self.multi_modal_projector(vision_outputs)
289
308
  return image_features
290
309
 
@@ -360,6 +379,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
360
379
  return self.language_model.tie_weights()
361
380
 
362
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382
+ stacked_params_mapping = [
383
+ # (param_name, shard_name, shard_id)
384
+ (".qkv_proj", ".q_proj", "q"),
385
+ (".qkv_proj", ".k_proj", "k"),
386
+ (".qkv_proj", ".v_proj", "v"),
387
+ ("gate_up_proj", "up_proj", 1),
388
+ ("gate_up_proj", "gate_proj", 0),
389
+ ]
363
390
  """Load weights for the model."""
364
391
  params_dict = dict(self.named_parameters())
365
392
  loaded_params: Set[str] = set()
@@ -373,21 +400,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
373
400
  loaded_params.update(causal_loaded_params)
374
401
  continue
375
402
  else:
376
- # Skip lm_head.weight as it's tied with embed_tokens
377
- if "lm_head.weight" in name:
378
- continue
379
-
380
- # Skip loading extra bias for GPTQ models
381
- if name.endswith(".bias") and name not in params_dict:
382
- continue
383
-
384
- # Remapping the name of FP8 kv-scale
385
- name = maybe_remap_kv_scale_name(name, params_dict)
386
- if name is None:
387
- continue
388
- param = params_dict[name]
389
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
390
- weight_loader(param, loaded_weight)
403
+ for param_name, weight_name, shard_id in stacked_params_mapping:
404
+ if weight_name not in name:
405
+ continue
406
+ name = name.replace(weight_name, param_name)
407
+ # Skip loading extra bias for GPTQ models.
408
+ if name.endswith(".bias") and name not in params_dict:
409
+ continue
410
+ param = params_dict[name]
411
+ weight_loader = param.weight_loader
412
+ weight_loader(param, loaded_weight, shard_id)
413
+ break
414
+ else:
415
+ if "vision_model" in name:
416
+ # adapt to VisionAttention
417
+ name = name.replace(".self_attn.out_proj", ".self_attn.proj")
418
+ # Skip loading extra bias for GPTQ models
419
+ if name.endswith(".bias") and name not in params_dict:
420
+ continue
421
+ # Remapping the name of FP8 kv-scale
422
+ name = maybe_remap_kv_scale_name(name, params_dict)
423
+ if name is None:
424
+ continue
425
+ param = params_dict[name]
426
+ weight_loader = getattr(
427
+ param, "weight_loader", default_weight_loader
428
+ )
429
+ weight_loader(param, loaded_weight)
391
430
  loaded_params.add(name)
392
431
  unloaded_params = params_dict.keys() - loaded_params
393
432
  if unloaded_params:
@@ -398,5 +437,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
398
437
 
399
438
 
400
439
  EntryClass = Gemma3ForConditionalGeneration
401
-
402
- AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  ParallelLMHead,
46
46
  VocabParallelEmbedding,
47
47
  )
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49
50
  from sglang.srt.model_loader.weight_utils import (
50
51
  default_weight_loader,
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
420
421
  config.hidden_size,
421
422
  quant_config=quant_config,
422
423
  prefix=add_prefix("lm_head", prefix),
424
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
423
425
  )
424
426
  self.logits_processor = LogitsProcessor(config)
425
427
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
30
30
  from sglang.srt.layers.dp_attention import (
31
31
  dp_gather_partial,
32
32
  dp_scatter,
33
- get_attention_dp_size,
34
33
  get_attention_tp_rank,
35
34
  get_attention_tp_size,
35
+ get_local_attention_dp_size,
36
36
  )
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.linear import (
@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
52
52
  PPProxyTensors,
53
53
  )
54
54
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
55
- from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
55
+ from sglang.srt.utils import (
56
+ add_prefix,
57
+ fast_topk,
58
+ get_compiler_backend,
59
+ is_cuda,
60
+ make_layers,
61
+ )
62
+
63
+ _is_cuda = is_cuda()
56
64
 
57
65
  logger = logging.getLogger(__name__)
58
66
 
@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
131
139
  return out_aD
132
140
 
133
141
  def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134
- if hidden_states.shape[0] < 4:
142
+ if hidden_states.shape[0] < 4 and _is_cuda:
135
143
  return self._forward_core_shared_routed_overlap(hidden_states)
136
144
  else:
137
145
  return self._forward_core_normal(hidden_states)
@@ -198,7 +206,6 @@ class Llama4Attention(nn.Module):
198
206
  self.use_rope = int((layer_id + 1) % 4 != 0)
199
207
  self.use_qk_norm = config.use_qk_norm and self.use_rope
200
208
 
201
- self.dp_size = get_attention_dp_size()
202
209
  attn_tp_rank = get_attention_tp_rank()
203
210
  attn_tp_size = get_attention_tp_size()
204
211
 
@@ -342,7 +349,7 @@ class Llama4DecoderLayer(nn.Module):
342
349
  rope_theta = config.rope_theta
343
350
  rope_scaling = config.rope_scaling
344
351
  max_position_embeddings = config.max_position_embeddings
345
- self.dp_size = get_attention_dp_size()
352
+ self.local_dp_size = get_local_attention_dp_size()
346
353
  self.attn_tp_size = get_attention_tp_size()
347
354
  self.attn_tp_rank = get_attention_tp_rank()
348
355
 
@@ -405,7 +412,7 @@ class Llama4DecoderLayer(nn.Module):
405
412
  # Gather
406
413
  if get_tensor_model_parallel_world_size() > 1:
407
414
  # all gather and all reduce
408
- if self.dp_size != 1:
415
+ if self.local_dp_size != 1:
409
416
  if self.attn_tp_rank == 0:
410
417
  hidden_states += residual
411
418
  hidden_states, local_hidden_states = (
@@ -428,9 +435,9 @@ class Llama4DecoderLayer(nn.Module):
428
435
  # Fully Connected
429
436
  hidden_states = self.feed_forward(hidden_states, forward_batch)
430
437
 
431
- # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
438
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
432
439
  # Scatter
433
- if self.dp_size != 1:
440
+ if self.local_dp_size != 1:
434
441
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
435
442
  # be careful about this!
436
443
  hidden_states, global_hidden_states = (
@@ -15,7 +15,8 @@
15
15
 
16
16
  import math
17
17
  import re
18
- from typing import Iterable, List, Optional, Tuple
18
+ from functools import lru_cache
19
+ from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
19
20
 
20
21
  import numpy as np
21
22
  import torch
@@ -28,10 +29,18 @@ from transformers import (
28
29
  Qwen2Config,
29
30
  SiglipVisionModel,
30
31
  )
32
+ from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
31
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
34
 
35
+ # leave till last and symbol only in case circular import
36
+ import sglang.srt.models as sgl_models
33
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
38
+ from sglang.srt.managers.mm_utils import general_mm_embed_routine
39
+ from sglang.srt.managers.schedule_batch import (
40
+ Modality,
41
+ MultimodalDataItem,
42
+ MultimodalInputs,
43
+ )
35
44
  from sglang.srt.mm_utils import (
36
45
  get_anyres_image_grid_shape,
37
46
  unpad_image,
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
51
  from sglang.srt.models.llama import LlamaForCausalLM
43
52
  from sglang.srt.models.mistral import MistralForCausalLM
44
53
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
- from sglang.srt.utils import add_prefix, flatten_nested_list
54
+ from sglang.srt.utils import add_prefix, flatten_nested_list, logger
46
55
 
47
56
 
48
57
  class LlavaBaseForCausalLM(nn.Module):
@@ -114,10 +123,18 @@ class LlavaBaseForCausalLM(nn.Module):
114
123
  image_inputs.image_offsets = offset_list
115
124
  return input_ids
116
125
 
117
- def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
126
+ def encode_images(
127
+ self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
128
+ ) -> torch.Tensor:
129
+ """
130
+ encode images by vision tower and multimodal projector
131
+ Args:
132
+ pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
133
+ Returns:
134
+ torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
135
+ """
118
136
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
119
137
  # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
120
-
121
138
  selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
122
139
  if self.vision_feature_select_strategy in ["default", "patch"]:
123
140
  selected_image_feature = selected_image_feature[:, 1:]
@@ -128,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
128
145
  f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
129
146
  )
130
147
  image_features = self.multi_modal_projector(selected_image_feature)
131
-
132
148
  return image_features
133
149
 
134
150
  @torch.no_grad()
@@ -583,4 +599,239 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
583
599
  )
584
600
 
585
601
 
586
- EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
602
+ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
603
+ """
604
+ An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
605
+ It follows the structure of (vision_tower, multi_modal_projector, language_model)
606
+
607
+ Once a model config is loaded, text_config and vision_config will be extracted, and
608
+ LlavaForConditionalGeneration will load the language_model and vision_tower models
609
+ according to config.
610
+ """
611
+
612
+ MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
613
+
614
+ @property
615
+ def dtype(self):
616
+ return self.torch_dtype
617
+
618
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
619
+ if hasattr(self.vision_tower, "pad_input_ids"):
620
+ return self.vision_tower.pad_input_ids(input_ids, image_inputs)
621
+ else:
622
+ return super().pad_input_ids(input_ids, image_inputs)
623
+
624
+ def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
625
+ """
626
+ Get the SGLang model implementation class according to config.
627
+
628
+ Args:
629
+ config: The config object of the model.
630
+ auto_model_type: The type of the auto model.
631
+
632
+ Returns:
633
+ The SGLang model implementation class.
634
+ """
635
+ config_cls_name = config.__class__.__name__
636
+ arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
637
+ if arch := arch_name_mapping.get(config_cls_name):
638
+ if isinstance(arch, tuple):
639
+ arch = arch[0]
640
+ logger.warning(
641
+ f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
642
+ )
643
+ try:
644
+ return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
645
+ except Exception as e:
646
+ raise ValueError(
647
+ f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
648
+ )
649
+ else:
650
+ raise ValueError(
651
+ f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
652
+ )
653
+
654
+ @lru_cache
655
+ def _config_cls_name_to_arch_name_mapping(
656
+ self, auto_model_type: Type[AutoModel]
657
+ ) -> Dict[str, str]:
658
+ mapping = {}
659
+ for config_cls, archs in auto_model_type._model_mapping.items():
660
+ if isinstance(archs, tuple):
661
+ mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
662
+ else:
663
+ mapping[config_cls.__name__] = archs.__name__
664
+ return mapping
665
+
666
+ def __init__(
667
+ self,
668
+ config: LlavaConfig,
669
+ quant_config: Optional[QuantizationConfig] = None,
670
+ prefix: str = "",
671
+ ) -> None:
672
+ super().__init__()
673
+
674
+ assert hasattr(config, "text_config")
675
+ assert hasattr(config, "vision_config")
676
+ self.config = config
677
+ self.text_config = self.config.text_config
678
+ self.vision_config = self.config.vision_config
679
+ self.torch_dtype = getattr(self.config, "torch_dtype")
680
+
681
+ if not getattr(self.text_config, "torch_dtype"):
682
+ self.text_config.torch_dtype = self.torch_dtype
683
+ if not getattr(self.vision_config, "torch_dtype"):
684
+ self.vision_config.torch_dtype = self.torch_dtype
685
+
686
+ if not hasattr(self.config, "vocab_size"):
687
+ self.config.vocab_size = self.text_config.vocab_size
688
+ if not hasattr(self.config, "image_aspect_ratio"):
689
+ self.config.image_aspect_ratio = "anyres"
690
+ if not hasattr(self.config, "image_grid_pinpoints"):
691
+ # from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
692
+ # self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
693
+ self.config.image_grid_pinpoints = [
694
+ [96, 96],
695
+ [224, 224],
696
+ [384, 384],
697
+ [512, 512],
698
+ [768, 768],
699
+ [1024, 1024],
700
+ ]
701
+ if not hasattr(self.config, "mm_patch_merge_type"):
702
+ self.config.mm_patch_merge_type = "flat"
703
+ if not hasattr(self.config, "image_token_index"):
704
+ self.config.image_token_index = 10
705
+ if not hasattr(self.config, "projector_hidden_act"):
706
+ self.config.projector_hidden_act = "gelu"
707
+
708
+ self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
709
+ self.vision_feature_select_strategy = getattr(
710
+ self.config, "vision_feature_select_strategy", "full"
711
+ )
712
+ self.image_size = self.vision_config.image_size
713
+ self.patch_size = self.vision_config.patch_size
714
+
715
+ self.mm_patch_merge_type = self.config.mm_patch_merge_type
716
+ self.image_aspect_ratio = self.config.image_aspect_ratio
717
+ self.image_grid_pinpoints = self.config.image_grid_pinpoints
718
+
719
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
720
+
721
+ self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
722
+
723
+ language_model_cls = self._get_sgl_model_cls(
724
+ self.text_config, AutoModelForCausalLM
725
+ )
726
+ vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
727
+ self.language_model = language_model_cls(
728
+ self.text_config,
729
+ quant_config=quant_config,
730
+ prefix=add_prefix("language_model", prefix),
731
+ )
732
+ self.vision_tower = vision_model_cls(
733
+ self.vision_config,
734
+ quant_config=quant_config,
735
+ prefix=add_prefix("vision_tower", prefix),
736
+ )
737
+
738
+ if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
739
+ self.language_model.model.image_newline = nn.Parameter(
740
+ torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
741
+ )
742
+
743
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
744
+ """Extract features from image inputs.
745
+
746
+ Args:
747
+ items: List of MultimodalDataItem objects containing image data
748
+ Note that an item can be either "image" or "multi-images"
749
+
750
+ Returns:
751
+ torch.Tensor: features from image inputs, concatenated
752
+ """
753
+ features = []
754
+ for item in items:
755
+ # in each item, we assume pixel_values is always batched
756
+ pixel_values, image_sizes = item.pixel_values, item.image_sizes
757
+ image_outputs = self.vision_tower(
758
+ pixel_values, image_sizes, output_hidden_states=True
759
+ )
760
+ selected_image_feature = image_outputs.hidden_states[
761
+ self.vision_feature_layer
762
+ ]
763
+
764
+ if self.vision_feature_select_strategy in ["default", "patch"]:
765
+ selected_image_feature = selected_image_feature[:, 1:]
766
+ elif self.vision_feature_select_strategy == "full":
767
+ selected_image_feature = selected_image_feature
768
+ else:
769
+ raise ValueError(
770
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
771
+ )
772
+ features.append(
773
+ self.multi_modal_projector(selected_image_feature.squeeze(0))
774
+ )
775
+ ret = torch.cat(features, dim=0)
776
+ return ret
777
+
778
+ def forward(
779
+ self,
780
+ input_ids: torch.Tensor,
781
+ positions: torch.Tensor,
782
+ forward_batch: ForwardBatch,
783
+ get_embedding: bool = False,
784
+ ):
785
+ hidden_states = general_mm_embed_routine(
786
+ input_ids=input_ids,
787
+ forward_batch=forward_batch,
788
+ get_embedding=get_embedding,
789
+ language_model=self.language_model,
790
+ image_data_embedding_func=self.get_image_feature,
791
+ placeholder_tokens=None, # using mm_item.pad_value
792
+ positions=positions,
793
+ )
794
+
795
+ return hidden_states
796
+
797
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
798
+ """Load weights for LlavaForConditionalGeneration.
799
+
800
+ Unlike the base class implementation, this one doesn't need to handle
801
+ weight name remapping as the weights are already properly structured with
802
+ 'language_model' and 'vision_tower' prefixes in the safetensors files.
803
+ """
804
+ if (
805
+ self.vision_feature_select_strategy == "patch"
806
+ or self.vision_feature_select_strategy == "full"
807
+ ):
808
+ pass
809
+ elif self.vision_feature_select_strategy == "cls_patch":
810
+ self.image_feature_len += 1
811
+ else:
812
+ raise ValueError(
813
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
814
+ )
815
+
816
+ # Create dictionaries for direct parameter loading
817
+ params_dict = dict(self.named_parameters())
818
+
819
+ # Load weights directly without remapping
820
+ for name, loaded_weight in weights:
821
+ for part in ("language_model", "vision_tower"):
822
+ if name.startswith(part):
823
+ name = name[len(part + ".") :]
824
+ getattr(self, part).load_weights([(name, loaded_weight)])
825
+ break
826
+ else:
827
+ param = params_dict[name]
828
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
829
+ weight_loader(param, loaded_weight)
830
+
831
+
832
+ EntryClass = [
833
+ LlavaLlamaForCausalLM,
834
+ LlavaQwenForCausalLM,
835
+ LlavaMistralForCausalLM,
836
+ LlavaForConditionalGeneration,
837
+ ]