sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma2.py
CHANGED
@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
190
190
|
prefix: str = "",
|
191
191
|
) -> None:
|
192
192
|
super().__init__()
|
193
|
+
self.layer_id = layer_id
|
193
194
|
self.hidden_size = config.hidden_size
|
194
195
|
self.self_attn = Gemma2Attention(
|
195
196
|
layer_id=layer_id,
|
@@ -380,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
|
|
380
381
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
381
382
|
)
|
382
383
|
|
384
|
+
@torch.no_grad()
|
385
|
+
def forward_split_prefill(
|
386
|
+
self,
|
387
|
+
input_ids: torch.Tensor,
|
388
|
+
positions: torch.Tensor,
|
389
|
+
forward_batch: ForwardBatch,
|
390
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
391
|
+
input_embeds: torch.Tensor = None,
|
392
|
+
):
|
393
|
+
start, end = split_interval
|
394
|
+
# embed
|
395
|
+
if start == 0:
|
396
|
+
if input_embeds is None:
|
397
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
398
|
+
else:
|
399
|
+
forward_batch.hidden_states = input_embeds
|
400
|
+
|
401
|
+
# Normalize
|
402
|
+
normalizer = torch.tensor(
|
403
|
+
self.model.config.hidden_size**0.5, dtype=torch.float16
|
404
|
+
)
|
405
|
+
forward_batch.hidden_states *= normalizer
|
406
|
+
|
407
|
+
# decoder layer
|
408
|
+
for i in range(start, end):
|
409
|
+
layer = self.model.layers[i]
|
410
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
411
|
+
positions,
|
412
|
+
forward_batch.hidden_states,
|
413
|
+
forward_batch,
|
414
|
+
forward_batch.residual,
|
415
|
+
)
|
416
|
+
|
417
|
+
if end == self.model.config.num_hidden_layers:
|
418
|
+
# norm
|
419
|
+
forward_batch.hidden_states, _ = self.model.norm(
|
420
|
+
forward_batch.hidden_states, forward_batch.residual
|
421
|
+
)
|
422
|
+
|
423
|
+
# logits process
|
424
|
+
result = self.logits_processor(
|
425
|
+
input_ids,
|
426
|
+
forward_batch.hidden_states,
|
427
|
+
self.model.embed_tokens,
|
428
|
+
forward_batch,
|
429
|
+
)
|
430
|
+
else:
|
431
|
+
result = None
|
432
|
+
|
433
|
+
return result
|
434
|
+
|
383
435
|
def get_hidden_dim(self, module_name):
|
384
436
|
# return input_dim, output_dim
|
385
437
|
if module_name in ["q_proj", "qkv_proj"]:
|
@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|
647
647
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
648
648
|
)
|
649
649
|
|
650
|
+
@torch.no_grad()
|
651
|
+
def forward_split_prefill(
|
652
|
+
self,
|
653
|
+
input_ids: torch.Tensor,
|
654
|
+
positions: torch.Tensor,
|
655
|
+
forward_batch: ForwardBatch,
|
656
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
657
|
+
input_embeds: torch.Tensor = None,
|
658
|
+
):
|
659
|
+
start, end = split_interval
|
660
|
+
# embed
|
661
|
+
if start == 0:
|
662
|
+
if input_embeds is None:
|
663
|
+
hidden_states = self.model.embed_tokens(input_ids)
|
664
|
+
else:
|
665
|
+
hidden_states = input_embeds
|
666
|
+
|
667
|
+
if positions.dim() == 1:
|
668
|
+
positions = einops.rearrange(positions, "s -> 1 s")
|
669
|
+
position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
|
670
|
+
position_embeddings_local = self.model.rotary_emb_local(
|
671
|
+
hidden_states, positions
|
672
|
+
)
|
673
|
+
|
674
|
+
forward_batch.hidden_states = hidden_states
|
675
|
+
forward_batch.model_specific_states = {
|
676
|
+
"positions": positions,
|
677
|
+
"position_embeddings_global": position_embeddings_global,
|
678
|
+
"position_embeddings_local": position_embeddings_local,
|
679
|
+
}
|
680
|
+
|
681
|
+
# decoder layer
|
682
|
+
for i in range(start, end):
|
683
|
+
layer = self.model.layers[i]
|
684
|
+
layer_output = layer(
|
685
|
+
positions=forward_batch.model_specific_states["positions"],
|
686
|
+
position_embeddings_global=forward_batch.model_specific_states[
|
687
|
+
"position_embeddings_global"
|
688
|
+
],
|
689
|
+
position_embeddings_local=forward_batch.model_specific_states[
|
690
|
+
"position_embeddings_local"
|
691
|
+
],
|
692
|
+
hidden_states=forward_batch.hidden_states,
|
693
|
+
forward_batch=forward_batch,
|
694
|
+
)
|
695
|
+
forward_batch.hidden_states = layer_output[0]
|
696
|
+
|
697
|
+
if end == self.model.config.num_hidden_layers:
|
698
|
+
# norm
|
699
|
+
forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
|
700
|
+
|
701
|
+
# logits process
|
702
|
+
result = self.logits_processor(
|
703
|
+
input_ids,
|
704
|
+
forward_batch.hidden_states,
|
705
|
+
self.model.embed_tokens,
|
706
|
+
forward_batch,
|
707
|
+
)
|
708
|
+
else:
|
709
|
+
result = None
|
710
|
+
|
711
|
+
return result
|
712
|
+
|
650
713
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
651
714
|
stacked_params_mapping = [
|
652
715
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
283
283
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
284
284
|
"""
|
285
285
|
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
286
|
-
all_pixel_values = flatten_nested_list([item.
|
286
|
+
all_pixel_values = flatten_nested_list([item.feature for item in items])
|
287
287
|
vision_outputs_list = []
|
288
288
|
|
289
289
|
for pixel_values_batch in all_pixel_values:
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
265
265
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
266
266
|
"""
|
267
267
|
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
268
|
-
all_pixel_values = flatten_nested_list([item.
|
268
|
+
all_pixel_values = flatten_nested_list([item.feature for item in items])
|
269
269
|
vision_outputs_list = []
|
270
270
|
|
271
271
|
for pixel_values_batch in all_pixel_values:
|
@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
316
316
|
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
|
317
317
|
"""
|
318
318
|
# Extract audio features and masks from items
|
319
|
-
all_input_features = flatten_nested_list(
|
320
|
-
[item.input_features for item in items]
|
321
|
-
)
|
319
|
+
all_input_features = flatten_nested_list([item.feature for item in items])
|
322
320
|
all_input_features_mask = flatten_nested_list(
|
323
321
|
[~item.input_features_mask for item in items]
|
324
322
|
) # Note(Xinyuan): reverse the mask according to the HF implementation
|
@@ -0,0 +1,385 @@
|
|
1
|
+
"""Inference-only GraniteMoe model."""
|
2
|
+
|
3
|
+
from typing import Iterable, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn
|
7
|
+
from transformers import GraniteConfig
|
8
|
+
|
9
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
10
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
11
|
+
from sglang.srt.layers.linear import (
|
12
|
+
QKVParallelLinear,
|
13
|
+
ReplicatedLinear,
|
14
|
+
RowParallelLinear,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
17
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
18
|
+
from sglang.srt.layers.moe.topk import TopK
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
23
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
|
+
ParallelLMHead,
|
25
|
+
VocabParallelEmbedding,
|
26
|
+
)
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.models import mixtral
|
29
|
+
from sglang.srt.utils import add_prefix
|
30
|
+
|
31
|
+
|
32
|
+
class GraniteMoeMoE(nn.Module):
|
33
|
+
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
|
34
|
+
expert across all ranks.
|
35
|
+
Each expert's weights are sharded across all ranks and a fused MoE
|
36
|
+
kernel is used for the forward pass, and finally we reduce the outputs
|
37
|
+
across ranks.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_experts: int,
|
43
|
+
top_k: int,
|
44
|
+
hidden_size: int,
|
45
|
+
intermediate_size: int,
|
46
|
+
params_dtype: Optional[torch.dtype] = None,
|
47
|
+
quant_config: Optional[QuantizationConfig] = None,
|
48
|
+
tp_size: Optional[int] = None,
|
49
|
+
prefix: str = "",
|
50
|
+
):
|
51
|
+
super().__init__()
|
52
|
+
self.hidden_size = hidden_size
|
53
|
+
|
54
|
+
# Gate always runs at half / full precision for now.
|
55
|
+
self.gate = ReplicatedLinear(
|
56
|
+
hidden_size,
|
57
|
+
num_experts,
|
58
|
+
bias=False,
|
59
|
+
params_dtype=params_dtype,
|
60
|
+
quant_config=None,
|
61
|
+
prefix=f"{prefix}.gate",
|
62
|
+
)
|
63
|
+
|
64
|
+
self.topk = TopK(
|
65
|
+
top_k=top_k,
|
66
|
+
renormalize=True,
|
67
|
+
)
|
68
|
+
|
69
|
+
self.experts = FusedMoE(
|
70
|
+
num_experts=num_experts,
|
71
|
+
top_k=top_k,
|
72
|
+
hidden_size=hidden_size,
|
73
|
+
intermediate_size=intermediate_size,
|
74
|
+
params_dtype=params_dtype,
|
75
|
+
reduce_results=True,
|
76
|
+
quant_config=quant_config,
|
77
|
+
tp_size=tp_size,
|
78
|
+
prefix=f"{prefix}.experts",
|
79
|
+
)
|
80
|
+
|
81
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
82
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
83
|
+
orig_shape = hidden_states.shape
|
84
|
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
85
|
+
router_logits, _ = self.gate(hidden_states)
|
86
|
+
topk_output = self.topk(hidden_states, router_logits)
|
87
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
88
|
+
return final_hidden_states.view(orig_shape)
|
89
|
+
|
90
|
+
|
91
|
+
class GraniteMoeAttention(nn.Module):
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
hidden_size: int,
|
96
|
+
num_heads: int,
|
97
|
+
num_kv_heads: int,
|
98
|
+
max_position: int = 4096 * 32,
|
99
|
+
layer_id: int = 0,
|
100
|
+
rope_theta: float = 10000,
|
101
|
+
quant_config: Optional[QuantizationConfig] = None,
|
102
|
+
attention_multiplier: Optional[float] = None,
|
103
|
+
prefix: str = "",
|
104
|
+
) -> None:
|
105
|
+
super().__init__()
|
106
|
+
self.hidden_size = hidden_size
|
107
|
+
tp_size = get_tensor_model_parallel_world_size()
|
108
|
+
self.total_num_heads = num_heads
|
109
|
+
assert self.total_num_heads % tp_size == 0
|
110
|
+
self.num_heads = self.total_num_heads // tp_size
|
111
|
+
self.total_num_kv_heads = num_kv_heads
|
112
|
+
if self.total_num_kv_heads >= tp_size:
|
113
|
+
# Number of KV heads is greater than TP size, so we partition
|
114
|
+
# the KV heads across multiple tensor parallel GPUs.
|
115
|
+
assert self.total_num_kv_heads % tp_size == 0
|
116
|
+
else:
|
117
|
+
# Number of KV heads is less than TP size, so we replicate
|
118
|
+
# the KV heads across multiple tensor parallel GPUs.
|
119
|
+
assert tp_size % self.total_num_kv_heads == 0
|
120
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
121
|
+
self.head_dim = hidden_size // self.total_num_heads
|
122
|
+
self.q_size = self.num_heads * self.head_dim
|
123
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
124
|
+
self.scaling = (
|
125
|
+
attention_multiplier
|
126
|
+
if attention_multiplier is not None
|
127
|
+
else self.head_dim**-1
|
128
|
+
)
|
129
|
+
self.rope_theta = rope_theta
|
130
|
+
|
131
|
+
self.qkv_proj = QKVParallelLinear(
|
132
|
+
hidden_size,
|
133
|
+
self.head_dim,
|
134
|
+
self.total_num_heads,
|
135
|
+
self.total_num_kv_heads,
|
136
|
+
bias=False,
|
137
|
+
quant_config=quant_config,
|
138
|
+
prefix=f"{prefix}.qkv_proj",
|
139
|
+
)
|
140
|
+
self.o_proj = RowParallelLinear(
|
141
|
+
self.total_num_heads * self.head_dim,
|
142
|
+
hidden_size,
|
143
|
+
bias=False,
|
144
|
+
quant_config=quant_config,
|
145
|
+
prefix=f"{prefix}.o_proj",
|
146
|
+
)
|
147
|
+
self.rotary_emb = get_rope(
|
148
|
+
self.head_dim,
|
149
|
+
rotary_dim=self.head_dim,
|
150
|
+
max_position=max_position,
|
151
|
+
base=int(self.rope_theta),
|
152
|
+
is_neox_style=True,
|
153
|
+
)
|
154
|
+
self.attn = RadixAttention(
|
155
|
+
self.num_heads,
|
156
|
+
self.head_dim,
|
157
|
+
self.scaling,
|
158
|
+
num_kv_heads=self.num_kv_heads,
|
159
|
+
layer_id=layer_id,
|
160
|
+
quant_config=quant_config,
|
161
|
+
prefix=f"{prefix}.attn",
|
162
|
+
)
|
163
|
+
|
164
|
+
def forward(
|
165
|
+
self,
|
166
|
+
positions: torch.Tensor,
|
167
|
+
hidden_states: torch.Tensor,
|
168
|
+
forward_batch: ForwardBatch,
|
169
|
+
) -> torch.Tensor:
|
170
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
171
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
172
|
+
q, k = self.rotary_emb(positions, q, k)
|
173
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
174
|
+
output, _ = self.o_proj(attn_output)
|
175
|
+
return output
|
176
|
+
|
177
|
+
|
178
|
+
class GraniteMoeDecoderLayer(nn.Module):
|
179
|
+
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
config: GraniteConfig,
|
183
|
+
layer_id: int = 0,
|
184
|
+
quant_config: Optional[QuantizationConfig] = None,
|
185
|
+
prefix: str = "",
|
186
|
+
) -> None:
|
187
|
+
super().__init__()
|
188
|
+
self.hidden_size = config.hidden_size
|
189
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
190
|
+
self.self_attn = GraniteMoeAttention(
|
191
|
+
hidden_size=self.hidden_size,
|
192
|
+
num_heads=config.num_attention_heads,
|
193
|
+
max_position=config.max_position_embeddings,
|
194
|
+
num_kv_heads=config.num_key_value_heads,
|
195
|
+
rope_theta=rope_theta,
|
196
|
+
layer_id=layer_id,
|
197
|
+
quant_config=quant_config,
|
198
|
+
prefix=f"{prefix}.self_attn",
|
199
|
+
attention_multiplier=config.attention_multiplier,
|
200
|
+
)
|
201
|
+
self.block_sparse_moe = GraniteMoeMoE(
|
202
|
+
num_experts=config.num_local_experts,
|
203
|
+
top_k=config.num_experts_per_tok,
|
204
|
+
hidden_size=config.hidden_size,
|
205
|
+
intermediate_size=config.intermediate_size,
|
206
|
+
quant_config=quant_config,
|
207
|
+
prefix=f"{prefix}.block_sparse_moe",
|
208
|
+
)
|
209
|
+
|
210
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
211
|
+
self.post_attention_layernorm = RMSNorm(
|
212
|
+
config.hidden_size, eps=config.rms_norm_eps
|
213
|
+
)
|
214
|
+
|
215
|
+
self.residual_multiplier = config.residual_multiplier
|
216
|
+
|
217
|
+
def forward(
|
218
|
+
self,
|
219
|
+
positions: torch.Tensor,
|
220
|
+
hidden_states: torch.Tensor,
|
221
|
+
forward_batch: ForwardBatch,
|
222
|
+
) -> torch.Tensor:
|
223
|
+
residual = hidden_states
|
224
|
+
hidden_states = self.input_layernorm(hidden_states)
|
225
|
+
# Self Attention
|
226
|
+
hidden_states = self.self_attn(
|
227
|
+
positions=positions,
|
228
|
+
hidden_states=hidden_states,
|
229
|
+
forward_batch=forward_batch,
|
230
|
+
)
|
231
|
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
232
|
+
residual = hidden_states
|
233
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
234
|
+
hidden_states = self.block_sparse_moe(hidden_states)
|
235
|
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
236
|
+
|
237
|
+
return hidden_states
|
238
|
+
|
239
|
+
|
240
|
+
class GraniteMoeModel(nn.Module):
|
241
|
+
|
242
|
+
def __init__(
|
243
|
+
self,
|
244
|
+
config: GraniteConfig,
|
245
|
+
quant_config: Optional[QuantizationConfig] = None,
|
246
|
+
prefix: str = "",
|
247
|
+
):
|
248
|
+
super().__init__()
|
249
|
+
self.embed_tokens = VocabParallelEmbedding(
|
250
|
+
config.vocab_size,
|
251
|
+
config.hidden_size,
|
252
|
+
org_num_embeddings=config.vocab_size,
|
253
|
+
)
|
254
|
+
self.embedding_multiplier = config.embedding_multiplier
|
255
|
+
|
256
|
+
self.layers = nn.ModuleList(
|
257
|
+
[
|
258
|
+
GraniteMoeDecoderLayer(
|
259
|
+
config,
|
260
|
+
i,
|
261
|
+
quant_config=quant_config,
|
262
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
263
|
+
)
|
264
|
+
for i in range(config.num_hidden_layers)
|
265
|
+
]
|
266
|
+
)
|
267
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
268
|
+
|
269
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
270
|
+
return self.embed_tokens(input_ids)
|
271
|
+
|
272
|
+
def forward(
|
273
|
+
self,
|
274
|
+
input_ids: torch.Tensor,
|
275
|
+
positions: torch.Tensor,
|
276
|
+
forward_batch: ForwardBatch,
|
277
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
278
|
+
) -> torch.Tensor:
|
279
|
+
if inputs_embeds is not None:
|
280
|
+
hidden_states = inputs_embeds
|
281
|
+
else:
|
282
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
283
|
+
hidden_states *= self.embedding_multiplier
|
284
|
+
|
285
|
+
for i in range(len(self.layers)):
|
286
|
+
layer = self.layers[i]
|
287
|
+
hidden_states = layer(
|
288
|
+
positions,
|
289
|
+
hidden_states,
|
290
|
+
forward_batch,
|
291
|
+
)
|
292
|
+
hidden_states = self.norm(hidden_states)
|
293
|
+
return hidden_states
|
294
|
+
|
295
|
+
|
296
|
+
class GraniteMoeForCausalLM(nn.Module):
|
297
|
+
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
config: GraniteConfig,
|
301
|
+
quant_config: Optional[QuantizationConfig] = None,
|
302
|
+
prefix: str = "",
|
303
|
+
):
|
304
|
+
super().__init__()
|
305
|
+
self.config = config
|
306
|
+
self.quant_config = quant_config
|
307
|
+
|
308
|
+
self.model = GraniteMoeModel(
|
309
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
310
|
+
)
|
311
|
+
self.lm_head = ParallelLMHead(
|
312
|
+
config.vocab_size,
|
313
|
+
config.hidden_size,
|
314
|
+
quant_config=quant_config,
|
315
|
+
prefix=add_prefix("lm_head", prefix),
|
316
|
+
)
|
317
|
+
if config.tie_word_embeddings:
|
318
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
319
|
+
# Granite logit scaling factors are applied via division, but
|
320
|
+
# LogitsProcessor expects a multiplicative factor.
|
321
|
+
if hasattr(config, "logits_scaling"):
|
322
|
+
logit_scale = 1.0 / config.logits_scaling
|
323
|
+
else:
|
324
|
+
logit_scale = None
|
325
|
+
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
|
326
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
327
|
+
|
328
|
+
@torch.no_grad()
|
329
|
+
def forward(
|
330
|
+
self,
|
331
|
+
input_ids: torch.Tensor,
|
332
|
+
positions: torch.Tensor,
|
333
|
+
forward_batch: ForwardBatch,
|
334
|
+
input_embeds: torch.Tensor = None,
|
335
|
+
get_embedding: bool = False,
|
336
|
+
) -> LogitsProcessorOutput:
|
337
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
338
|
+
if not get_embedding:
|
339
|
+
logits_processor_output: LogitsProcessorOutput = self.logits_processor(
|
340
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
341
|
+
)
|
342
|
+
return logits_processor_output
|
343
|
+
else:
|
344
|
+
return self.pooler(hidden_states, forward_batch)
|
345
|
+
|
346
|
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
347
|
+
new_weights = {}
|
348
|
+
for n, p in weights:
|
349
|
+
if n.endswith(".block_sparse_moe.input_linear.weight"):
|
350
|
+
for e in range(p.size(0)):
|
351
|
+
w1_name = n.replace(
|
352
|
+
".block_sparse_moe.input_linear.weight",
|
353
|
+
f".block_sparse_moe.experts.{e}.w1.weight",
|
354
|
+
)
|
355
|
+
w3_name = n.replace(
|
356
|
+
".block_sparse_moe.input_linear.weight",
|
357
|
+
f".block_sparse_moe.experts.{e}.w3.weight",
|
358
|
+
)
|
359
|
+
w1_param, w3_param = p[e].chunk(2, dim=0)
|
360
|
+
assert w1_name not in new_weights
|
361
|
+
assert w3_name not in new_weights
|
362
|
+
new_weights[w1_name] = w1_param
|
363
|
+
new_weights[w3_name] = w3_param
|
364
|
+
elif n.endswith(".block_sparse_moe.output_linear.weight"):
|
365
|
+
for e in range(p.size(0)):
|
366
|
+
w2_name = n.replace(
|
367
|
+
".block_sparse_moe.output_linear.weight",
|
368
|
+
f".block_sparse_moe.experts.{e}.w2.weight",
|
369
|
+
)
|
370
|
+
w2_param = p[e]
|
371
|
+
assert w2_name not in new_weights
|
372
|
+
new_weights[w2_name] = w2_param
|
373
|
+
elif n.endswith(".block_sparse_moe.router.layer.weight"):
|
374
|
+
gate_name = n.replace(
|
375
|
+
".block_sparse_moe.router.layer.weight",
|
376
|
+
".block_sparse_moe.gate.weight",
|
377
|
+
)
|
378
|
+
assert gate_name not in new_weights
|
379
|
+
new_weights[gate_name] = p
|
380
|
+
else:
|
381
|
+
new_weights[n] = p
|
382
|
+
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
|
383
|
+
|
384
|
+
|
385
|
+
EntryClass = [GraniteMoeForCausalLM]
|
sglang/srt/models/grok.py
CHANGED
@@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
45
45
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
46
46
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
47
47
|
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
48
|
+
from sglang.srt.layers.moe.topk import TopK
|
48
49
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
49
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
50
51
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -108,6 +109,12 @@ class Grok1MoE(nn.Module):
|
|
108
109
|
fused_moe_router_shim, self.router_logit_softcapping
|
109
110
|
)
|
110
111
|
|
112
|
+
self.topk = TopK(
|
113
|
+
top_k=top_k,
|
114
|
+
renormalize=False,
|
115
|
+
custom_routing_function=custom_routing_function,
|
116
|
+
)
|
117
|
+
|
111
118
|
kwargs = {}
|
112
119
|
if global_server_args_dict["enable_ep_moe"]:
|
113
120
|
MoEImpl = EPMoE
|
@@ -124,17 +131,16 @@ class Grok1MoE(nn.Module):
|
|
124
131
|
hidden_size=hidden_size,
|
125
132
|
intermediate_size=intermediate_size,
|
126
133
|
params_dtype=params_dtype,
|
127
|
-
renormalize=False,
|
128
134
|
quant_config=quant_config,
|
129
135
|
tp_size=tp_size,
|
130
|
-
custom_routing_function=custom_routing_function,
|
131
136
|
activation="gelu",
|
132
137
|
**kwargs,
|
133
138
|
)
|
134
139
|
|
135
140
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
136
141
|
# need to assert self.gate.quant_method is unquantized
|
137
|
-
|
142
|
+
topk_output = self.topk(hidden_states, self.gate.weight)
|
143
|
+
return self.experts(hidden_states, topk_output)
|
138
144
|
|
139
145
|
|
140
146
|
class Grok1Attention(nn.Module):
|