sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
190
190
  prefix: str = "",
191
191
  ) -> None:
192
192
  super().__init__()
193
+ self.layer_id = layer_id
193
194
  self.hidden_size = config.hidden_size
194
195
  self.self_attn = Gemma2Attention(
195
196
  layer_id=layer_id,
@@ -380,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
380
381
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
381
382
  )
382
383
 
384
+ @torch.no_grad()
385
+ def forward_split_prefill(
386
+ self,
387
+ input_ids: torch.Tensor,
388
+ positions: torch.Tensor,
389
+ forward_batch: ForwardBatch,
390
+ split_interval: Tuple[int, int], # [start, end) 0-based
391
+ input_embeds: torch.Tensor = None,
392
+ ):
393
+ start, end = split_interval
394
+ # embed
395
+ if start == 0:
396
+ if input_embeds is None:
397
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
398
+ else:
399
+ forward_batch.hidden_states = input_embeds
400
+
401
+ # Normalize
402
+ normalizer = torch.tensor(
403
+ self.model.config.hidden_size**0.5, dtype=torch.float16
404
+ )
405
+ forward_batch.hidden_states *= normalizer
406
+
407
+ # decoder layer
408
+ for i in range(start, end):
409
+ layer = self.model.layers[i]
410
+ forward_batch.hidden_states, forward_batch.residual = layer(
411
+ positions,
412
+ forward_batch.hidden_states,
413
+ forward_batch,
414
+ forward_batch.residual,
415
+ )
416
+
417
+ if end == self.model.config.num_hidden_layers:
418
+ # norm
419
+ forward_batch.hidden_states, _ = self.model.norm(
420
+ forward_batch.hidden_states, forward_batch.residual
421
+ )
422
+
423
+ # logits process
424
+ result = self.logits_processor(
425
+ input_ids,
426
+ forward_batch.hidden_states,
427
+ self.model.embed_tokens,
428
+ forward_batch,
429
+ )
430
+ else:
431
+ result = None
432
+
433
+ return result
434
+
383
435
  def get_hidden_dim(self, module_name):
384
436
  # return input_dim, output_dim
385
437
  if module_name in ["q_proj", "qkv_proj"]:
@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
647
647
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
648
648
  )
649
649
 
650
+ @torch.no_grad()
651
+ def forward_split_prefill(
652
+ self,
653
+ input_ids: torch.Tensor,
654
+ positions: torch.Tensor,
655
+ forward_batch: ForwardBatch,
656
+ split_interval: Tuple[int, int], # [start, end) 0-based
657
+ input_embeds: torch.Tensor = None,
658
+ ):
659
+ start, end = split_interval
660
+ # embed
661
+ if start == 0:
662
+ if input_embeds is None:
663
+ hidden_states = self.model.embed_tokens(input_ids)
664
+ else:
665
+ hidden_states = input_embeds
666
+
667
+ if positions.dim() == 1:
668
+ positions = einops.rearrange(positions, "s -> 1 s")
669
+ position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
670
+ position_embeddings_local = self.model.rotary_emb_local(
671
+ hidden_states, positions
672
+ )
673
+
674
+ forward_batch.hidden_states = hidden_states
675
+ forward_batch.model_specific_states = {
676
+ "positions": positions,
677
+ "position_embeddings_global": position_embeddings_global,
678
+ "position_embeddings_local": position_embeddings_local,
679
+ }
680
+
681
+ # decoder layer
682
+ for i in range(start, end):
683
+ layer = self.model.layers[i]
684
+ layer_output = layer(
685
+ positions=forward_batch.model_specific_states["positions"],
686
+ position_embeddings_global=forward_batch.model_specific_states[
687
+ "position_embeddings_global"
688
+ ],
689
+ position_embeddings_local=forward_batch.model_specific_states[
690
+ "position_embeddings_local"
691
+ ],
692
+ hidden_states=forward_batch.hidden_states,
693
+ forward_batch=forward_batch,
694
+ )
695
+ forward_batch.hidden_states = layer_output[0]
696
+
697
+ if end == self.model.config.num_hidden_layers:
698
+ # norm
699
+ forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
700
+
701
+ # logits process
702
+ result = self.logits_processor(
703
+ input_ids,
704
+ forward_batch.hidden_states,
705
+ self.model.embed_tokens,
706
+ forward_batch,
707
+ )
708
+ else:
709
+ result = None
710
+
711
+ return result
712
+
650
713
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
651
714
  stacked_params_mapping = [
652
715
  # (param_name, shard_name, shard_id)
@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
283
283
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
284
284
  """
285
285
  # Process images one by one to handle flatten_batch=True constraint in vision_tower
286
- all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
286
+ all_pixel_values = flatten_nested_list([item.feature for item in items])
287
287
  vision_outputs_list = []
288
288
 
289
289
  for pixel_values_batch in all_pixel_values:
@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
265
265
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
266
266
  """
267
267
  # Process images one by one to handle flatten_batch=True constraint in vision_tower
268
- all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
268
+ all_pixel_values = flatten_nested_list([item.feature for item in items])
269
269
  vision_outputs_list = []
270
270
 
271
271
  for pixel_values_batch in all_pixel_values:
@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
316
316
  audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
317
317
  """
318
318
  # Extract audio features and masks from items
319
- all_input_features = flatten_nested_list(
320
- [item.input_features for item in items]
321
- )
319
+ all_input_features = flatten_nested_list([item.feature for item in items])
322
320
  all_input_features_mask = flatten_nested_list(
323
321
  [~item.input_features_mask for item in items]
324
322
  ) # Note(Xinyuan): reverse the mask according to the HF implementation
@@ -0,0 +1,385 @@
1
+ """Inference-only GraniteMoe model."""
2
+
3
+ from typing import Iterable, Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import GraniteConfig
8
+
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
10
+ from sglang.srt.layers.layernorm import RMSNorm
11
+ from sglang.srt.layers.linear import (
12
+ QKVParallelLinear,
13
+ ReplicatedLinear,
14
+ RowParallelLinear,
15
+ )
16
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
17
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
18
+ from sglang.srt.layers.moe.topk import TopK
19
+ from sglang.srt.layers.pooler import Pooler, PoolingType
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.layers.rotary_embedding import get_rope
23
+ from sglang.srt.layers.vocab_parallel_embedding import (
24
+ ParallelLMHead,
25
+ VocabParallelEmbedding,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.models import mixtral
29
+ from sglang.srt.utils import add_prefix
30
+
31
+
32
+ class GraniteMoeMoE(nn.Module):
33
+ """A tensor-parallel MoE implementation for GraniteMoe that shards each
34
+ expert across all ranks.
35
+ Each expert's weights are sharded across all ranks and a fused MoE
36
+ kernel is used for the forward pass, and finally we reduce the outputs
37
+ across ranks.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_experts: int,
43
+ top_k: int,
44
+ hidden_size: int,
45
+ intermediate_size: int,
46
+ params_dtype: Optional[torch.dtype] = None,
47
+ quant_config: Optional[QuantizationConfig] = None,
48
+ tp_size: Optional[int] = None,
49
+ prefix: str = "",
50
+ ):
51
+ super().__init__()
52
+ self.hidden_size = hidden_size
53
+
54
+ # Gate always runs at half / full precision for now.
55
+ self.gate = ReplicatedLinear(
56
+ hidden_size,
57
+ num_experts,
58
+ bias=False,
59
+ params_dtype=params_dtype,
60
+ quant_config=None,
61
+ prefix=f"{prefix}.gate",
62
+ )
63
+
64
+ self.topk = TopK(
65
+ top_k=top_k,
66
+ renormalize=True,
67
+ )
68
+
69
+ self.experts = FusedMoE(
70
+ num_experts=num_experts,
71
+ top_k=top_k,
72
+ hidden_size=hidden_size,
73
+ intermediate_size=intermediate_size,
74
+ params_dtype=params_dtype,
75
+ reduce_results=True,
76
+ quant_config=quant_config,
77
+ tp_size=tp_size,
78
+ prefix=f"{prefix}.experts",
79
+ )
80
+
81
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
82
+ # NOTE: hidden_states can have either 1D or 2D shape.
83
+ orig_shape = hidden_states.shape
84
+ hidden_states = hidden_states.view(-1, self.hidden_size)
85
+ router_logits, _ = self.gate(hidden_states)
86
+ topk_output = self.topk(hidden_states, router_logits)
87
+ final_hidden_states = self.experts(hidden_states, topk_output)
88
+ return final_hidden_states.view(orig_shape)
89
+
90
+
91
+ class GraniteMoeAttention(nn.Module):
92
+
93
+ def __init__(
94
+ self,
95
+ hidden_size: int,
96
+ num_heads: int,
97
+ num_kv_heads: int,
98
+ max_position: int = 4096 * 32,
99
+ layer_id: int = 0,
100
+ rope_theta: float = 10000,
101
+ quant_config: Optional[QuantizationConfig] = None,
102
+ attention_multiplier: Optional[float] = None,
103
+ prefix: str = "",
104
+ ) -> None:
105
+ super().__init__()
106
+ self.hidden_size = hidden_size
107
+ tp_size = get_tensor_model_parallel_world_size()
108
+ self.total_num_heads = num_heads
109
+ assert self.total_num_heads % tp_size == 0
110
+ self.num_heads = self.total_num_heads // tp_size
111
+ self.total_num_kv_heads = num_kv_heads
112
+ if self.total_num_kv_heads >= tp_size:
113
+ # Number of KV heads is greater than TP size, so we partition
114
+ # the KV heads across multiple tensor parallel GPUs.
115
+ assert self.total_num_kv_heads % tp_size == 0
116
+ else:
117
+ # Number of KV heads is less than TP size, so we replicate
118
+ # the KV heads across multiple tensor parallel GPUs.
119
+ assert tp_size % self.total_num_kv_heads == 0
120
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
121
+ self.head_dim = hidden_size // self.total_num_heads
122
+ self.q_size = self.num_heads * self.head_dim
123
+ self.kv_size = self.num_kv_heads * self.head_dim
124
+ self.scaling = (
125
+ attention_multiplier
126
+ if attention_multiplier is not None
127
+ else self.head_dim**-1
128
+ )
129
+ self.rope_theta = rope_theta
130
+
131
+ self.qkv_proj = QKVParallelLinear(
132
+ hidden_size,
133
+ self.head_dim,
134
+ self.total_num_heads,
135
+ self.total_num_kv_heads,
136
+ bias=False,
137
+ quant_config=quant_config,
138
+ prefix=f"{prefix}.qkv_proj",
139
+ )
140
+ self.o_proj = RowParallelLinear(
141
+ self.total_num_heads * self.head_dim,
142
+ hidden_size,
143
+ bias=False,
144
+ quant_config=quant_config,
145
+ prefix=f"{prefix}.o_proj",
146
+ )
147
+ self.rotary_emb = get_rope(
148
+ self.head_dim,
149
+ rotary_dim=self.head_dim,
150
+ max_position=max_position,
151
+ base=int(self.rope_theta),
152
+ is_neox_style=True,
153
+ )
154
+ self.attn = RadixAttention(
155
+ self.num_heads,
156
+ self.head_dim,
157
+ self.scaling,
158
+ num_kv_heads=self.num_kv_heads,
159
+ layer_id=layer_id,
160
+ quant_config=quant_config,
161
+ prefix=f"{prefix}.attn",
162
+ )
163
+
164
+ def forward(
165
+ self,
166
+ positions: torch.Tensor,
167
+ hidden_states: torch.Tensor,
168
+ forward_batch: ForwardBatch,
169
+ ) -> torch.Tensor:
170
+ qkv, _ = self.qkv_proj(hidden_states)
171
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
172
+ q, k = self.rotary_emb(positions, q, k)
173
+ attn_output = self.attn(q, k, v, forward_batch)
174
+ output, _ = self.o_proj(attn_output)
175
+ return output
176
+
177
+
178
+ class GraniteMoeDecoderLayer(nn.Module):
179
+
180
+ def __init__(
181
+ self,
182
+ config: GraniteConfig,
183
+ layer_id: int = 0,
184
+ quant_config: Optional[QuantizationConfig] = None,
185
+ prefix: str = "",
186
+ ) -> None:
187
+ super().__init__()
188
+ self.hidden_size = config.hidden_size
189
+ rope_theta = getattr(config, "rope_theta", 10000)
190
+ self.self_attn = GraniteMoeAttention(
191
+ hidden_size=self.hidden_size,
192
+ num_heads=config.num_attention_heads,
193
+ max_position=config.max_position_embeddings,
194
+ num_kv_heads=config.num_key_value_heads,
195
+ rope_theta=rope_theta,
196
+ layer_id=layer_id,
197
+ quant_config=quant_config,
198
+ prefix=f"{prefix}.self_attn",
199
+ attention_multiplier=config.attention_multiplier,
200
+ )
201
+ self.block_sparse_moe = GraniteMoeMoE(
202
+ num_experts=config.num_local_experts,
203
+ top_k=config.num_experts_per_tok,
204
+ hidden_size=config.hidden_size,
205
+ intermediate_size=config.intermediate_size,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.block_sparse_moe",
208
+ )
209
+
210
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
211
+ self.post_attention_layernorm = RMSNorm(
212
+ config.hidden_size, eps=config.rms_norm_eps
213
+ )
214
+
215
+ self.residual_multiplier = config.residual_multiplier
216
+
217
+ def forward(
218
+ self,
219
+ positions: torch.Tensor,
220
+ hidden_states: torch.Tensor,
221
+ forward_batch: ForwardBatch,
222
+ ) -> torch.Tensor:
223
+ residual = hidden_states
224
+ hidden_states = self.input_layernorm(hidden_states)
225
+ # Self Attention
226
+ hidden_states = self.self_attn(
227
+ positions=positions,
228
+ hidden_states=hidden_states,
229
+ forward_batch=forward_batch,
230
+ )
231
+ hidden_states = residual + hidden_states * self.residual_multiplier
232
+ residual = hidden_states
233
+ hidden_states = self.post_attention_layernorm(hidden_states)
234
+ hidden_states = self.block_sparse_moe(hidden_states)
235
+ hidden_states = residual + hidden_states * self.residual_multiplier
236
+
237
+ return hidden_states
238
+
239
+
240
+ class GraniteMoeModel(nn.Module):
241
+
242
+ def __init__(
243
+ self,
244
+ config: GraniteConfig,
245
+ quant_config: Optional[QuantizationConfig] = None,
246
+ prefix: str = "",
247
+ ):
248
+ super().__init__()
249
+ self.embed_tokens = VocabParallelEmbedding(
250
+ config.vocab_size,
251
+ config.hidden_size,
252
+ org_num_embeddings=config.vocab_size,
253
+ )
254
+ self.embedding_multiplier = config.embedding_multiplier
255
+
256
+ self.layers = nn.ModuleList(
257
+ [
258
+ GraniteMoeDecoderLayer(
259
+ config,
260
+ i,
261
+ quant_config=quant_config,
262
+ prefix=add_prefix(f"layers.{i}", prefix),
263
+ )
264
+ for i in range(config.num_hidden_layers)
265
+ ]
266
+ )
267
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268
+
269
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
270
+ return self.embed_tokens(input_ids)
271
+
272
+ def forward(
273
+ self,
274
+ input_ids: torch.Tensor,
275
+ positions: torch.Tensor,
276
+ forward_batch: ForwardBatch,
277
+ inputs_embeds: Optional[torch.Tensor] = None,
278
+ ) -> torch.Tensor:
279
+ if inputs_embeds is not None:
280
+ hidden_states = inputs_embeds
281
+ else:
282
+ hidden_states = self.get_input_embeddings(input_ids)
283
+ hidden_states *= self.embedding_multiplier
284
+
285
+ for i in range(len(self.layers)):
286
+ layer = self.layers[i]
287
+ hidden_states = layer(
288
+ positions,
289
+ hidden_states,
290
+ forward_batch,
291
+ )
292
+ hidden_states = self.norm(hidden_states)
293
+ return hidden_states
294
+
295
+
296
+ class GraniteMoeForCausalLM(nn.Module):
297
+
298
+ def __init__(
299
+ self,
300
+ config: GraniteConfig,
301
+ quant_config: Optional[QuantizationConfig] = None,
302
+ prefix: str = "",
303
+ ):
304
+ super().__init__()
305
+ self.config = config
306
+ self.quant_config = quant_config
307
+
308
+ self.model = GraniteMoeModel(
309
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
310
+ )
311
+ self.lm_head = ParallelLMHead(
312
+ config.vocab_size,
313
+ config.hidden_size,
314
+ quant_config=quant_config,
315
+ prefix=add_prefix("lm_head", prefix),
316
+ )
317
+ if config.tie_word_embeddings:
318
+ self.lm_head.weight = self.model.embed_tokens.weight
319
+ # Granite logit scaling factors are applied via division, but
320
+ # LogitsProcessor expects a multiplicative factor.
321
+ if hasattr(config, "logits_scaling"):
322
+ logit_scale = 1.0 / config.logits_scaling
323
+ else:
324
+ logit_scale = None
325
+ self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
326
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
327
+
328
+ @torch.no_grad()
329
+ def forward(
330
+ self,
331
+ input_ids: torch.Tensor,
332
+ positions: torch.Tensor,
333
+ forward_batch: ForwardBatch,
334
+ input_embeds: torch.Tensor = None,
335
+ get_embedding: bool = False,
336
+ ) -> LogitsProcessorOutput:
337
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
338
+ if not get_embedding:
339
+ logits_processor_output: LogitsProcessorOutput = self.logits_processor(
340
+ input_ids, hidden_states, self.lm_head, forward_batch
341
+ )
342
+ return logits_processor_output
343
+ else:
344
+ return self.pooler(hidden_states, forward_batch)
345
+
346
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
347
+ new_weights = {}
348
+ for n, p in weights:
349
+ if n.endswith(".block_sparse_moe.input_linear.weight"):
350
+ for e in range(p.size(0)):
351
+ w1_name = n.replace(
352
+ ".block_sparse_moe.input_linear.weight",
353
+ f".block_sparse_moe.experts.{e}.w1.weight",
354
+ )
355
+ w3_name = n.replace(
356
+ ".block_sparse_moe.input_linear.weight",
357
+ f".block_sparse_moe.experts.{e}.w3.weight",
358
+ )
359
+ w1_param, w3_param = p[e].chunk(2, dim=0)
360
+ assert w1_name not in new_weights
361
+ assert w3_name not in new_weights
362
+ new_weights[w1_name] = w1_param
363
+ new_weights[w3_name] = w3_param
364
+ elif n.endswith(".block_sparse_moe.output_linear.weight"):
365
+ for e in range(p.size(0)):
366
+ w2_name = n.replace(
367
+ ".block_sparse_moe.output_linear.weight",
368
+ f".block_sparse_moe.experts.{e}.w2.weight",
369
+ )
370
+ w2_param = p[e]
371
+ assert w2_name not in new_weights
372
+ new_weights[w2_name] = w2_param
373
+ elif n.endswith(".block_sparse_moe.router.layer.weight"):
374
+ gate_name = n.replace(
375
+ ".block_sparse_moe.router.layer.weight",
376
+ ".block_sparse_moe.gate.weight",
377
+ )
378
+ assert gate_name not in new_weights
379
+ new_weights[gate_name] = p
380
+ else:
381
+ new_weights[n] = p
382
+ mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
383
+
384
+
385
+ EntryClass = [GraniteMoeForCausalLM]
sglang/srt/models/grok.py CHANGED
@@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
45
45
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
46
46
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
47
47
  from sglang.srt.layers.moe.router import fused_moe_router_shim
48
+ from sglang.srt.layers.moe.topk import TopK
48
49
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
50
  from sglang.srt.layers.radix_attention import RadixAttention
50
51
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -108,6 +109,12 @@ class Grok1MoE(nn.Module):
108
109
  fused_moe_router_shim, self.router_logit_softcapping
109
110
  )
110
111
 
112
+ self.topk = TopK(
113
+ top_k=top_k,
114
+ renormalize=False,
115
+ custom_routing_function=custom_routing_function,
116
+ )
117
+
111
118
  kwargs = {}
112
119
  if global_server_args_dict["enable_ep_moe"]:
113
120
  MoEImpl = EPMoE
@@ -124,17 +131,16 @@ class Grok1MoE(nn.Module):
124
131
  hidden_size=hidden_size,
125
132
  intermediate_size=intermediate_size,
126
133
  params_dtype=params_dtype,
127
- renormalize=False,
128
134
  quant_config=quant_config,
129
135
  tp_size=tp_size,
130
- custom_routing_function=custom_routing_function,
131
136
  activation="gelu",
132
137
  **kwargs,
133
138
  )
134
139
 
135
140
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
136
141
  # need to assert self.gate.quant_method is unquantized
137
- return self.experts(hidden_states, self.gate.weight)
142
+ topk_output = self.topk(hidden_states, self.gate.weight)
143
+ return self.experts(hidden_states, topk_output)
138
144
 
139
145
 
140
146
  class Grok1Attention(nn.Module):