sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,9 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import PretrainedConfig
6
6
 
7
- from sglang.srt.distributed import parallel_state
7
+ from sglang.srt.layers.attention import vision_utils
8
8
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
9
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
9
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
11
  from sglang.srt.managers.mm_utils import (
11
12
  MultiModalityDataPaddingPatternTokenPairs,
@@ -20,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
21
  from sglang.srt.model_loader.weight_utils import default_weight_loader
21
22
  from sglang.srt.models.internvl import InternVisionModel
22
23
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
24
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM
23
25
  from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
24
26
  from sglang.utils import logger
25
27
 
@@ -34,7 +36,7 @@ class InternS1ForConditionalGeneration(nn.Module):
34
36
  super().__init__()
35
37
  self.config = config
36
38
  self.quant_config = quant_config
37
- self._update_hf_config()
39
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
38
40
  image_size = (
39
41
  getattr(config, "force_image_size", None) or config.vision_config.image_size
40
42
  )
@@ -69,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
69
71
  self.language_model = Qwen3MoeForCausalLM(
70
72
  config=config.text_config, quant_config=quant_config
71
73
  )
74
+ elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
75
+ self.language_model = Qwen3ForCausalLM(
76
+ config=config.text_config, quant_config=quant_config
77
+ )
72
78
  else:
73
79
  raise NotImplementedError(
74
80
  f"{config.text_config.architectures[0]} is not implemented."
@@ -86,21 +92,6 @@ class InternS1ForConditionalGeneration(nn.Module):
86
92
  nn.Linear(llm_hidden_size, llm_hidden_size),
87
93
  )
88
94
 
89
- def _update_hf_config(self):
90
- """update hf config to support tp"""
91
- world_size = parallel_state.get_tensor_model_parallel_world_size()
92
- num_heads = self.config.vision_config.num_attention_heads
93
- head_dim = self.config.vision_config.hidden_size // num_heads
94
- num_dummy_heads = 0
95
-
96
- if num_heads % world_size != 0:
97
- num_dummy_heads = (
98
- (num_heads + world_size) // world_size
99
- ) * world_size - num_heads
100
-
101
- setattr(self.config.vision_config, "head_dim", head_dim)
102
- setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
103
-
104
95
  def pixel_shuffle(self, x, scale_factor=0.5):
105
96
  n, w, h, c = x.size()
106
97
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -183,34 +174,6 @@ class InternS1ForConditionalGeneration(nn.Module):
183
174
 
184
175
  return helper.pad_input_tokens(input_ids, mm_inputs)
185
176
 
186
- def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
187
- """pad attn qkv weights for dummy heads"""
188
- num_dummy_heads = self.config.vision_config.num_dummy_heads
189
- if num_dummy_heads == 0:
190
- return loaded_weight
191
- head_dim = self.config.vision_config.head_dim
192
-
193
- if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
194
- if name.endswith(".weight"):
195
- dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
196
- elif name.endswith(".bias"):
197
- dummy_shape = [num_dummy_heads, head_dim]
198
- else:
199
- raise RuntimeError(f"Unsupported weight with name={name}")
200
- padded_weight = loaded_weight.new_zeros(dummy_shape)
201
- loaded_weight = torch.cat(
202
- [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
203
- ).flatten(0, 1)
204
- if "attn.proj.weight" in name:
205
- padded_weight = loaded_weight.new_zeros(
206
- loaded_weight.shape[0], head_dim * num_dummy_heads
207
- )
208
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
209
- if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
210
- padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
211
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
212
- return loaded_weight
213
-
214
177
  def _mapping_interns1_name(self, name):
215
178
  names_map = {
216
179
  "lm_head.weight": "language_model.lm_head.weight",
@@ -254,7 +217,7 @@ class InternS1ForConditionalGeneration(nn.Module):
254
217
  ]
255
218
  expert_params_mapping = []
256
219
  if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
257
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
220
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
258
221
  ckpt_gate_proj_name="gate_proj",
259
222
  ckpt_down_proj_name="down_proj",
260
223
  ckpt_up_proj_name="up_proj",
@@ -269,7 +232,9 @@ class InternS1ForConditionalGeneration(nn.Module):
269
232
  continue
270
233
  name = self._mapping_interns1_name(name)
271
234
  if "vision_model" in name:
272
- loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
235
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
236
+ self.config, name, loaded_weight
237
+ )
273
238
 
274
239
  for param_name, weight_name, shard_id in stacked_params_mapping:
275
240
  if weight_name not in name:
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
10
10
  from transformers.activations import ACT2FN
11
11
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
12
12
 
13
- from sglang.srt.distributed import parallel_state
13
+ from sglang.srt.layers.attention import vision_utils
14
14
  from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
15
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
15
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
16
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
17
  from sglang.srt.managers.mm_utils import (
18
18
  MultiModalityDataPaddingPatternTokenPairs,
@@ -412,7 +412,7 @@ class InternVLChatModel(nn.Module):
412
412
  super().__init__()
413
413
  self.config = config
414
414
  self.quant_config = quant_config
415
- self._update_vision_config()
415
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
416
416
  image_size = config.force_image_size or config.vision_config.image_size
417
417
  patch_size = config.vision_config.patch_size
418
418
  self.patch_size = patch_size
@@ -462,21 +462,6 @@ class InternVLChatModel(nn.Module):
462
462
  nn.Linear(llm_hidden_size, llm_hidden_size),
463
463
  )
464
464
 
465
- def _update_vision_config(self):
466
- """update vision config to support tp"""
467
- world_size = parallel_state.get_tensor_model_parallel_world_size()
468
- num_heads = self.config.vision_config.num_attention_heads
469
- head_dim = self.config.vision_config.hidden_size // num_heads
470
- num_dummy_heads = 0
471
-
472
- if num_heads % world_size != 0:
473
- num_dummy_heads = (
474
- (num_heads + world_size) // world_size
475
- ) * world_size - num_heads
476
-
477
- setattr(self.config.vision_config, "head_dim", head_dim)
478
- setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
479
-
480
465
  def pixel_shuffle(self, x, scale_factor=0.5):
481
466
  n, w, h, c = x.size()
482
467
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -559,36 +544,6 @@ class InternVLChatModel(nn.Module):
559
544
 
560
545
  return helper.pad_input_tokens(input_ids, mm_inputs)
561
546
 
562
- def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
563
- """pad attn qkv weights for dummy heads"""
564
- num_dummy_heads = self.config.vision_config.num_dummy_heads
565
- if num_dummy_heads == 0:
566
- return loaded_weight
567
- head_dim = self.config.vision_config.head_dim
568
-
569
- if "attn.qkv_proj" in name:
570
- wq, wk, wv = loaded_weight.chunk(3, dim=0)
571
- if name.endswith(".weight"):
572
- dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
573
- elif name.endswith(".bias"):
574
- dummy_shape = [num_dummy_heads, head_dim]
575
- else:
576
- raise RuntimeError(f"Unsupported weight with name={name}")
577
- pad_func = lambda x: torch.cat(
578
- [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
579
- ).flatten(0, 1)
580
- wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
581
- loaded_weight = torch.cat([wq, wk, wv], dim=0)
582
- if "attn.proj.weight" in name:
583
- padded_weight = loaded_weight.new_zeros(
584
- loaded_weight.shape[0], head_dim * num_dummy_heads
585
- )
586
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
587
- if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
588
- padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
589
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
590
- return loaded_weight
591
-
592
547
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
593
548
  expert_params_mapping = []
594
549
  if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
@@ -616,7 +571,7 @@ class InternVLChatModel(nn.Module):
616
571
  ("gate_up_proj", "up_proj", 1),
617
572
  ]
618
573
 
619
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
574
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
620
575
  ckpt_gate_proj_name="gate_proj",
621
576
  ckpt_down_proj_name="down_proj",
622
577
  ckpt_up_proj_name="up_proj",
@@ -699,8 +654,8 @@ class InternVLChatModel(nn.Module):
699
654
  param, "weight_loader", default_weight_loader
700
655
  )
701
656
  if "vision_model" in name:
702
- loaded_weight = self._pad_vit_attn_dummy_heads(
703
- name, loaded_weight
657
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
658
+ self.config, name, loaded_weight
704
659
  )
705
660
  weight_loader(param, loaded_weight)
706
661
 
@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
31
31
  from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
- get_local_attention_dp_size,
35
34
  is_dp_attention_enabled,
36
35
  )
37
36
  from sglang.srt.layers.layernorm import RMSNorm
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
364
363
  rope_theta = config.rope_theta
365
364
  rope_scaling = config.rope_scaling
366
365
  max_position_embeddings = config.max_position_embeddings
367
- self.local_dp_size = get_local_attention_dp_size()
368
366
  self.attn_tp_size = get_attention_tp_size()
369
367
  self.attn_tp_rank = get_attention_tp_rank()
370
368
 
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
39
39
  )
40
- from sglang.srt.managers.schedule_batch import global_server_args_dict
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
42
  from sglang.srt.utils import add_prefix, is_cuda
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  ParallelLMHead,
48
48
  VocabParallelEmbedding,
49
49
  )
50
- from sglang.srt.managers.schedule_batch import global_server_args_dict
51
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
52
51
  from sglang.srt.model_loader.weight_utils import default_weight_loader
53
52
  from sglang.srt.utils import add_prefix, make_layers
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
104
103
  intermediate_size=intermediate_size,
105
104
  params_dtype=params_dtype,
106
105
  quant_config=quant_config,
107
- tp_size=tp_size,
108
106
  prefix=add_prefix("experts", prefix),
109
107
  )
110
108
 
@@ -0,0 +1,435 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_nas.py
15
+
16
+ """Inference-only deci model compatible with HuggingFace weights."""
17
+ from typing import Iterable, Optional, Tuple, Type, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from sglang.srt.distributed import get_pp_group
24
+ from sglang.srt.layers.layernorm import RMSNorm
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
26
+ from sglang.srt.layers.pooler import Pooler, PoolingType
27
+ from sglang.srt.layers.quantization import QuantizationConfig
28
+ from sglang.srt.layers.utils import PPMissingLayer
29
+ from sglang.srt.layers.vocab_parallel_embedding import (
30
+ DEFAULT_VOCAB_PADDING_SIZE,
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
+ from sglang.srt.model_loader.weight_utils import (
36
+ default_weight_loader,
37
+ maybe_remap_kv_scale_name,
38
+ )
39
+ from sglang.srt.models.llama import LlamaAttention, LlamaMLP
40
+ from sglang.srt.utils import add_prefix, make_layers
41
+ from sglang.utils import logger
42
+
43
+
44
+ def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
45
+ # DeciLM-specific code
46
+ intermediate_size = int(2 * ffn_mult * n_embd / 3)
47
+ return _find_multiple(intermediate_size, 256)
48
+
49
+
50
+ def _find_multiple(n: int, k: int) -> int:
51
+ # DeciLM-specific code
52
+ if n % k == 0:
53
+ return n
54
+ return n + k - (n % k)
55
+
56
+
57
+ class DeciLMDecoderLayer(nn.Module):
58
+
59
+ def __init__(
60
+ self,
61
+ config: LlamaConfig,
62
+ layer_idx: int,
63
+ quant_config: Optional[QuantizationConfig] = None,
64
+ prefix: str = "",
65
+ ) -> None:
66
+ super().__init__()
67
+ block_config = config.block_configs[layer_idx]
68
+ self._is_no_op_attention = block_config.attention.no_op
69
+ self._is_no_op_ffn = block_config.ffn.no_op
70
+
71
+ self.hidden_size = config.hidden_size
72
+ rope_theta = getattr(config, "rope_theta", 10000)
73
+ rope_scaling = getattr(config, "rope_scaling", None)
74
+ if rope_scaling is not None and getattr(
75
+ config, "original_max_position_embeddings", None
76
+ ):
77
+ rope_scaling["original_max_position_embeddings"] = (
78
+ config.original_max_position_embeddings
79
+ )
80
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
81
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
82
+ # Support internlm/internlm-7b with bias
83
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
84
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
85
+ config, "bias", False
86
+ )
87
+ # support internlm/internlm3-8b with qkv_bias
88
+ if hasattr(config, "qkv_bias"):
89
+ attention_bias = config.qkv_bias
90
+
91
+ if not self._is_no_op_attention:
92
+ num_kv_heads = (
93
+ config.num_attention_heads // block_config.attention.n_heads_in_group
94
+ )
95
+ self.self_attn = LlamaAttention(
96
+ config=config,
97
+ hidden_size=self.hidden_size,
98
+ num_heads=config.num_attention_heads,
99
+ num_kv_heads=num_kv_heads,
100
+ layer_id=layer_idx,
101
+ rope_theta=rope_theta,
102
+ rope_scaling=rope_scaling,
103
+ rope_is_neox_style=rope_is_neox_style,
104
+ max_position_embeddings=max_position_embeddings,
105
+ quant_config=quant_config,
106
+ prefix=add_prefix("self_attn", prefix),
107
+ bias=attention_bias,
108
+ )
109
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
110
+
111
+ if not self._is_no_op_ffn:
112
+ ffn_mult = block_config.ffn.ffn_mult
113
+ intermediate_size = _ffn_mult_to_intermediate_size(
114
+ ffn_mult, config.hidden_size
115
+ )
116
+ self.mlp = LlamaMLP(
117
+ hidden_size=self.hidden_size,
118
+ intermediate_size=intermediate_size,
119
+ hidden_act=config.hidden_act,
120
+ quant_config=quant_config,
121
+ prefix=add_prefix("mlp", prefix),
122
+ )
123
+ self.post_attention_layernorm = RMSNorm(
124
+ config.hidden_size, eps=config.rms_norm_eps
125
+ )
126
+
127
+ def forward(
128
+ self,
129
+ positions: torch.Tensor,
130
+ hidden_states: torch.Tensor,
131
+ forward_batch: ForwardBatch,
132
+ residual: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ # Self Attention
135
+
136
+ if self._is_no_op_attention:
137
+ pass
138
+ else:
139
+ if residual is None:
140
+ residual = hidden_states
141
+ hidden_states = self.input_layernorm(hidden_states)
142
+ else:
143
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
144
+ hidden_states = self.self_attn(
145
+ positions=positions,
146
+ hidden_states=hidden_states,
147
+ forward_batch=forward_batch,
148
+ )
149
+
150
+ # Fully Connected
151
+ if not self._is_no_op_ffn:
152
+ hidden_states, residual = self.post_attention_layernorm(
153
+ hidden_states, residual
154
+ )
155
+ hidden_states = self.mlp(hidden_states)
156
+ return hidden_states, residual
157
+
158
+
159
+ class DeciModel(nn.Module):
160
+ def __init__(
161
+ self,
162
+ *,
163
+ config: LlamaConfig,
164
+ quant_config: Optional[QuantizationConfig] = None,
165
+ prefix: str = "",
166
+ layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
167
+ ):
168
+ super().__init__()
169
+
170
+ lora_config = None
171
+ self.config = config
172
+ self.quant_config = quant_config
173
+ self.padding_idx = config.pad_token_id
174
+ lora_vocab = (
175
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
176
+ if lora_config
177
+ else 0
178
+ )
179
+ vocab_size = config.vocab_size + lora_vocab
180
+ if get_pp_group().is_first_rank:
181
+ self.embed_tokens = VocabParallelEmbedding(
182
+ vocab_size,
183
+ config.hidden_size,
184
+ org_num_embeddings=config.vocab_size,
185
+ quant_config=quant_config,
186
+ )
187
+ else:
188
+ self.embed_tokens = PPMissingLayer()
189
+
190
+ def get_layer(idx: int, prefix: str):
191
+ return layer_type(
192
+ config,
193
+ layer_idx=idx,
194
+ quant_config=quant_config,
195
+ prefix=prefix,
196
+ )
197
+
198
+ self.layers, self.start_layer, self.end_layer = make_layers(
199
+ config.num_hidden_layers,
200
+ get_layer,
201
+ pp_rank=get_pp_group().rank_in_group,
202
+ pp_size=get_pp_group().world_size,
203
+ prefix=add_prefix("layers", prefix),
204
+ )
205
+ if get_pp_group().is_last_rank:
206
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
207
+ else:
208
+ self.norm = PPMissingLayer(return_tuple=True)
209
+
210
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
211
+ return self.embed_tokens(input_ids)
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.Tensor],
216
+ positions: torch.Tensor,
217
+ forward_batch: ForwardBatch,
218
+ inputs_embeds: Optional[torch.Tensor] = None,
219
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
220
+ ) -> Union[torch.Tensor, PPProxyTensors]:
221
+ if get_pp_group().is_first_rank:
222
+ if inputs_embeds is not None:
223
+ hidden_states = inputs_embeds
224
+ else:
225
+ hidden_states = self.get_input_embeddings(input_ids)
226
+ residual = None
227
+ else:
228
+ assert pp_proxy_tensors is not None
229
+ hidden_states = pp_proxy_tensors["hidden_states"]
230
+ residual = pp_proxy_tensors["residual"]
231
+
232
+ kv_cache_index = 0
233
+ for i in range(self.start_layer, self.end_layer):
234
+ layer = self.layers[i]
235
+ if not layer._is_no_op_attention:
236
+ hidden_states, residual = layer(
237
+ positions, hidden_states, forward_batch, residual
238
+ )
239
+ kv_cache_index += 1
240
+ else:
241
+ hidden_states, residual = layer(
242
+ positions, hidden_states, forward_batch, residual
243
+ )
244
+
245
+ if not get_pp_group().is_last_rank:
246
+ return PPProxyTensors(
247
+ {"hidden_states": hidden_states, "residual": residual}
248
+ )
249
+
250
+ hidden_states, _ = self.norm(hidden_states, residual)
251
+ return hidden_states
252
+
253
+
254
+ class DeciLMForCausalLM(nn.Module):
255
+ packed_modules_mapping = {
256
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
257
+ "gate_up_proj": ["gate_proj", "up_proj"],
258
+ }
259
+
260
+ # LoRA specific attributes
261
+ supported_lora_modules = [
262
+ "qkv_proj",
263
+ "o_proj",
264
+ "gate_up_proj",
265
+ "down_proj",
266
+ "embed_tokens",
267
+ "lm_head",
268
+ ]
269
+ embedding_modules = {
270
+ "embed_tokens": "input_embeddings",
271
+ "lm_head": "output_embeddings",
272
+ }
273
+ embedding_padding_modules = ["lm_head"]
274
+
275
+ # Mistral/Llama models can also be loaded with --load-format mistral
276
+ # from consolidated.safetensors checkpoints
277
+ mistral_mapping = {
278
+ "layers": "model.layers",
279
+ "attention": "self_attn",
280
+ "wq": "q_proj",
281
+ "wk": "k_proj",
282
+ "wv": "v_proj",
283
+ "wo": "o_proj",
284
+ "attention_norm": "input_layernorm",
285
+ "feed_forward": "mlp",
286
+ "w1": "gate_proj",
287
+ "w2": "down_proj",
288
+ "w3": "up_proj",
289
+ "ffn_norm": "post_attention_layernorm",
290
+ "tok_embeddings": "model.embed_tokens",
291
+ "output": "lm_head",
292
+ "norm": "model.norm",
293
+ }
294
+
295
+ def __init__(
296
+ self,
297
+ *,
298
+ config: LlamaConfig,
299
+ quant_config: Optional[QuantizationConfig] = None,
300
+ prefix: str = "",
301
+ ):
302
+ super().__init__()
303
+ lora_config = None
304
+ self.config = config
305
+ self.lora_config = lora_config
306
+
307
+ self.model = self._init_model(
308
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
309
+ )
310
+ if self.config.tie_word_embeddings:
311
+ self.lm_head = self.model.embed_tokens
312
+ else:
313
+ self.unpadded_vocab_size = config.vocab_size
314
+ if lora_config:
315
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
316
+ self.lm_head = ParallelLMHead(
317
+ self.unpadded_vocab_size,
318
+ config.hidden_size,
319
+ org_num_embeddings=config.vocab_size,
320
+ padding_size=(
321
+ DEFAULT_VOCAB_PADDING_SIZE
322
+ # We need bigger padding if using lora for kernel
323
+ # compatibility
324
+ if not lora_config
325
+ else lora_config.lora_vocab_padding_size
326
+ ),
327
+ quant_config=quant_config,
328
+ prefix=add_prefix("lm_head", prefix),
329
+ )
330
+ self.logits_processor = LogitsProcessor(config)
331
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
332
+
333
+ def _init_model(
334
+ self,
335
+ config: LlamaConfig,
336
+ quant_config: Optional[QuantizationConfig] = None,
337
+ prefix: str = "",
338
+ ):
339
+ return DeciModel(config=config, quant_config=quant_config, prefix=prefix)
340
+
341
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
342
+ return self.model.get_input_embeddings(input_ids)
343
+
344
+ @torch.no_grad()
345
+ def forward(
346
+ self,
347
+ input_ids: torch.Tensor,
348
+ positions: torch.Tensor,
349
+ forward_batch: ForwardBatch,
350
+ inputs_embeds: Optional[torch.Tensor] = None,
351
+ get_embedding: bool = False,
352
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
353
+ ) -> LogitsProcessorOutput:
354
+ hidden_states = self.model(
355
+ input_ids,
356
+ positions,
357
+ forward_batch,
358
+ inputs_embeds,
359
+ pp_proxy_tensors=pp_proxy_tensors,
360
+ )
361
+ if get_pp_group().is_last_rank:
362
+ if not get_embedding:
363
+ return self.logits_processor(
364
+ input_ids, hidden_states, self.lm_head, forward_batch
365
+ )
366
+ else:
367
+ return self.pooler(hidden_states, forward_batch)
368
+ else:
369
+ return hidden_states
370
+
371
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
372
+ stacked_params_mapping = [
373
+ # (param_name, shard_name, shard_id)
374
+ (".qkv_proj", ".q_proj", "q"),
375
+ (".qkv_proj", ".k_proj", "k"),
376
+ (".qkv_proj", ".v_proj", "v"),
377
+ (".gate_up_proj", ".gate_proj", 0),
378
+ (".gate_up_proj", ".up_proj", 1),
379
+ ]
380
+
381
+ params_dict = dict(self.named_parameters())
382
+
383
+ for name, loaded_weight in weights:
384
+ if "rotary_emb.inv_freq" in name:
385
+ continue
386
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
387
+ # Models trained using ColossalAI may include these tensors in
388
+ # the checkpoint. Skip them.
389
+ continue
390
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
391
+ continue
392
+ if self.model.quant_config is not None and (
393
+ scale_name := self.model.quant_config.get_cache_scale(name)
394
+ ):
395
+ # Loading kv cache quantization scales
396
+ param = params_dict[scale_name]
397
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
398
+ loaded_weight = (
399
+ loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
400
+ )
401
+ weight_loader(param, loaded_weight)
402
+ continue
403
+ if "scale" in name:
404
+ name = maybe_remap_kv_scale_name(name, params_dict)
405
+ if name is None:
406
+ continue
407
+
408
+ for param_name, weight_name, shard_id in stacked_params_mapping:
409
+ if weight_name not in name:
410
+ continue
411
+ name = name.replace(weight_name, param_name)
412
+ # Skip loading extra bias for GPTQ models.
413
+ if name.endswith(".bias") and name not in params_dict:
414
+ continue
415
+ if name not in params_dict:
416
+ continue
417
+ param = params_dict[name]
418
+ weight_loader = param.weight_loader
419
+ weight_loader(param, loaded_weight, shard_id)
420
+ break
421
+ else:
422
+ # Skip loading extra bias for GPTQ models.
423
+ if name.endswith(".bias") and name not in params_dict:
424
+ continue
425
+ if name in params_dict.keys():
426
+ param = params_dict[name]
427
+ weight_loader = getattr(
428
+ param, "weight_loader", default_weight_loader
429
+ )
430
+ weight_loader(param, loaded_weight)
431
+ else:
432
+ logger.warning(f"Parameter {name} not found in params_dict")
433
+
434
+
435
+ EntryClass = [DeciLMForCausalLM]