sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ import torch
|
|
38
38
|
import triton
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
41
42
|
from sglang.srt.layers.dp_attention import (
|
42
43
|
DPPaddingMode,
|
43
44
|
get_attention_dp_rank,
|
@@ -179,6 +180,9 @@ class ForwardBatch:
|
|
179
180
|
# The sum of all sequence lengths
|
180
181
|
seq_lens_sum: int
|
181
182
|
|
183
|
+
# The original sequence length without being chunked. Qwen-1M related.
|
184
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
185
|
+
|
182
186
|
# Optional seq_lens on cpu
|
183
187
|
seq_lens_cpu: Optional[torch.Tensor] = None
|
184
188
|
|
@@ -188,6 +192,7 @@ class ForwardBatch:
|
|
188
192
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
189
193
|
|
190
194
|
# For logits and logprobs post processing
|
195
|
+
next_token_logits_buffer: torch.Tensor = None
|
191
196
|
temp_scaled_logprobs: bool = False
|
192
197
|
temperature: torch.Tensor = None
|
193
198
|
top_p_normalized_logprobs: bool = False
|
@@ -246,7 +251,7 @@ class ForwardBatch:
|
|
246
251
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
247
252
|
|
248
253
|
# For LoRA
|
249
|
-
|
254
|
+
lora_ids: Optional[List[str]] = None
|
250
255
|
|
251
256
|
# For input embeddings
|
252
257
|
input_embeds: Optional[torch.Tensor] = None
|
@@ -319,13 +324,14 @@ class ForwardBatch:
|
|
319
324
|
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
320
325
|
seq_lens_sum=batch.seq_lens_sum,
|
321
326
|
seq_lens_cpu=batch.seq_lens_cpu,
|
327
|
+
orig_seq_lens=batch.orig_seq_lens,
|
322
328
|
return_logprob=batch.return_logprob,
|
323
329
|
top_logprobs_nums=batch.top_logprobs_nums,
|
324
330
|
token_ids_logprobs=batch.token_ids_logprobs,
|
325
331
|
is_extend_in_batch=batch.is_extend_in_batch,
|
326
332
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
327
333
|
global_forward_mode=batch.global_forward_mode,
|
328
|
-
|
334
|
+
lora_ids=batch.lora_ids,
|
329
335
|
sampling_info=batch.sampling_info,
|
330
336
|
req_to_token_pool=model_runner.req_to_token_pool,
|
331
337
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
@@ -418,16 +424,12 @@ class ForwardBatch:
|
|
418
424
|
batch.extend_prefix_lens, dtype=torch.int32
|
419
425
|
).to(device, non_blocking=True)
|
420
426
|
ret.extend_num_tokens = batch.extend_num_tokens
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
else:
|
428
|
-
positions, ret.extend_start_loc = compute_position_torch(
|
429
|
-
ret.extend_prefix_lens, ret.extend_seq_lens
|
430
|
-
)
|
427
|
+
positions, ret.extend_start_loc = compute_position(
|
428
|
+
model_runner.server_args.attention_backend,
|
429
|
+
ret.extend_prefix_lens,
|
430
|
+
ret.extend_seq_lens,
|
431
|
+
ret.extend_num_tokens,
|
432
|
+
)
|
431
433
|
if ret.positions is None:
|
432
434
|
ret.positions = positions
|
433
435
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
@@ -630,8 +632,10 @@ class ForwardBatch:
|
|
630
632
|
self.dp_padding_mode = dp_padding_mode
|
631
633
|
|
632
634
|
if dp_padding_mode.is_max_len():
|
633
|
-
# when DP gather mode is all gather, we will use
|
634
|
-
#
|
635
|
+
# when DP gather mode is all gather, we will use
|
636
|
+
# all_gather_into_tensor to gather hidden states, where transferred
|
637
|
+
# tokens should be padded to the same length. We will also use
|
638
|
+
# reduce-scatter instead of all-reduce after MLP.
|
635
639
|
max_num_tokens = max(global_num_tokens)
|
636
640
|
global_num_tokens = [max_num_tokens] * sync_group_size
|
637
641
|
buffer_len = max_num_tokens * sync_group_size
|
@@ -644,12 +648,17 @@ class ForwardBatch:
|
|
644
648
|
device=model_runner.device,
|
645
649
|
)
|
646
650
|
|
647
|
-
bs = self.batch_size
|
648
651
|
if len(global_num_tokens) > 1:
|
649
652
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
650
653
|
else:
|
651
654
|
num_tokens = global_num_tokens[0]
|
652
655
|
|
656
|
+
if self.forward_mode.is_decode():
|
657
|
+
setattr(self, "raw_bs", self.batch_size)
|
658
|
+
self.batch_size = num_tokens
|
659
|
+
|
660
|
+
bs = self.batch_size
|
661
|
+
|
653
662
|
# padding
|
654
663
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
655
664
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
@@ -657,6 +666,9 @@ class ForwardBatch:
|
|
657
666
|
seq_len_fill_value = (
|
658
667
|
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
659
668
|
)
|
669
|
+
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
|
670
|
+
bs - self.seq_lens.shape[0]
|
671
|
+
)
|
660
672
|
self.seq_lens = self._pad_tensor_to_size(
|
661
673
|
self.seq_lens, bs, value=seq_len_fill_value
|
662
674
|
)
|
@@ -700,7 +712,7 @@ class ForwardBatch:
|
|
700
712
|
|
701
713
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
702
714
|
|
703
|
-
bs = self.batch_size
|
715
|
+
bs = getattr(self, "raw_bs", self.batch_size)
|
704
716
|
|
705
717
|
if self.spec_info is not None:
|
706
718
|
if self.forward_mode.is_decode(): # draft
|
@@ -839,7 +851,7 @@ class ForwardBatch:
|
|
839
851
|
|
840
852
|
|
841
853
|
def enable_num_token_non_padded(server_args):
|
842
|
-
return
|
854
|
+
return get_moe_expert_parallel_world_size() > 1
|
843
855
|
|
844
856
|
|
845
857
|
class PPProxyTensors:
|
@@ -872,6 +884,25 @@ class PPProxyTensors:
|
|
872
884
|
return f"PPProxyTensors(tensors={self.tensors})"
|
873
885
|
|
874
886
|
|
887
|
+
def compute_position(
|
888
|
+
attn_backend: str,
|
889
|
+
extend_prefix_lens: torch.Tensor,
|
890
|
+
extend_seq_lens: torch.Tensor,
|
891
|
+
extend_seq_lens_sum: int,
|
892
|
+
):
|
893
|
+
if support_triton(attn_backend):
|
894
|
+
positions, extend_start_loc = compute_position_triton(
|
895
|
+
extend_prefix_lens,
|
896
|
+
extend_seq_lens,
|
897
|
+
extend_seq_lens_sum,
|
898
|
+
)
|
899
|
+
else:
|
900
|
+
positions, extend_start_loc = compute_position_torch(
|
901
|
+
extend_prefix_lens, extend_seq_lens
|
902
|
+
)
|
903
|
+
return positions, extend_start_loc
|
904
|
+
|
905
|
+
|
875
906
|
def compute_position_triton(
|
876
907
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
877
908
|
):
|
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
|
|
60
60
|
initialize_dp_attention,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
63
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
63
64
|
from sglang.srt.layers.quantization import (
|
64
65
|
deep_gemm_wrapper,
|
65
66
|
monkey_patch_isinstance_for_vllm_base_layer,
|
@@ -217,6 +218,10 @@ class ModelRunner:
|
|
217
218
|
"use_mla_backend": self.use_mla_backend,
|
218
219
|
"speculative_algorithm": self.spec_algorithm,
|
219
220
|
}
|
221
|
+
| {
|
222
|
+
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
223
|
+
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
224
|
+
}
|
220
225
|
)
|
221
226
|
|
222
227
|
# CPU offload
|
@@ -1438,19 +1443,36 @@ class ModelRunner:
|
|
1438
1443
|
)
|
1439
1444
|
|
1440
1445
|
return CutlassMLABackend(self)
|
1441
|
-
elif
|
1446
|
+
elif backend_str == "trtllm_mla":
|
1442
1447
|
if not self.use_mla_backend:
|
1443
1448
|
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1444
1449
|
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1445
1450
|
|
1446
1451
|
return TRTLLMMLABackend(self)
|
1447
|
-
elif
|
1452
|
+
elif backend_str == "trtllm_mha":
|
1453
|
+
if self.use_mla_backend:
|
1454
|
+
raise ValueError(
|
1455
|
+
"trtllm_mha backend can only be used with non-MLA models."
|
1456
|
+
)
|
1457
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1458
|
+
TRTLLMHAAttnBackend,
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
return TRTLLMHAAttnBackend(self)
|
1462
|
+
|
1463
|
+
elif backend_str == "intel_amx":
|
1448
1464
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1449
1465
|
IntelAMXAttnBackend,
|
1450
1466
|
)
|
1451
1467
|
|
1452
1468
|
logger.info(f"Intel AMX attention backend is enabled.")
|
1453
1469
|
return IntelAMXAttnBackend(self)
|
1470
|
+
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
1471
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1472
|
+
DualChunkFlashAttentionBackend,
|
1473
|
+
)
|
1474
|
+
|
1475
|
+
return DualChunkFlashAttentionBackend(self)
|
1454
1476
|
else:
|
1455
1477
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1456
1478
|
|
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
843
843
|
return None
|
844
844
|
return remapped_name
|
845
845
|
|
846
|
+
quark_scale_names = {
|
847
|
+
".q_proj.output_scale": ".attn.q_scale",
|
848
|
+
".k_proj.output_scale": ".attn.k_scale",
|
849
|
+
".v_proj.output_scale": ".attn.v_scale",
|
850
|
+
"self_attn.prob_output_scale": ".attn.prob_scale",
|
851
|
+
}
|
852
|
+
for quark_scale_name, sglang_scale_name in quark_scale_names.items():
|
853
|
+
if name.endswith(quark_scale_name):
|
854
|
+
return name.replace(quark_scale_name, sglang_scale_name)
|
855
|
+
|
846
856
|
# If there were no matches, return the untouched param name
|
847
857
|
return name
|
848
858
|
|
@@ -0,0 +1,425 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
|
3
|
+
|
4
|
+
from collections.abc import Iterable
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers.configuration_utils import PretrainedConfig
|
11
|
+
|
12
|
+
from sglang.srt.distributed import (
|
13
|
+
get_tensor_model_parallel_world_size,
|
14
|
+
tensor_model_parallel_all_reduce,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.activation import SiluAndMul
|
17
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
18
|
+
from sglang.srt.layers.linear import (
|
19
|
+
MergedColumnParallelLinear,
|
20
|
+
QKVParallelLinear,
|
21
|
+
ReplicatedLinear,
|
22
|
+
RowParallelLinear,
|
23
|
+
)
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
26
|
+
from sglang.srt.layers.moe.topk import TopK
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
29
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
30
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
31
|
+
ParallelLMHead,
|
32
|
+
VocabParallelEmbedding,
|
33
|
+
)
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
36
|
+
from sglang.srt.utils import add_prefix, make_layers
|
37
|
+
|
38
|
+
|
39
|
+
class BailingAttention(nn.Module):
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: PretrainedConfig,
|
44
|
+
layer_id: int = 0,
|
45
|
+
quant_config: Optional[QuantizationConfig] = None,
|
46
|
+
prefix: str = "",
|
47
|
+
):
|
48
|
+
super().__init__()
|
49
|
+
self.hidden_size = config.hidden_size
|
50
|
+
tp_size = get_tensor_model_parallel_world_size()
|
51
|
+
|
52
|
+
self.total_num_heads = config.num_attention_heads
|
53
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
54
|
+
|
55
|
+
assert self.total_num_heads % tp_size == 0
|
56
|
+
assert self.total_num_kv_heads % tp_size == 0
|
57
|
+
|
58
|
+
self.num_heads = self.total_num_heads // tp_size
|
59
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
60
|
+
self.q_size = self.num_heads * self.head_dim
|
61
|
+
|
62
|
+
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
63
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
64
|
+
self.scale = self.head_dim**-0.5
|
65
|
+
|
66
|
+
self.query_key_value = QKVParallelLinear(
|
67
|
+
self.hidden_size,
|
68
|
+
self.head_dim,
|
69
|
+
self.total_num_heads,
|
70
|
+
self.total_num_kv_heads,
|
71
|
+
bias=(config.use_bias or config.use_qkv_bias),
|
72
|
+
quant_config=quant_config,
|
73
|
+
prefix=add_prefix("query_key_value", prefix),
|
74
|
+
)
|
75
|
+
|
76
|
+
self.dense = RowParallelLinear(
|
77
|
+
self.total_num_heads * self.head_dim,
|
78
|
+
self.hidden_size,
|
79
|
+
bias=config.use_bias,
|
80
|
+
quant_config=quant_config,
|
81
|
+
prefix=add_prefix("dense", prefix),
|
82
|
+
)
|
83
|
+
|
84
|
+
self.attn = RadixAttention(
|
85
|
+
self.num_heads,
|
86
|
+
self.head_dim,
|
87
|
+
self.scale,
|
88
|
+
num_kv_heads=self.num_kv_heads,
|
89
|
+
layer_id=layer_id,
|
90
|
+
quant_config=quant_config,
|
91
|
+
prefix=add_prefix("attn", prefix),
|
92
|
+
)
|
93
|
+
|
94
|
+
self.rotary_emb = get_rope(
|
95
|
+
self.head_dim,
|
96
|
+
rotary_dim=self.head_dim,
|
97
|
+
max_position=config.max_position_embeddings,
|
98
|
+
base=config.rope_theta,
|
99
|
+
is_neox_style=True,
|
100
|
+
rope_scaling=config.rope_scaling,
|
101
|
+
)
|
102
|
+
|
103
|
+
def forward(
|
104
|
+
self,
|
105
|
+
hidden_states: torch.Tensor,
|
106
|
+
position_ids: torch.Tensor,
|
107
|
+
forward_batch: ForwardBatch,
|
108
|
+
) -> torch.Tensor:
|
109
|
+
qkv, _ = self.query_key_value(hidden_states)
|
110
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
111
|
+
|
112
|
+
q, k = self.rotary_emb(position_ids, q, k)
|
113
|
+
context_layer = self.attn(q, k, v, forward_batch)
|
114
|
+
attn_output, _ = self.dense(context_layer)
|
115
|
+
return attn_output
|
116
|
+
|
117
|
+
|
118
|
+
class BailingMLP(nn.Module):
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
intermediate_size: int,
|
122
|
+
config: PretrainedConfig,
|
123
|
+
quant_config: Optional[QuantizationConfig] = None,
|
124
|
+
reduce_results: Optional[bool] = True,
|
125
|
+
prefix: str = "",
|
126
|
+
) -> None:
|
127
|
+
super().__init__()
|
128
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
129
|
+
config.hidden_size,
|
130
|
+
[intermediate_size] * 2,
|
131
|
+
bias=config.use_bias,
|
132
|
+
quant_config=quant_config,
|
133
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
134
|
+
)
|
135
|
+
self.down_proj = RowParallelLinear(
|
136
|
+
intermediate_size,
|
137
|
+
config.hidden_size,
|
138
|
+
bias=config.use_bias,
|
139
|
+
quant_config=quant_config,
|
140
|
+
reduce_results=reduce_results,
|
141
|
+
prefix=add_prefix("down_proj", prefix),
|
142
|
+
)
|
143
|
+
self.act_fn = SiluAndMul()
|
144
|
+
|
145
|
+
def forward(self, x):
|
146
|
+
x, _ = self.gate_up_proj(x)
|
147
|
+
x = self.act_fn(x)
|
148
|
+
x, _ = self.down_proj(x)
|
149
|
+
return x
|
150
|
+
|
151
|
+
|
152
|
+
class BailingMoE(nn.Module):
|
153
|
+
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
config: PretrainedConfig,
|
157
|
+
layer_id: int,
|
158
|
+
quant_config: Optional[QuantizationConfig] = None,
|
159
|
+
prefix: str = "",
|
160
|
+
):
|
161
|
+
super().__init__()
|
162
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
163
|
+
self.num_experts = config.num_experts
|
164
|
+
self.top_k = config.num_experts_per_tok
|
165
|
+
self.hidden_size = config.hidden_size
|
166
|
+
self.num_shared_experts = config.num_shared_experts
|
167
|
+
self.norm_expert_prob = config.norm_topk_prob
|
168
|
+
self.moe_intermediate_size = config.moe_intermediate_size
|
169
|
+
|
170
|
+
self.gate = ReplicatedLinear(
|
171
|
+
self.hidden_size, self.num_experts, bias=False, quant_config=None
|
172
|
+
)
|
173
|
+
|
174
|
+
self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
|
175
|
+
|
176
|
+
self.experts = FusedMoE(
|
177
|
+
num_experts=self.num_experts,
|
178
|
+
top_k=self.top_k,
|
179
|
+
layer_id=layer_id,
|
180
|
+
hidden_size=self.hidden_size,
|
181
|
+
intermediate_size=self.moe_intermediate_size,
|
182
|
+
reduce_results=False,
|
183
|
+
quant_config=quant_config,
|
184
|
+
prefix=add_prefix("experts", prefix),
|
185
|
+
)
|
186
|
+
|
187
|
+
if self.num_shared_experts > 0:
|
188
|
+
shared_intermediate_size = (
|
189
|
+
self.moe_intermediate_size * self.num_shared_experts
|
190
|
+
)
|
191
|
+
self.shared_experts = BailingMLP(
|
192
|
+
intermediate_size=shared_intermediate_size,
|
193
|
+
config=config,
|
194
|
+
quant_config=quant_config,
|
195
|
+
reduce_results=False,
|
196
|
+
prefix=add_prefix("shared_experts", prefix),
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
self.shared_experts = None
|
200
|
+
|
201
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
202
|
+
orig_shape = hidden_states.shape
|
203
|
+
hidden_states_flat = hidden_states.view(-1, self.hidden_size)
|
204
|
+
|
205
|
+
shared_output = None
|
206
|
+
if self.shared_experts is not None:
|
207
|
+
shared_output = self.shared_experts(hidden_states_flat)
|
208
|
+
|
209
|
+
router_logits, _ = self.gate(hidden_states_flat)
|
210
|
+
topk_output = self.topk(hidden_states_flat, router_logits)
|
211
|
+
final_hidden_states = self.experts(hidden_states_flat, topk_output)
|
212
|
+
|
213
|
+
if shared_output is not None:
|
214
|
+
final_hidden_states = final_hidden_states + shared_output
|
215
|
+
|
216
|
+
if self.tp_size > 1:
|
217
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
218
|
+
|
219
|
+
return final_hidden_states.view(orig_shape)
|
220
|
+
|
221
|
+
|
222
|
+
class BailingMoeBlock(nn.Module):
|
223
|
+
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
config: PretrainedConfig,
|
227
|
+
layer_id: int,
|
228
|
+
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
prefix: str = "",
|
230
|
+
):
|
231
|
+
super().__init__()
|
232
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
233
|
+
self.attention = BailingAttention(
|
234
|
+
config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
|
235
|
+
)
|
236
|
+
self.post_attention_layernorm = RMSNorm(
|
237
|
+
config.hidden_size, eps=config.rms_norm_eps
|
238
|
+
)
|
239
|
+
self.mlp = BailingMoE(
|
240
|
+
config=config,
|
241
|
+
layer_id=layer_id,
|
242
|
+
quant_config=quant_config,
|
243
|
+
prefix=add_prefix("mlp", prefix),
|
244
|
+
)
|
245
|
+
|
246
|
+
def forward(
|
247
|
+
self,
|
248
|
+
hidden_states: torch.Tensor,
|
249
|
+
position_ids: torch.Tensor,
|
250
|
+
residual: Optional[torch.Tensor],
|
251
|
+
forward_batch: ForwardBatch,
|
252
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
253
|
+
# Pre-normalization and residual connection for the attention block
|
254
|
+
if residual is None:
|
255
|
+
residual = hidden_states
|
256
|
+
normed_hidden_states = self.input_layernorm(hidden_states)
|
257
|
+
else:
|
258
|
+
normed_hidden_states, residual = self.input_layernorm(
|
259
|
+
hidden_states, residual
|
260
|
+
)
|
261
|
+
|
262
|
+
attn_output = self.attention(
|
263
|
+
hidden_states=normed_hidden_states,
|
264
|
+
position_ids=position_ids,
|
265
|
+
forward_batch=forward_batch,
|
266
|
+
)
|
267
|
+
|
268
|
+
# Pre-normalization and residual connection for the MLP block
|
269
|
+
normed_hidden_states, residual = self.post_attention_layernorm(
|
270
|
+
attn_output, residual
|
271
|
+
)
|
272
|
+
mlp_output = self.mlp(normed_hidden_states)
|
273
|
+
|
274
|
+
return mlp_output, residual
|
275
|
+
|
276
|
+
|
277
|
+
class BailingMoeModel(nn.Module):
|
278
|
+
|
279
|
+
def __init__(
|
280
|
+
self,
|
281
|
+
config: PretrainedConfig,
|
282
|
+
quant_config: Optional[QuantizationConfig] = None,
|
283
|
+
prefix: str = "",
|
284
|
+
):
|
285
|
+
super().__init__()
|
286
|
+
self.config = config
|
287
|
+
self.padding_idx = config.pad_token_id
|
288
|
+
self.vocab_size = config.vocab_size
|
289
|
+
self.embed_dim = config.hidden_size
|
290
|
+
|
291
|
+
self.embed_tokens = VocabParallelEmbedding(
|
292
|
+
config.vocab_size,
|
293
|
+
config.hidden_size,
|
294
|
+
prefix=add_prefix("embed_tokens", prefix),
|
295
|
+
)
|
296
|
+
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
|
297
|
+
|
298
|
+
self.layers = make_layers(
|
299
|
+
config.num_hidden_layers,
|
300
|
+
lambda idx, prefix: BailingMoeBlock(
|
301
|
+
config=config,
|
302
|
+
layer_id=idx,
|
303
|
+
quant_config=quant_config,
|
304
|
+
prefix=prefix,
|
305
|
+
),
|
306
|
+
prefix=add_prefix("layers", prefix),
|
307
|
+
)
|
308
|
+
|
309
|
+
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
310
|
+
|
311
|
+
def forward(
|
312
|
+
self,
|
313
|
+
input_ids: torch.Tensor,
|
314
|
+
position_ids: torch.Tensor,
|
315
|
+
forward_batch: ForwardBatch,
|
316
|
+
input_embeds: Optional[torch.Tensor] = None,
|
317
|
+
) -> torch.Tensor:
|
318
|
+
if input_embeds is None:
|
319
|
+
hidden_states = self.embed_tokens(input_ids)
|
320
|
+
else:
|
321
|
+
hidden_states = input_embeds
|
322
|
+
|
323
|
+
residual = None
|
324
|
+
for layer in self.layers:
|
325
|
+
hidden_states, residual = layer(
|
326
|
+
hidden_states,
|
327
|
+
position_ids,
|
328
|
+
residual,
|
329
|
+
forward_batch,
|
330
|
+
)
|
331
|
+
|
332
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
333
|
+
return hidden_states
|
334
|
+
|
335
|
+
|
336
|
+
class BailingMoeForCausalLM(nn.Module):
|
337
|
+
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
config: PretrainedConfig,
|
341
|
+
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
) -> None:
|
343
|
+
super().__init__()
|
344
|
+
self.config = config
|
345
|
+
self.model = BailingMoeModel(config=config, quant_config=quant_config)
|
346
|
+
self.lm_head = ParallelLMHead(
|
347
|
+
num_embeddings=config.vocab_size,
|
348
|
+
embedding_dim=config.hidden_size,
|
349
|
+
quant_config=quant_config,
|
350
|
+
)
|
351
|
+
if config.tie_word_embeddings:
|
352
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
353
|
+
|
354
|
+
self.logits_processor = LogitsProcessor(config)
|
355
|
+
|
356
|
+
def forward(
|
357
|
+
self,
|
358
|
+
input_ids: torch.Tensor,
|
359
|
+
positions: torch.Tensor,
|
360
|
+
forward_batch: ForwardBatch,
|
361
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
362
|
+
) -> torch.Tensor:
|
363
|
+
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
|
364
|
+
return self.logits_processor(
|
365
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
366
|
+
)
|
367
|
+
|
368
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
369
|
+
|
370
|
+
stacked_params_mapping = [
|
371
|
+
("gate_up_proj", "gate_proj", 0),
|
372
|
+
("gate_up_proj", "up_proj", 1),
|
373
|
+
]
|
374
|
+
|
375
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
376
|
+
ckpt_gate_proj_name="gate_proj",
|
377
|
+
ckpt_down_proj_name="down_proj",
|
378
|
+
ckpt_up_proj_name="up_proj",
|
379
|
+
num_experts=self.config.num_experts,
|
380
|
+
)
|
381
|
+
|
382
|
+
params_dict = dict(self.named_parameters())
|
383
|
+
for name, loaded_weight in weights:
|
384
|
+
|
385
|
+
if (
|
386
|
+
hasattr(self.config, "norm_head")
|
387
|
+
and self.config.norm_head
|
388
|
+
and "lm_head.weight" in name
|
389
|
+
):
|
390
|
+
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
|
391
|
+
|
392
|
+
if "model.word_embeddings.weight" == name:
|
393
|
+
name = "model.embed_tokens.weight"
|
394
|
+
|
395
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
396
|
+
if weight_name in name and "mlp.experts" not in name:
|
397
|
+
full_param_name = name.replace(weight_name, param_name)
|
398
|
+
param = params_dict[full_param_name]
|
399
|
+
param.weight_loader(param, loaded_weight, shard_id)
|
400
|
+
break
|
401
|
+
else:
|
402
|
+
for p_name, w_name, e_id, s_id in expert_params_mapping:
|
403
|
+
if w_name in name and "mlp.experts" in name:
|
404
|
+
full_param_name = name.replace(w_name, p_name)
|
405
|
+
param = params_dict[full_param_name]
|
406
|
+
param.weight_loader(
|
407
|
+
param,
|
408
|
+
loaded_weight,
|
409
|
+
full_param_name,
|
410
|
+
shard_id=s_id,
|
411
|
+
expert_id=e_id,
|
412
|
+
)
|
413
|
+
break
|
414
|
+
else:
|
415
|
+
if name.endswith(".bias") and name not in params_dict:
|
416
|
+
continue
|
417
|
+
|
418
|
+
param = params_dict[name]
|
419
|
+
weight_loader = getattr(
|
420
|
+
param, "weight_loader", default_weight_loader
|
421
|
+
)
|
422
|
+
weight_loader(param, loaded_weight)
|
423
|
+
|
424
|
+
|
425
|
+
EntryClass = BailingMoeForCausalLM
|