sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
307
307
  self.transformer = ExaoneModel(
308
308
  config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
309
309
  )
310
- self.lm_head = ParallelLMHead(
311
- config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
312
- )
310
+ if self.config.tie_word_embeddings:
311
+ self.lm_head = self.transformer.wte
312
+ else:
313
+ self.lm_head = ParallelLMHead(
314
+ config.vocab_size,
315
+ config.hidden_size,
316
+ prefix=add_prefix("lm_head", prefix),
317
+ )
313
318
  self.logits_processor = LogitsProcessor(config)
314
319
 
315
320
  @torch.no_grad()
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from transformers import AutoModel, Gemma3Config, PreTrainedModel
24
+ from transformers import Gemma3Config, PreTrainedModel
25
25
 
26
26
  from sglang.srt.hf_transformers_utils import get_processor
27
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
42
42
  maybe_remap_kv_scale_name,
43
43
  )
44
44
  from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
45
+ from sglang.srt.models.siglip import SiglipVisionModel
45
46
  from sglang.srt.utils import add_prefix
46
47
 
47
48
  logger = logging.getLogger(__name__)
@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
118
119
  ".k_proj.",
119
120
  ".v_proj.",
120
121
  ".o_proj.",
122
+ ".out_proj.",
121
123
  ]
122
124
  bitsandbytes_stacked_params_mapping = {
123
125
  # shard_name, weight_name, index
@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
126
128
  "v_proj": ("qkv_proj", 2),
127
129
  "gate_proj": ("gate_up_proj", 0),
128
130
  "up_proj": ("gate_up_proj", 1),
131
+ "out_proj": ("proj", 0),
129
132
  }
130
133
 
131
134
  packed_modules_mapping = {
@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
161
164
  super().__init__(config=config)
162
165
  self.config = config
163
166
  self.quant_config = quant_config
164
- # Vision components
165
- # TODO: replace with vision attention
166
- # self.vision_tower = SiglipVisionModel(
167
- # config.vision_config,
168
- # quant_config,
169
- # prefix=add_prefix("vision_tower", prefix),
170
- # )
171
- self.vision_tower = AutoModel.from_config(config=config.vision_config)
167
+
168
+ self.vision_tower = SiglipVisionModel(
169
+ config=config.vision_config,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("vision_tower", prefix),
172
+ )
173
+
172
174
  self.multi_modal_projector = Gemma3MultiModalProjector(config)
173
175
  self.vocab_size = config.text_config.vocab_size
174
176
 
175
177
  # Text model
176
178
  self.language_model = Gemma3ForCausalLM(
177
- config.text_config, quant_config, prefix=add_prefix("model", prefix)
179
+ config.text_config,
180
+ quant_config,
181
+ prefix=add_prefix("language_model", prefix),
178
182
  )
179
183
  if self.language_model.logits_processor.logit_scale:
180
184
  logit_scale = getattr(config, "logit_scale", 1.0)
@@ -278,13 +282,28 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
278
282
  Returns:
279
283
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
280
284
  """
281
- pixel_values = torch.stack(
282
- flatten_nested_list([item.pixel_values for item in items]), dim=0
283
- )
284
- pixel_values = pixel_values.to(device=self.vision_tower.device)
285
- pixel_values = pixel_values.to(dtype=self.language_model.dtype())
285
+ if any(item.precomputed_features is not None for item in items):
286
+ if not all(item.precomputed_features is not None for item in items):
287
+ raise NotImplementedError(
288
+ "MM inputs where only some items are precomputed."
289
+ )
290
+ return torch.concat([item.precomputed_features for item in items])
286
291
 
287
- vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
292
+ # Process images one by one to handle flatten_batch=True constraint in vision_tower
293
+ all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
294
+ vision_outputs_list = []
295
+
296
+ for pixel_value in all_pixel_values:
297
+ # Add batch dimension for single image processing
298
+ pixel_value_batch = pixel_value.unsqueeze(0)
299
+ pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
300
+ pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
301
+
302
+ vision_output = self.vision_tower(pixel_values=pixel_value_batch)
303
+ vision_outputs_list.append(vision_output)
304
+
305
+ # Concatenate all vision outputs
306
+ vision_outputs = torch.cat(vision_outputs_list, dim=0)
288
307
  image_features = self.multi_modal_projector(vision_outputs)
289
308
  return image_features
290
309
 
@@ -360,6 +379,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
360
379
  return self.language_model.tie_weights()
361
380
 
362
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382
+ stacked_params_mapping = [
383
+ # (param_name, shard_name, shard_id)
384
+ (".qkv_proj", ".q_proj", "q"),
385
+ (".qkv_proj", ".k_proj", "k"),
386
+ (".qkv_proj", ".v_proj", "v"),
387
+ ("gate_up_proj", "up_proj", 1),
388
+ ("gate_up_proj", "gate_proj", 0),
389
+ ]
363
390
  """Load weights for the model."""
364
391
  params_dict = dict(self.named_parameters())
365
392
  loaded_params: Set[str] = set()
@@ -373,21 +400,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
373
400
  loaded_params.update(causal_loaded_params)
374
401
  continue
375
402
  else:
376
- # Skip lm_head.weight as it's tied with embed_tokens
377
- if "lm_head.weight" in name:
378
- continue
379
-
380
- # Skip loading extra bias for GPTQ models
381
- if name.endswith(".bias") and name not in params_dict:
382
- continue
383
-
384
- # Remapping the name of FP8 kv-scale
385
- name = maybe_remap_kv_scale_name(name, params_dict)
386
- if name is None:
387
- continue
388
- param = params_dict[name]
389
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
390
- weight_loader(param, loaded_weight)
403
+ for param_name, weight_name, shard_id in stacked_params_mapping:
404
+ if weight_name not in name:
405
+ continue
406
+ name = name.replace(weight_name, param_name)
407
+ # Skip loading extra bias for GPTQ models.
408
+ if name.endswith(".bias") and name not in params_dict:
409
+ continue
410
+ param = params_dict[name]
411
+ weight_loader = param.weight_loader
412
+ weight_loader(param, loaded_weight, shard_id)
413
+ break
414
+ else:
415
+ if "vision_model" in name:
416
+ # adapt to VisionAttention
417
+ name = name.replace(".self_attn.out_proj", ".self_attn.proj")
418
+ # Skip loading extra bias for GPTQ models
419
+ if name.endswith(".bias") and name not in params_dict:
420
+ continue
421
+ # Remapping the name of FP8 kv-scale
422
+ name = maybe_remap_kv_scale_name(name, params_dict)
423
+ if name is None:
424
+ continue
425
+ param = params_dict[name]
426
+ weight_loader = getattr(
427
+ param, "weight_loader", default_weight_loader
428
+ )
429
+ weight_loader(param, loaded_weight)
391
430
  loaded_params.add(name)
392
431
  unloaded_params = params_dict.keys() - loaded_params
393
432
  if unloaded_params:
@@ -398,5 +437,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
398
437
 
399
438
 
400
439
  EntryClass = Gemma3ForConditionalGeneration
401
-
402
- AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
52
52
  PPProxyTensors,
53
53
  )
54
54
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
55
- from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
55
+ from sglang.srt.utils import (
56
+ add_prefix,
57
+ fast_topk,
58
+ get_compiler_backend,
59
+ is_cuda,
60
+ make_layers,
61
+ )
62
+
63
+ _is_cuda = is_cuda()
56
64
 
57
65
  logger = logging.getLogger(__name__)
58
66
 
@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
131
139
  return out_aD
132
140
 
133
141
  def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134
- if hidden_states.shape[0] < 4:
142
+ if hidden_states.shape[0] < 4 and _is_cuda:
135
143
  return self._forward_core_shared_routed_overlap(hidden_states)
136
144
  else:
137
145
  return self._forward_core_normal(hidden_states)
@@ -135,7 +135,6 @@ class LlavaBaseForCausalLM(nn.Module):
135
135
  """
136
136
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
137
137
  # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
138
-
139
138
  selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
140
139
  if self.vision_feature_select_strategy in ["default", "patch"]:
141
140
  selected_image_feature = selected_image_feature[:, 1:]
@@ -146,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
146
145
  f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
147
146
  )
148
147
  image_features = self.multi_modal_projector(selected_image_feature)
149
-
150
148
  return image_features
151
149
 
152
150
  @torch.no_grad()
@@ -613,6 +611,10 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
613
611
 
614
612
  MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
615
613
 
614
+ @property
615
+ def dtype(self):
616
+ return self.torch_dtype
617
+
616
618
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
617
619
  if hasattr(self.vision_tower, "pad_input_ids"):
618
620
  return self.vision_tower.pad_input_ids(input_ids, image_inputs)
@@ -672,11 +674,17 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
672
674
  assert hasattr(config, "text_config")
673
675
  assert hasattr(config, "vision_config")
674
676
  self.config = config
675
- self.text_config = config.text_config
676
- self.vision_config = config.vision_config
677
+ self.text_config = self.config.text_config
678
+ self.vision_config = self.config.vision_config
679
+ self.torch_dtype = getattr(self.config, "torch_dtype")
680
+
681
+ if not getattr(self.text_config, "torch_dtype"):
682
+ self.text_config.torch_dtype = self.torch_dtype
683
+ if not getattr(self.vision_config, "torch_dtype"):
684
+ self.vision_config.torch_dtype = self.torch_dtype
677
685
 
678
686
  if not hasattr(self.config, "vocab_size"):
679
- self.config.vocab_size = self.config.text_config.vocab_size
687
+ self.config.vocab_size = self.text_config.vocab_size
680
688
  if not hasattr(self.config, "image_aspect_ratio"):
681
689
  self.config.image_aspect_ratio = "anyres"
682
690
  if not hasattr(self.config, "image_grid_pinpoints"):
@@ -697,39 +705,39 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
697
705
  if not hasattr(self.config, "projector_hidden_act"):
698
706
  self.config.projector_hidden_act = "gelu"
699
707
 
700
- self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
708
+ self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
701
709
  self.vision_feature_select_strategy = getattr(
702
- config, "vision_feature_select_strategy", "full"
710
+ self.config, "vision_feature_select_strategy", "full"
703
711
  )
704
- self.image_size = self.config.vision_config.image_size
705
- self.patch_size = self.config.vision_config.patch_size
712
+ self.image_size = self.vision_config.image_size
713
+ self.patch_size = self.vision_config.patch_size
706
714
 
707
- self.mm_patch_merge_type = config.mm_patch_merge_type
708
- self.image_aspect_ratio = config.image_aspect_ratio
709
- self.image_grid_pinpoints = config.image_grid_pinpoints
715
+ self.mm_patch_merge_type = self.config.mm_patch_merge_type
716
+ self.image_aspect_ratio = self.config.image_aspect_ratio
717
+ self.image_grid_pinpoints = self.config.image_grid_pinpoints
710
718
 
711
719
  self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
712
720
 
713
721
  self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
714
722
 
715
723
  language_model_cls = self._get_sgl_model_cls(
716
- config.text_config, AutoModelForCausalLM
724
+ self.text_config, AutoModelForCausalLM
717
725
  )
718
- vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
726
+ vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
719
727
  self.language_model = language_model_cls(
720
- config.text_config,
728
+ self.text_config,
721
729
  quant_config=quant_config,
722
730
  prefix=add_prefix("language_model", prefix),
723
731
  )
724
732
  self.vision_tower = vision_model_cls(
725
- config.vision_config,
733
+ self.vision_config,
726
734
  quant_config=quant_config,
727
735
  prefix=add_prefix("vision_tower", prefix),
728
736
  )
729
737
 
730
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
738
+ if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
731
739
  self.language_model.model.image_newline = nn.Parameter(
732
- torch.empty(config.text_config.hidden_size, dtype=torch.float16)
740
+ torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
733
741
  )
734
742
 
735
743
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
@@ -0,0 +1,220 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+
10
+ from sglang.srt.distributed import (
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ split_tensor_along_last_dim,
14
+ tensor_model_parallel_all_gather,
15
+ )
16
+ from sglang.srt.layers.layernorm import RMSNorm
17
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18
+ from sglang.srt.layers.logits_processor import LogitsProcessor
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.layers.rotary_embedding import get_rope
23
+ from sglang.srt.layers.vocab_parallel_embedding import (
24
+ ParallelLMHead,
25
+ VocabParallelEmbedding,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.models.mimo import MiMoForCausalLM
30
+ from sglang.srt.models.qwen2 import (
31
+ Qwen2Attention,
32
+ Qwen2DecoderLayer,
33
+ Qwen2MLP,
34
+ Qwen2Model,
35
+ )
36
+ from sglang.srt.utils import add_prefix
37
+
38
+
39
+ class MiMoMultiTokenPredictorLayer(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ prefix: str,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ self.embed_tokens = VocabParallelEmbedding(
50
+ config.vocab_size,
51
+ config.hidden_size,
52
+ )
53
+ self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
54
+ self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
55
+ self.input_proj = nn.Linear(
56
+ config.hidden_size * 2, config.hidden_size, bias=False
57
+ )
58
+ self.mtp_block = Qwen2DecoderLayer(
59
+ config=config, quant_config=quant_config, prefix=prefix
60
+ )
61
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.Tensor,
66
+ positions: torch.Tensor,
67
+ forward_batch: ForwardBatch,
68
+ input_embeds: torch.Tensor = None,
69
+ ) -> torch.Tensor:
70
+
71
+ if input_embeds is None:
72
+ hidden_states = self.embed_tokens(input_ids)
73
+ else:
74
+ hidden_states = input_embeds
75
+ # masking inputs at position 0, as not needed by MTP
76
+ hidden_states[positions == 0] = 0
77
+
78
+ hidden_states = self.input_proj(
79
+ torch.cat(
80
+ (
81
+ self.hidden_layernorm(forward_batch.spec_info.hidden_states),
82
+ self.token_layernorm(hidden_states),
83
+ ),
84
+ dim=-1,
85
+ )
86
+ )
87
+
88
+ hidden_states, residual = self.mtp_block(
89
+ positions=positions,
90
+ hidden_states=hidden_states,
91
+ forward_batch=forward_batch,
92
+ residual=None,
93
+ )
94
+ hidden_states = residual + hidden_states
95
+ hidden_states = self.final_layernorm(hidden_states)
96
+ return hidden_states
97
+
98
+
99
+ class MiMoMTP(nn.Module):
100
+ def __init__(
101
+ self,
102
+ config: PretrainedConfig,
103
+ quant_config: Optional[QuantizationConfig] = None,
104
+ prefix: str = "",
105
+ ) -> None:
106
+ nn.Module.__init__(self)
107
+ self.config = config
108
+ self.tp_size = get_tensor_model_parallel_world_size()
109
+ self.quant_config = quant_config
110
+
111
+ self.model = MiMoMultiTokenPredictorLayer(
112
+ config,
113
+ prefix,
114
+ quant_config,
115
+ )
116
+ self.lm_head = ParallelLMHead(
117
+ config.vocab_size,
118
+ config.hidden_size,
119
+ quant_config=quant_config,
120
+ )
121
+ self.logits_processor = LogitsProcessor(config)
122
+
123
+ @torch.no_grad()
124
+ def forward(
125
+ self,
126
+ input_ids: torch.Tensor,
127
+ positions: torch.Tensor,
128
+ forward_batch: ForwardBatch,
129
+ ) -> torch.Tensor:
130
+ hidden_states = self.model(input_ids, positions, forward_batch)
131
+ return self.logits_processor(
132
+ input_ids, hidden_states, self.lm_head, forward_batch
133
+ )
134
+
135
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
136
+ stacked_params_mapping = [
137
+ # (param_name, shard_name, shard_id)
138
+ ("qkv_proj", "q_proj", "q"),
139
+ ("qkv_proj", "k_proj", "k"),
140
+ ("qkv_proj", "v_proj", "v"),
141
+ ("gate_up_proj", "gate_proj", 0),
142
+ ("gate_up_proj", "up_proj", 1),
143
+ ]
144
+
145
+ params_dict = dict(self.named_parameters())
146
+ for name, loaded_weight in weights:
147
+ if "rotary_emb.inv_freq" in name or "projector" in name:
148
+ continue
149
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
150
+ # Models trained using ColossalAI may include these tensors in
151
+ # the checkpoint. Skip them.
152
+ continue
153
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
154
+ continue
155
+ if name.startswith("model.vision_tower") and name not in params_dict:
156
+ continue
157
+ name = self.map_model_name_to_mtp_param_name(name)
158
+
159
+ for param_name, weight_name, shard_id in stacked_params_mapping:
160
+ if weight_name not in name:
161
+ continue
162
+ if "mtp_block" not in name:
163
+ break
164
+ name = name.replace(weight_name, param_name)
165
+ # Skip loading extra bias for GPTQ models.
166
+ if name.endswith(".bias") and name not in params_dict:
167
+ continue
168
+ param = params_dict[name]
169
+ weight_loader = param.weight_loader
170
+ weight_loader(param, loaded_weight, shard_id)
171
+ break
172
+ else:
173
+ # Skip loading extra bias for GPTQ models.
174
+ if name.endswith(".bias") and name not in params_dict:
175
+ continue
176
+ if "mtp_block" not in name and (
177
+ "embed_tokens" not in name
178
+ and "lm_head" not in name
179
+ and "token_layernorm" not in name
180
+ and "hidden_layernorm" not in name
181
+ and "input_proj" not in name
182
+ and "final_layernorm" not in name
183
+ ):
184
+ continue
185
+ param = params_dict[name]
186
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
187
+ weight_loader(param, loaded_weight)
188
+
189
+ def map_model_name_to_mtp_param_name(self, name: str) -> str:
190
+ import re
191
+
192
+ name_without_prefix = [
193
+ "token_layernorm",
194
+ "hidden_layernorm",
195
+ "input_proj",
196
+ "final_layernorm",
197
+ ]
198
+ pattern = r"model.mtp_layers.(\d+)."
199
+ group = re.match(pattern, name)
200
+ if group is not None:
201
+ for sub_name in name_without_prefix:
202
+ if sub_name in name:
203
+ name = name.replace(group.group(), "model.")
204
+ return name
205
+ name = name.replace(group.group(), "model.mtp_block.")
206
+ return name
207
+
208
+ def get_embed_and_head(self):
209
+ return self.model.embed_tokens.weight, self.lm_head.weight
210
+
211
+ def set_embed_and_head(self, embed, head):
212
+ del self.model.embed_tokens.weight
213
+ del self.lm_head.weight
214
+ self.model.embed_tokens.weight = embed
215
+ self.lm_head.weight = head
216
+ torch.cuda.empty_cache()
217
+ torch.cuda.synchronize()
218
+
219
+
220
+ EntryClass = MiMoMTP
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
1520
1520
  slice_start_id: int = mm_input.slice_start_id
1521
1521
  slice_end_id: int = mm_input.slice_end_id
1522
1522
 
1523
- media_token_pairs = [
1523
+ data_token_pairs = [
1524
1524
  (im_start_id, im_end_id),
1525
1525
  (slice_start_id, slice_end_id),
1526
1526
  (mm_input.audio_start_id, mm_input.audio_end_id),
1527
1527
  ]
1528
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
1528
+ data_start_token_ids = [im_start_id, mm_input.audio_start_id]
1529
+ pattern = MultiModalityDataPaddingPatternTokenPairs(
1530
+ data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
1531
+ )
1529
1532
 
1530
1533
  return pattern.pad_input_tokens(input_ids, mm_input)
1531
1534
 
@@ -1823,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
1823
1826
  **kwargs: Any,
1824
1827
  ) -> torch.Tensor:
1825
1828
 
1826
- mm_input = forward_batch.merge_mm_inputs()
1827
- placeholder_token_ids = (
1828
- ([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
1829
- if forward_batch.contains_mm_inputs()
1830
- else []
1831
- )
1832
1829
  hidden_states = general_mm_embed_routine(
1833
1830
  input_ids=input_ids,
1834
1831
  forward_batch=forward_batch,
1835
1832
  language_model=self.llm,
1836
1833
  image_data_embedding_func=self.get_image_feature,
1837
1834
  audio_data_embedding_func=self.get_audio_feature,
1838
- placeholder_tokens={
1839
- Modality.IMAGE: placeholder_token_ids,
1840
- Modality.AUDIO: placeholder_token_ids,
1841
- },
1842
1835
  positions=positions,
1843
1836
  )
1844
1837
  return hidden_states
@@ -13,6 +13,12 @@
13
13
  # ==============================================================================
14
14
  """Inference-only Mistral model."""
15
15
 
16
+ from typing import List, Union
17
+
18
+ import torch
19
+ from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
20
+
21
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem
16
22
  from sglang.srt.models.llama import LlamaForCausalLM
17
23
 
18
24
 
@@ -20,4 +26,68 @@ class MistralForCausalLM(LlamaForCausalLM):
20
26
  pass
21
27
 
22
28
 
23
- EntryClass = MistralForCausalLM
29
+ class Mistral3ForConditionalGeneration:
30
+ MULTIMODAL_PROJECTOR_TYPE = Mistral3MultiModalProjector
31
+
32
+ def __init__(self, **kwargs):
33
+ # lazy load inner class
34
+ # to bypass circular import
35
+ from sglang.srt.models.llava import LlavaForConditionalGeneration
36
+
37
+ # override config: mistral's projector adds patchmerger that doesn't require padding
38
+ kwargs["config"].vision_config.pad_image_border = False
39
+
40
+ self.inner = LlavaForConditionalGeneration(**kwargs)
41
+ self.inner.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(
42
+ kwargs["config"]
43
+ )
44
+ self.inner.get_image_feature = self.get_image_feature
45
+
46
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
47
+ """Extract features from image inputs.
48
+
49
+ Args:
50
+ items: List of MultimodalDataItem objects containing image data
51
+ Note that an item can be either "image" or "multi-images"
52
+
53
+ Returns:
54
+ torch.Tensor: features from image inputs, concatenated
55
+ """
56
+ features = []
57
+ for item in items:
58
+ # in each item, we assume pixel_values is always batched
59
+ pixel_values, image_sizes = item.pixel_values, item.image_sizes
60
+ image_outputs = self.vision_tower(
61
+ pixel_values, image_sizes, output_hidden_states=True
62
+ )
63
+ selected_image_feature = image_outputs.hidden_states[
64
+ self.vision_feature_layer
65
+ ]
66
+
67
+ if self.vision_feature_select_strategy in ["default", "patch"]:
68
+ selected_image_feature = selected_image_feature[:, 1:]
69
+ elif self.vision_feature_select_strategy == "full":
70
+ selected_image_feature = selected_image_feature
71
+ else:
72
+ raise ValueError(
73
+ f"Unexpected select feature: {self.vision_feature_select_strategy}"
74
+ )
75
+ features.append(
76
+ self.multi_modal_projector(
77
+ selected_image_feature.squeeze(0), image_sizes
78
+ )
79
+ )
80
+ ret = torch.cat(features, dim=0)
81
+ return ret
82
+
83
+ def __getattr__(self, name):
84
+ return getattr(self.inner, name)
85
+
86
+ def __hasattr__(self, name):
87
+ return hasattr(self.inner, name)
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ return self.inner(*args, **kwargs)
91
+
92
+
93
+ EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration]