sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
29
29
  import torch
30
30
  import torch.distributed as dist
31
31
 
32
- from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
32
+ from sglang.srt.configs import (
33
+ FalconH1Config,
34
+ KimiLinearConfig,
35
+ NemotronHConfig,
36
+ Qwen3NextConfig,
37
+ )
33
38
  from sglang.srt.configs.device_config import DeviceConfig
34
39
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
35
40
  from sglang.srt.configs.model_config import (
@@ -40,6 +45,9 @@ from sglang.srt.configs.model_config import (
40
45
  )
41
46
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
42
47
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
48
+ from sglang.srt.debug_utils.tensor_dump_forward_hook import (
49
+ register_forward_hook_for_model,
50
+ )
43
51
  from sglang.srt.distributed import (
44
52
  get_pp_group,
45
53
  get_tp_group,
@@ -77,7 +85,6 @@ from sglang.srt.layers.dp_attention import (
77
85
  initialize_dp_attention,
78
86
  )
79
87
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
80
- from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
81
88
  from sglang.srt.layers.sampler import Sampler
82
89
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
83
90
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -349,7 +356,11 @@ class ModelRunner:
349
356
 
350
357
  if not self.is_draft_worker:
351
358
  set_global_expert_location_metadata(
352
- compute_initial_expert_location_metadata(server_args, self.model_config)
359
+ compute_initial_expert_location_metadata(
360
+ server_args=server_args,
361
+ model_config=self.model_config,
362
+ moe_ep_rank=self.moe_ep_rank,
363
+ )
353
364
  )
354
365
  if self.tp_rank == 0 and get_bool_env_var(
355
366
  "SGLANG_LOG_EXPERT_LOCATION_METADATA"
@@ -730,7 +741,6 @@ class ModelRunner:
730
741
  # Load the model
731
742
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
732
743
  monkey_patch_vllm_parallel_state()
733
- monkey_patch_isinstance_for_vllm_base_layer()
734
744
 
735
745
  with self.memory_saver_adapter.region(
736
746
  GPU_MEMORY_TYPE_WEIGHTS,
@@ -742,7 +752,6 @@ class ModelRunner:
742
752
  device_config=DeviceConfig(self.device, self.gpu_id),
743
753
  )
744
754
  monkey_patch_vllm_parallel_state(reverse=True)
745
- monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
746
755
 
747
756
  get_offloader().post_init()
748
757
 
@@ -790,6 +799,15 @@ class ModelRunner:
790
799
  f"avail mem={after_avail_memory:.2f} GB, "
791
800
  f"mem usage={self.weight_load_mem_usage:.2f} GB."
792
801
  )
802
+ if self.server_args.debug_tensor_dump_output_folder is not None:
803
+ register_forward_hook_for_model(
804
+ self.model,
805
+ self.server_args.debug_tensor_dump_output_folder,
806
+ self.server_args.debug_tensor_dump_layers,
807
+ self.tp_size,
808
+ self.tp_rank,
809
+ self.pp_rank,
810
+ )
793
811
 
794
812
  if self.server_args.elastic_ep_backend == "mooncake":
795
813
  # Mooncake does not support `monitored_barrier`
@@ -1345,9 +1363,16 @@ class ModelRunner:
1345
1363
  return config
1346
1364
  return None
1347
1365
 
1366
+ @property
1367
+ def kimi_linear_config(self):
1368
+ config = self.model_config.hf_config
1369
+ if isinstance(config, KimiLinearConfig):
1370
+ return config
1371
+ return None
1372
+
1348
1373
  @property
1349
1374
  def mambaish_config(self):
1350
- return self.mamba2_config or self.hybrid_gdn_config
1375
+ return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
1351
1376
 
1352
1377
  def set_num_token_hybrid(self):
1353
1378
  if (
@@ -1658,9 +1683,11 @@ class ModelRunner:
1658
1683
  get_attention_tp_size()
1659
1684
  ),
1660
1685
  head_dim=self.model_config.head_dim,
1661
- layer_num=self.model_config.num_hidden_layers,
1686
+ layer_num=self.num_effective_layers,
1662
1687
  device=self.device,
1663
1688
  enable_memory_saver=self.server_args.enable_memory_saver,
1689
+ start_layer=self.start_layer,
1690
+ end_layer=self.end_layer,
1664
1691
  )
1665
1692
  elif self.use_mla_backend and is_nsa_model:
1666
1693
  self.token_to_kv_pool = NSATokenToKVPool(
@@ -1676,7 +1703,7 @@ class ModelRunner:
1676
1703
  end_layer=self.end_layer,
1677
1704
  index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1678
1705
  )
1679
- elif self.use_mla_backend:
1706
+ elif self.use_mla_backend and not self.mambaish_config:
1680
1707
  assert not is_nsa_model
1681
1708
  self.token_to_kv_pool = MLATokenToKVPool(
1682
1709
  self.max_total_num_tokens,
@@ -1720,6 +1747,12 @@ class ModelRunner:
1720
1747
  device=self.device,
1721
1748
  )
1722
1749
  elif config := self.mambaish_config:
1750
+ extra_args = {}
1751
+ if self.use_mla_backend:
1752
+ extra_args = {
1753
+ "kv_lora_rank": self.model_config.kv_lora_rank,
1754
+ "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
1755
+ }
1723
1756
  self.token_to_kv_pool = HybridLinearKVPool(
1724
1757
  page_size=self.page_size,
1725
1758
  size=self.max_total_num_tokens,
@@ -1735,6 +1768,8 @@ class ModelRunner:
1735
1768
  enable_kvcache_transpose=False,
1736
1769
  device=self.device,
1737
1770
  mamba_pool=self.req_to_token_pool.mamba_pool,
1771
+ use_mla=self.use_mla_backend,
1772
+ **extra_args,
1738
1773
  )
1739
1774
  else:
1740
1775
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -1750,6 +1785,7 @@ class ModelRunner:
1750
1785
  enable_memory_saver=self.server_args.enable_memory_saver,
1751
1786
  start_layer=self.start_layer,
1752
1787
  end_layer=self.end_layer,
1788
+ enable_alt_stream=not self.server_args.enable_pdmux,
1753
1789
  enable_kv_cache_copy=(
1754
1790
  self.server_args.speculative_algorithm is not None
1755
1791
  ),
@@ -1818,12 +1854,18 @@ class ModelRunner:
1818
1854
 
1819
1855
  def init_attention_backend(self):
1820
1856
  """Init attention kernel backend."""
1821
- if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1857
+ if self.server_args.enable_pdmux:
1858
+ self.attn_backend = self._get_attention_backend(init_new_workspace=True)
1859
+ self.decode_attn_backend_group = []
1860
+ for _ in range(self.server_args.sm_group_num):
1861
+ self.decode_attn_backend_group.append(self._get_attention_backend())
1862
+ self.decode_attn_backend = self.decode_attn_backend_group[0]
1863
+ elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1822
1864
  self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
1823
1865
  else:
1824
1866
  self.attn_backend = self._get_attention_backend()
1825
1867
 
1826
- def _get_attention_backend(self):
1868
+ def _get_attention_backend(self, init_new_workspace: bool = False):
1827
1869
  """Init attention kernel backend."""
1828
1870
  self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1829
1871
  self.server_args.get_attention_backends()
@@ -1837,10 +1879,12 @@ class ModelRunner:
1837
1879
  attn_backend = HybridAttnBackend(
1838
1880
  self,
1839
1881
  decode_backend=self._get_attention_backend_from_str(
1840
- self.decode_attention_backend_str
1882
+ self.decode_attention_backend_str,
1883
+ init_new_workspace=init_new_workspace,
1841
1884
  ),
1842
1885
  prefill_backend=self._get_attention_backend_from_str(
1843
- self.prefill_attention_backend_str
1886
+ self.prefill_attention_backend_str,
1887
+ init_new_workspace=init_new_workspace,
1844
1888
  ),
1845
1889
  )
1846
1890
  logger.info(
@@ -1854,7 +1898,8 @@ class ModelRunner:
1854
1898
  )
1855
1899
  else:
1856
1900
  attn_backend = self._get_attention_backend_from_str(
1857
- self.server_args.attention_backend
1901
+ self.server_args.attention_backend,
1902
+ init_new_workspace=init_new_workspace,
1858
1903
  )
1859
1904
 
1860
1905
  (
@@ -1863,9 +1908,12 @@ class ModelRunner:
1863
1908
  ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1864
1909
  return attn_backend
1865
1910
 
1866
- def _get_attention_backend_from_str(self, backend_str: str):
1911
+ def _get_attention_backend_from_str(
1912
+ self, backend_str: str, init_new_workspace: bool = False
1913
+ ):
1867
1914
  if backend_str not in ATTENTION_BACKENDS:
1868
1915
  raise ValueError(f"Invalid attention backend: {backend_str}")
1916
+ self.init_new_workspace = init_new_workspace
1869
1917
  full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
1870
1918
  return attn_backend_wrapper(self, full_attention_backend)
1871
1919
 
@@ -1963,6 +2011,9 @@ class ModelRunner:
1963
2011
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
1964
2012
  tensor_parallel(self.model, device_mesh)
1965
2013
 
2014
+ def update_decode_attn_backend(self, stream_idx: int):
2015
+ self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
2016
+
1966
2017
  def forward_decode(
1967
2018
  self,
1968
2019
  forward_batch: ForwardBatch,
@@ -1970,7 +2021,11 @@ class ModelRunner:
1970
2021
  pp_proxy_tensors=None,
1971
2022
  ) -> LogitsProcessorOutput:
1972
2023
  if not skip_attn_backend_init:
1973
- self.attn_backend.init_forward_metadata(forward_batch)
2024
+ if self.server_args.enable_pdmux:
2025
+ self.decode_attn_backend.init_forward_metadata(forward_batch)
2026
+ forward_batch.attn_backend = self.decode_attn_backend
2027
+ else:
2028
+ self.attn_backend.init_forward_metadata(forward_batch)
1974
2029
  # FIXME: add pp_proxy_tensors arg to all models
1975
2030
  kwargs = {}
1976
2031
  if self.support_pp:
@@ -2108,18 +2163,18 @@ class ModelRunner:
2108
2163
  skip_attn_backend_init=skip_attn_backend_init,
2109
2164
  pp_proxy_tensors=pp_proxy_tensors,
2110
2165
  )
2111
- elif forward_batch.forward_mode.is_extend():
2112
- ret = self.forward_extend(
2113
- forward_batch,
2114
- skip_attn_backend_init=skip_attn_backend_init,
2115
- pp_proxy_tensors=pp_proxy_tensors,
2116
- )
2117
2166
  elif forward_batch.forward_mode.is_split_prefill():
2118
2167
  ret = self.forward_split_prefill(
2119
2168
  forward_batch,
2120
2169
  reinit_attn_backend=reinit_attn_backend,
2121
2170
  forward_count=split_forward_count,
2122
2171
  )
2172
+ elif forward_batch.forward_mode.is_extend():
2173
+ ret = self.forward_extend(
2174
+ forward_batch,
2175
+ skip_attn_backend_init=skip_attn_backend_init,
2176
+ pp_proxy_tensors=pp_proxy_tensors,
2177
+ )
2123
2178
  elif forward_batch.forward_mode.is_idle():
2124
2179
  ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
2125
2180
  else:
@@ -75,9 +75,13 @@ class NPUGraphRunner(CudaGraphRunner):
75
75
 
76
76
  # Replay
77
77
  if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
78
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
79
- self.bs - self.raw_bs
80
- )
78
+ if forward_batch.forward_mode.is_target_verify():
79
+ seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
80
+ seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
81
+ else:
82
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
83
+ self.bs - self.raw_bs
84
+ )
81
85
  thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
82
86
  thread.start()
83
87
  self.graphs[self.bs].replay()
@@ -238,7 +238,7 @@ def get_quant_config(
238
238
  if model_config.quantization == "bitsandbytes":
239
239
  config["adapter_name_or_path"] = model_name_or_path
240
240
  elif model_config.quantization.startswith("modelopt") and (
241
- config["producer"]["name"].startswith("modelopt")
241
+ config.get("producer", {}).get("name", "").startswith("modelopt")
242
242
  ):
243
243
  quant_algo = config["quantization"]["quant_algo"]
244
244
  if quant_algo is None:
@@ -420,14 +420,21 @@ class BailingMoEAttention(nn.Module):
420
420
  attn_tp_size = get_attention_tp_size()
421
421
 
422
422
  assert self.total_num_heads % attn_tp_size == 0
423
- assert self.total_kv_heads % attn_tp_size == 0
423
+ if self.total_kv_heads >= attn_tp_size:
424
+ # Number of KV heads is greater than TP size, so we partition
425
+ # the KV heads across multiple tensor parallel GPUs.
426
+ assert self.total_kv_heads % attn_tp_size == 0
427
+ else:
428
+ # Number of KV heads is less than TP size, so we replicate
429
+ # the KV heads across multiple tensor parallel GPUs.
430
+ assert attn_tp_size % self.total_kv_heads == 0
424
431
  assert self.total_num_heads >= self.total_kv_heads
425
432
 
426
433
  self.num_heads = self.total_num_heads // attn_tp_size
427
434
  self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
428
435
  self.q_size = self.head_dim * self.num_heads
429
436
 
430
- self.num_kv_heads = self.total_kv_heads // attn_tp_size
437
+ self.num_kv_heads = max(1, self.total_kv_heads // attn_tp_size)
431
438
  self.kv_size = max(1, self.num_kv_heads * self.head_dim)
432
439
 
433
440
  self.scale = self.head_dim**-0.5
@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import (
38
38
  enable_nextn_moe_bf16_cast_to_fp8,
39
39
  )
40
40
  from sglang.srt.server_args import get_global_server_args
41
- from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
41
+ from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_npu
42
42
 
43
43
  logger = logging.getLogger(__name__)
44
44
 
45
45
 
46
46
  _is_cuda = is_cuda()
47
+ _is_npu = is_npu()
47
48
 
48
49
 
49
50
  class DeepseekModelNextN(nn.Module):
@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module):
85
86
  self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
86
87
 
87
88
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
89
+
90
+ layer_name = "decoder"
91
+ if _is_npu and (
92
+ get_global_server_args().speculative_draft_model_path
93
+ == get_global_server_args().model_path
94
+ ):
95
+ layer_name = "layers." + str(config.num_hidden_layers)
96
+
88
97
  self.decoder = DeepseekV2DecoderLayer(
89
98
  config,
90
99
  0,
91
100
  quant_config=quant_config,
92
101
  moe_quant_config=moe_quant_config,
93
102
  is_nextn=True,
94
- prefix=add_prefix("decoder", prefix),
103
+ prefix=add_prefix(layer_name, prefix),
95
104
  alt_stream=self.alt_stream,
96
105
  )
97
106