sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ import torch
38
38
  import triton
39
39
  import triton.language as tl
40
40
 
41
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
41
42
  from sglang.srt.layers.dp_attention import (
42
43
  DPPaddingMode,
43
44
  get_attention_dp_rank,
@@ -179,6 +180,9 @@ class ForwardBatch:
179
180
  # The sum of all sequence lengths
180
181
  seq_lens_sum: int
181
182
 
183
+ # The original sequence length without being chunked. Qwen-1M related.
184
+ orig_seq_lens: Optional[torch.Tensor] = None
185
+
182
186
  # Optional seq_lens on cpu
183
187
  seq_lens_cpu: Optional[torch.Tensor] = None
184
188
 
@@ -188,6 +192,7 @@ class ForwardBatch:
188
192
  token_ids_logprobs: Optional[List[List[int]]] = None
189
193
 
190
194
  # For logits and logprobs post processing
195
+ next_token_logits_buffer: torch.Tensor = None
191
196
  temp_scaled_logprobs: bool = False
192
197
  temperature: torch.Tensor = None
193
198
  top_p_normalized_logprobs: bool = False
@@ -246,7 +251,7 @@ class ForwardBatch:
246
251
  encoder_out_cache_loc: Optional[torch.Tensor] = None
247
252
 
248
253
  # For LoRA
249
- lora_paths: Optional[List[str]] = None
254
+ lora_ids: Optional[List[str]] = None
250
255
 
251
256
  # For input embeddings
252
257
  input_embeds: Optional[torch.Tensor] = None
@@ -319,13 +324,14 @@ class ForwardBatch:
319
324
  encoder_out_cache_loc=batch.encoder_out_cache_loc,
320
325
  seq_lens_sum=batch.seq_lens_sum,
321
326
  seq_lens_cpu=batch.seq_lens_cpu,
327
+ orig_seq_lens=batch.orig_seq_lens,
322
328
  return_logprob=batch.return_logprob,
323
329
  top_logprobs_nums=batch.top_logprobs_nums,
324
330
  token_ids_logprobs=batch.token_ids_logprobs,
325
331
  is_extend_in_batch=batch.is_extend_in_batch,
326
332
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
327
333
  global_forward_mode=batch.global_forward_mode,
328
- lora_paths=batch.lora_paths,
334
+ lora_ids=batch.lora_ids,
329
335
  sampling_info=batch.sampling_info,
330
336
  req_to_token_pool=model_runner.req_to_token_pool,
331
337
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -418,16 +424,12 @@ class ForwardBatch:
418
424
  batch.extend_prefix_lens, dtype=torch.int32
419
425
  ).to(device, non_blocking=True)
420
426
  ret.extend_num_tokens = batch.extend_num_tokens
421
- if support_triton(model_runner.server_args.attention_backend):
422
- positions, ret.extend_start_loc = compute_position_triton(
423
- ret.extend_prefix_lens,
424
- ret.extend_seq_lens,
425
- ret.extend_num_tokens,
426
- )
427
- else:
428
- positions, ret.extend_start_loc = compute_position_torch(
429
- ret.extend_prefix_lens, ret.extend_seq_lens
430
- )
427
+ positions, ret.extend_start_loc = compute_position(
428
+ model_runner.server_args.attention_backend,
429
+ ret.extend_prefix_lens,
430
+ ret.extend_seq_lens,
431
+ ret.extend_num_tokens,
432
+ )
431
433
  if ret.positions is None:
432
434
  ret.positions = positions
433
435
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
@@ -630,8 +632,10 @@ class ForwardBatch:
630
632
  self.dp_padding_mode = dp_padding_mode
631
633
 
632
634
  if dp_padding_mode.is_max_len():
633
- # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
634
- # where transferred tokens should be padded to the same length.
635
+ # when DP gather mode is all gather, we will use
636
+ # all_gather_into_tensor to gather hidden states, where transferred
637
+ # tokens should be padded to the same length. We will also use
638
+ # reduce-scatter instead of all-reduce after MLP.
635
639
  max_num_tokens = max(global_num_tokens)
636
640
  global_num_tokens = [max_num_tokens] * sync_group_size
637
641
  buffer_len = max_num_tokens * sync_group_size
@@ -644,12 +648,17 @@ class ForwardBatch:
644
648
  device=model_runner.device,
645
649
  )
646
650
 
647
- bs = self.batch_size
648
651
  if len(global_num_tokens) > 1:
649
652
  num_tokens = global_num_tokens[get_attention_dp_rank()]
650
653
  else:
651
654
  num_tokens = global_num_tokens[0]
652
655
 
656
+ if self.forward_mode.is_decode():
657
+ setattr(self, "raw_bs", self.batch_size)
658
+ self.batch_size = num_tokens
659
+
660
+ bs = self.batch_size
661
+
653
662
  # padding
654
663
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
655
664
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -657,6 +666,9 @@ class ForwardBatch:
657
666
  seq_len_fill_value = (
658
667
  model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
659
668
  )
669
+ self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
670
+ bs - self.seq_lens.shape[0]
671
+ )
660
672
  self.seq_lens = self._pad_tensor_to_size(
661
673
  self.seq_lens, bs, value=seq_len_fill_value
662
674
  )
@@ -700,7 +712,7 @@ class ForwardBatch:
700
712
 
701
713
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
702
714
 
703
- bs = self.batch_size
715
+ bs = getattr(self, "raw_bs", self.batch_size)
704
716
 
705
717
  if self.spec_info is not None:
706
718
  if self.forward_mode.is_decode(): # draft
@@ -839,7 +851,7 @@ class ForwardBatch:
839
851
 
840
852
 
841
853
  def enable_num_token_non_padded(server_args):
842
- return server_args.enable_ep_moe or server_args.enable_deepep_moe
854
+ return get_moe_expert_parallel_world_size() > 1
843
855
 
844
856
 
845
857
  class PPProxyTensors:
@@ -872,6 +884,25 @@ class PPProxyTensors:
872
884
  return f"PPProxyTensors(tensors={self.tensors})"
873
885
 
874
886
 
887
+ def compute_position(
888
+ attn_backend: str,
889
+ extend_prefix_lens: torch.Tensor,
890
+ extend_seq_lens: torch.Tensor,
891
+ extend_seq_lens_sum: int,
892
+ ):
893
+ if support_triton(attn_backend):
894
+ positions, extend_start_loc = compute_position_triton(
895
+ extend_prefix_lens,
896
+ extend_seq_lens,
897
+ extend_seq_lens_sum,
898
+ )
899
+ else:
900
+ positions, extend_start_loc = compute_position_torch(
901
+ extend_prefix_lens, extend_seq_lens
902
+ )
903
+ return positions, extend_start_loc
904
+
905
+
875
906
  def compute_position_triton(
876
907
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
877
908
  ):
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
60
60
  initialize_dp_attention,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
63
64
  from sglang.srt.layers.quantization import (
64
65
  deep_gemm_wrapper,
65
66
  monkey_patch_isinstance_for_vllm_base_layer,
@@ -217,6 +218,10 @@ class ModelRunner:
217
218
  "use_mla_backend": self.use_mla_backend,
218
219
  "speculative_algorithm": self.spec_algorithm,
219
220
  }
221
+ | {
222
+ "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
223
+ "deepep_mode": DeepEPMode(server_args.deepep_mode),
224
+ }
220
225
  )
221
226
 
222
227
  # CPU offload
@@ -1438,19 +1443,36 @@ class ModelRunner:
1438
1443
  )
1439
1444
 
1440
1445
  return CutlassMLABackend(self)
1441
- elif self.server_args.attention_backend == "trtllm_mla":
1446
+ elif backend_str == "trtllm_mla":
1442
1447
  if not self.use_mla_backend:
1443
1448
  raise ValueError("trtllm_mla backend can only be used with MLA models.")
1444
1449
  from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1445
1450
 
1446
1451
  return TRTLLMMLABackend(self)
1447
- elif self.server_args.attention_backend == "intel_amx":
1452
+ elif backend_str == "trtllm_mha":
1453
+ if self.use_mla_backend:
1454
+ raise ValueError(
1455
+ "trtllm_mha backend can only be used with non-MLA models."
1456
+ )
1457
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
1458
+ TRTLLMHAAttnBackend,
1459
+ )
1460
+
1461
+ return TRTLLMHAAttnBackend(self)
1462
+
1463
+ elif backend_str == "intel_amx":
1448
1464
  from sglang.srt.layers.attention.intel_amx_backend import (
1449
1465
  IntelAMXAttnBackend,
1450
1466
  )
1451
1467
 
1452
1468
  logger.info(f"Intel AMX attention backend is enabled.")
1453
1469
  return IntelAMXAttnBackend(self)
1470
+ elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1471
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1472
+ DualChunkFlashAttentionBackend,
1473
+ )
1474
+
1475
+ return DualChunkFlashAttentionBackend(self)
1454
1476
  else:
1455
1477
  raise ValueError(f"Invalid attention backend: {backend_str}")
1456
1478
 
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
843
843
  return None
844
844
  return remapped_name
845
845
 
846
+ quark_scale_names = {
847
+ ".q_proj.output_scale": ".attn.q_scale",
848
+ ".k_proj.output_scale": ".attn.k_scale",
849
+ ".v_proj.output_scale": ".attn.v_scale",
850
+ "self_attn.prob_output_scale": ".attn.prob_scale",
851
+ }
852
+ for quark_scale_name, sglang_scale_name in quark_scale_names.items():
853
+ if name.endswith(quark_scale_name):
854
+ return name.replace(quark_scale_name, sglang_scale_name)
855
+
846
856
  # If there were no matches, return the untouched param name
847
857
  return name
848
858
 
@@ -0,0 +1,425 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
3
+
4
+ from collections.abc import Iterable
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from sglang.srt.distributed import (
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.layers.activation import SiluAndMul
17
+ from sglang.srt.layers.layernorm import RMSNorm
18
+ from sglang.srt.layers.linear import (
19
+ MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ ReplicatedLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
+ from sglang.srt.layers.moe.topk import TopK
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.radix_attention import RadixAttention
29
+ from sglang.srt.layers.rotary_embedding import get_rope
30
+ from sglang.srt.layers.vocab_parallel_embedding import (
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
36
+ from sglang.srt.utils import add_prefix, make_layers
37
+
38
+
39
+ class BailingAttention(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ layer_id: int = 0,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ prefix: str = "",
47
+ ):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+ tp_size = get_tensor_model_parallel_world_size()
51
+
52
+ self.total_num_heads = config.num_attention_heads
53
+ self.total_num_kv_heads = config.num_key_value_heads
54
+
55
+ assert self.total_num_heads % tp_size == 0
56
+ assert self.total_num_kv_heads % tp_size == 0
57
+
58
+ self.num_heads = self.total_num_heads // tp_size
59
+ self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
60
+ self.q_size = self.num_heads * self.head_dim
61
+
62
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
63
+ self.kv_size = self.num_kv_heads * self.head_dim
64
+ self.scale = self.head_dim**-0.5
65
+
66
+ self.query_key_value = QKVParallelLinear(
67
+ self.hidden_size,
68
+ self.head_dim,
69
+ self.total_num_heads,
70
+ self.total_num_kv_heads,
71
+ bias=(config.use_bias or config.use_qkv_bias),
72
+ quant_config=quant_config,
73
+ prefix=add_prefix("query_key_value", prefix),
74
+ )
75
+
76
+ self.dense = RowParallelLinear(
77
+ self.total_num_heads * self.head_dim,
78
+ self.hidden_size,
79
+ bias=config.use_bias,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("dense", prefix),
82
+ )
83
+
84
+ self.attn = RadixAttention(
85
+ self.num_heads,
86
+ self.head_dim,
87
+ self.scale,
88
+ num_kv_heads=self.num_kv_heads,
89
+ layer_id=layer_id,
90
+ quant_config=quant_config,
91
+ prefix=add_prefix("attn", prefix),
92
+ )
93
+
94
+ self.rotary_emb = get_rope(
95
+ self.head_dim,
96
+ rotary_dim=self.head_dim,
97
+ max_position=config.max_position_embeddings,
98
+ base=config.rope_theta,
99
+ is_neox_style=True,
100
+ rope_scaling=config.rope_scaling,
101
+ )
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ position_ids: torch.Tensor,
107
+ forward_batch: ForwardBatch,
108
+ ) -> torch.Tensor:
109
+ qkv, _ = self.query_key_value(hidden_states)
110
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
111
+
112
+ q, k = self.rotary_emb(position_ids, q, k)
113
+ context_layer = self.attn(q, k, v, forward_batch)
114
+ attn_output, _ = self.dense(context_layer)
115
+ return attn_output
116
+
117
+
118
+ class BailingMLP(nn.Module):
119
+ def __init__(
120
+ self,
121
+ intermediate_size: int,
122
+ config: PretrainedConfig,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ reduce_results: Optional[bool] = True,
125
+ prefix: str = "",
126
+ ) -> None:
127
+ super().__init__()
128
+ self.gate_up_proj = MergedColumnParallelLinear(
129
+ config.hidden_size,
130
+ [intermediate_size] * 2,
131
+ bias=config.use_bias,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("gate_up_proj", prefix),
134
+ )
135
+ self.down_proj = RowParallelLinear(
136
+ intermediate_size,
137
+ config.hidden_size,
138
+ bias=config.use_bias,
139
+ quant_config=quant_config,
140
+ reduce_results=reduce_results,
141
+ prefix=add_prefix("down_proj", prefix),
142
+ )
143
+ self.act_fn = SiluAndMul()
144
+
145
+ def forward(self, x):
146
+ x, _ = self.gate_up_proj(x)
147
+ x = self.act_fn(x)
148
+ x, _ = self.down_proj(x)
149
+ return x
150
+
151
+
152
+ class BailingMoE(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ config: PretrainedConfig,
157
+ layer_id: int,
158
+ quant_config: Optional[QuantizationConfig] = None,
159
+ prefix: str = "",
160
+ ):
161
+ super().__init__()
162
+ self.tp_size = get_tensor_model_parallel_world_size()
163
+ self.num_experts = config.num_experts
164
+ self.top_k = config.num_experts_per_tok
165
+ self.hidden_size = config.hidden_size
166
+ self.num_shared_experts = config.num_shared_experts
167
+ self.norm_expert_prob = config.norm_topk_prob
168
+ self.moe_intermediate_size = config.moe_intermediate_size
169
+
170
+ self.gate = ReplicatedLinear(
171
+ self.hidden_size, self.num_experts, bias=False, quant_config=None
172
+ )
173
+
174
+ self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
175
+
176
+ self.experts = FusedMoE(
177
+ num_experts=self.num_experts,
178
+ top_k=self.top_k,
179
+ layer_id=layer_id,
180
+ hidden_size=self.hidden_size,
181
+ intermediate_size=self.moe_intermediate_size,
182
+ reduce_results=False,
183
+ quant_config=quant_config,
184
+ prefix=add_prefix("experts", prefix),
185
+ )
186
+
187
+ if self.num_shared_experts > 0:
188
+ shared_intermediate_size = (
189
+ self.moe_intermediate_size * self.num_shared_experts
190
+ )
191
+ self.shared_experts = BailingMLP(
192
+ intermediate_size=shared_intermediate_size,
193
+ config=config,
194
+ quant_config=quant_config,
195
+ reduce_results=False,
196
+ prefix=add_prefix("shared_experts", prefix),
197
+ )
198
+ else:
199
+ self.shared_experts = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ orig_shape = hidden_states.shape
203
+ hidden_states_flat = hidden_states.view(-1, self.hidden_size)
204
+
205
+ shared_output = None
206
+ if self.shared_experts is not None:
207
+ shared_output = self.shared_experts(hidden_states_flat)
208
+
209
+ router_logits, _ = self.gate(hidden_states_flat)
210
+ topk_output = self.topk(hidden_states_flat, router_logits)
211
+ final_hidden_states = self.experts(hidden_states_flat, topk_output)
212
+
213
+ if shared_output is not None:
214
+ final_hidden_states = final_hidden_states + shared_output
215
+
216
+ if self.tp_size > 1:
217
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
218
+
219
+ return final_hidden_states.view(orig_shape)
220
+
221
+
222
+ class BailingMoeBlock(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ layer_id: int,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ):
231
+ super().__init__()
232
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
+ self.attention = BailingAttention(
234
+ config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
235
+ )
236
+ self.post_attention_layernorm = RMSNorm(
237
+ config.hidden_size, eps=config.rms_norm_eps
238
+ )
239
+ self.mlp = BailingMoE(
240
+ config=config,
241
+ layer_id=layer_id,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("mlp", prefix),
244
+ )
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ position_ids: torch.Tensor,
250
+ residual: Optional[torch.Tensor],
251
+ forward_batch: ForwardBatch,
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ # Pre-normalization and residual connection for the attention block
254
+ if residual is None:
255
+ residual = hidden_states
256
+ normed_hidden_states = self.input_layernorm(hidden_states)
257
+ else:
258
+ normed_hidden_states, residual = self.input_layernorm(
259
+ hidden_states, residual
260
+ )
261
+
262
+ attn_output = self.attention(
263
+ hidden_states=normed_hidden_states,
264
+ position_ids=position_ids,
265
+ forward_batch=forward_batch,
266
+ )
267
+
268
+ # Pre-normalization and residual connection for the MLP block
269
+ normed_hidden_states, residual = self.post_attention_layernorm(
270
+ attn_output, residual
271
+ )
272
+ mlp_output = self.mlp(normed_hidden_states)
273
+
274
+ return mlp_output, residual
275
+
276
+
277
+ class BailingMoeModel(nn.Module):
278
+
279
+ def __init__(
280
+ self,
281
+ config: PretrainedConfig,
282
+ quant_config: Optional[QuantizationConfig] = None,
283
+ prefix: str = "",
284
+ ):
285
+ super().__init__()
286
+ self.config = config
287
+ self.padding_idx = config.pad_token_id
288
+ self.vocab_size = config.vocab_size
289
+ self.embed_dim = config.hidden_size
290
+
291
+ self.embed_tokens = VocabParallelEmbedding(
292
+ config.vocab_size,
293
+ config.hidden_size,
294
+ prefix=add_prefix("embed_tokens", prefix),
295
+ )
296
+ self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
297
+
298
+ self.layers = make_layers(
299
+ config.num_hidden_layers,
300
+ lambda idx, prefix: BailingMoeBlock(
301
+ config=config,
302
+ layer_id=idx,
303
+ quant_config=quant_config,
304
+ prefix=prefix,
305
+ ),
306
+ prefix=add_prefix("layers", prefix),
307
+ )
308
+
309
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ position_ids: torch.Tensor,
315
+ forward_batch: ForwardBatch,
316
+ input_embeds: Optional[torch.Tensor] = None,
317
+ ) -> torch.Tensor:
318
+ if input_embeds is None:
319
+ hidden_states = self.embed_tokens(input_ids)
320
+ else:
321
+ hidden_states = input_embeds
322
+
323
+ residual = None
324
+ for layer in self.layers:
325
+ hidden_states, residual = layer(
326
+ hidden_states,
327
+ position_ids,
328
+ residual,
329
+ forward_batch,
330
+ )
331
+
332
+ hidden_states, _ = self.norm(hidden_states, residual)
333
+ return hidden_states
334
+
335
+
336
+ class BailingMoeForCausalLM(nn.Module):
337
+
338
+ def __init__(
339
+ self,
340
+ config: PretrainedConfig,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ ) -> None:
343
+ super().__init__()
344
+ self.config = config
345
+ self.model = BailingMoeModel(config=config, quant_config=quant_config)
346
+ self.lm_head = ParallelLMHead(
347
+ num_embeddings=config.vocab_size,
348
+ embedding_dim=config.hidden_size,
349
+ quant_config=quant_config,
350
+ )
351
+ if config.tie_word_embeddings:
352
+ self.lm_head.weight = self.model.embed_tokens.weight
353
+
354
+ self.logits_processor = LogitsProcessor(config)
355
+
356
+ def forward(
357
+ self,
358
+ input_ids: torch.Tensor,
359
+ positions: torch.Tensor,
360
+ forward_batch: ForwardBatch,
361
+ inputs_embeds: Optional[torch.Tensor] = None,
362
+ ) -> torch.Tensor:
363
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
364
+ return self.logits_processor(
365
+ input_ids, hidden_states, self.lm_head, forward_batch
366
+ )
367
+
368
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
+
370
+ stacked_params_mapping = [
371
+ ("gate_up_proj", "gate_proj", 0),
372
+ ("gate_up_proj", "up_proj", 1),
373
+ ]
374
+
375
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
376
+ ckpt_gate_proj_name="gate_proj",
377
+ ckpt_down_proj_name="down_proj",
378
+ ckpt_up_proj_name="up_proj",
379
+ num_experts=self.config.num_experts,
380
+ )
381
+
382
+ params_dict = dict(self.named_parameters())
383
+ for name, loaded_weight in weights:
384
+
385
+ if (
386
+ hasattr(self.config, "norm_head")
387
+ and self.config.norm_head
388
+ and "lm_head.weight" in name
389
+ ):
390
+ loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
391
+
392
+ if "model.word_embeddings.weight" == name:
393
+ name = "model.embed_tokens.weight"
394
+
395
+ for param_name, weight_name, shard_id in stacked_params_mapping:
396
+ if weight_name in name and "mlp.experts" not in name:
397
+ full_param_name = name.replace(weight_name, param_name)
398
+ param = params_dict[full_param_name]
399
+ param.weight_loader(param, loaded_weight, shard_id)
400
+ break
401
+ else:
402
+ for p_name, w_name, e_id, s_id in expert_params_mapping:
403
+ if w_name in name and "mlp.experts" in name:
404
+ full_param_name = name.replace(w_name, p_name)
405
+ param = params_dict[full_param_name]
406
+ param.weight_loader(
407
+ param,
408
+ loaded_weight,
409
+ full_param_name,
410
+ shard_id=s_id,
411
+ expert_id=e_id,
412
+ )
413
+ break
414
+ else:
415
+ if name.endswith(".bias") and name not in params_dict:
416
+ continue
417
+
418
+ param = params_dict[name]
419
+ weight_loader = getattr(
420
+ param, "weight_loader", default_weight_loader
421
+ )
422
+ weight_loader(param, loaded_weight)
423
+
424
+
425
+ EntryClass = BailingMoeForCausalLM