sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -320,17 +320,30 @@ class ForwardBatch:
320
320
 
321
321
  # For DP attention
322
322
  if batch.global_num_tokens is not None:
323
- ret.global_num_tokens_cpu = batch.global_num_tokens
323
+
324
+ spec_num_draft_tokens = (
325
+ batch.spec_num_draft_tokens
326
+ if batch.spec_num_draft_tokens is not None
327
+ else 1
328
+ )
329
+ global_num_tokens = [
330
+ x * spec_num_draft_tokens for x in batch.global_num_tokens
331
+ ]
332
+ global_num_tokens_for_logprob = [
333
+ x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
334
+ ]
335
+
336
+ ret.global_num_tokens_cpu = global_num_tokens
324
337
  ret.global_num_tokens_gpu = torch.tensor(
325
- batch.global_num_tokens, dtype=torch.int64
338
+ global_num_tokens, dtype=torch.int64
326
339
  ).to(device, non_blocking=True)
327
340
 
328
- ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
341
+ ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
329
342
  ret.global_num_tokens_for_logprob_gpu = torch.tensor(
330
- batch.global_num_tokens_for_logprob, dtype=torch.int64
343
+ global_num_tokens_for_logprob, dtype=torch.int64
331
344
  ).to(device, non_blocking=True)
332
345
 
333
- sum_len = sum(batch.global_num_tokens)
346
+ sum_len = sum(global_num_tokens)
334
347
  ret.gathered_buffer = torch.zeros(
335
348
  (sum_len, model_runner.model_config.hidden_size),
336
349
  dtype=model_runner.dtype,
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
30
30
  from sglang.srt.configs.device_config import DeviceConfig
31
31
  from sglang.srt.configs.load_config import LoadConfig
32
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
+ from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
33
34
  from sglang.srt.distributed import (
34
35
  get_tp_group,
35
36
  get_world_group,
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
70
71
  GLOBAL_SERVER_ARGS_KEYS,
71
72
  global_server_args_dict,
72
73
  )
74
+ from sglang.srt.mem_cache.allocator import (
75
+ BaseTokenToKVPoolAllocator,
76
+ PagedTokenToKVPoolAllocator,
77
+ TokenToKVPoolAllocator,
78
+ )
73
79
  from sglang.srt.mem_cache.memory_pool import (
74
80
  DoubleSparseTokenToKVPool,
75
81
  MHATokenToKVPool,
76
82
  MLATokenToKVPool,
77
83
  ReqToTokenPool,
78
- TokenToKVPoolAllocator,
79
84
  )
80
- from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
81
85
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
82
86
  from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
83
87
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
93
97
  from sglang.srt.utils import (
94
98
  MultiprocessingSerializer,
95
99
  cpu_has_amx_support,
100
+ dynamic_import,
96
101
  enable_show_time_cost,
97
102
  get_available_gpu_memory,
98
103
  get_bool_env_var,
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
110
115
  )
111
116
 
112
117
  _is_hip = is_hip()
118
+ _is_cpu_amx_available = cpu_has_amx_support()
113
119
 
114
120
  # Use a small KV cache pool size for tests in CI
115
121
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -149,7 +155,7 @@ class ModelRunner:
149
155
  server_args: ServerArgs,
150
156
  is_draft_worker: bool = False,
151
157
  req_to_token_pool: Optional[ReqToTokenPool] = None,
152
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
158
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
153
159
  ):
154
160
  # Parse args
155
161
  self.model_config = model_config
@@ -162,6 +168,7 @@ class ModelRunner:
162
168
  logger.addFilter(RankZeroFilter(tp_rank == 0))
163
169
  self.tp_rank = tp_rank
164
170
  self.tp_size = tp_size
171
+ self.dp_size = server_args.dp_size
165
172
  self.pp_rank = pp_rank
166
173
  self.pp_size = pp_size
167
174
  self.dist_port = nccl_port
@@ -195,6 +202,7 @@ class ModelRunner:
195
202
  | {
196
203
  # TODO it is indeed not a "server args"
197
204
  "use_mla_backend": self.use_mla_backend,
205
+ "speculative_algorithm": self.spec_algorithm,
198
206
  }
199
207
  )
200
208
 
@@ -218,6 +226,7 @@ class ModelRunner:
218
226
 
219
227
  def initialize(self, min_per_gpu_memory: float):
220
228
  server_args = self.server_args
229
+
221
230
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
222
231
  enable=self.server_args.enable_memory_saver
223
232
  )
@@ -272,6 +281,10 @@ class ModelRunner:
272
281
  self.apply_torch_tp()
273
282
 
274
283
  # Init lora
284
+ # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
285
+ # a new server arg `enable_lora` to control whether to init LoRA manager to be more
286
+ # explicit, as it is perfectly valid to start a server with an empty lora_paths and
287
+ # load LoRA adapters dynamically later.
275
288
  if server_args.lora_paths is not None:
276
289
  self.init_lora_manager()
277
290
 
@@ -299,7 +312,7 @@ class ModelRunner:
299
312
  if (
300
313
  server_args.attention_backend == "intel_amx"
301
314
  and server_args.device == "cpu"
302
- and not cpu_has_amx_support()
315
+ and not _is_cpu_amx_available
303
316
  ):
304
317
  logger.info(
305
318
  "The current platform does not support Intel AMX, will fallback to torch_native backend."
@@ -543,7 +556,7 @@ class ModelRunner:
543
556
  monkey_patch_vllm_parallel_state()
544
557
  monkey_patch_isinstance_for_vllm_base_layer()
545
558
 
546
- with self.memory_saver_adapter.region():
559
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
547
560
  self.model = get_model(
548
561
  model_config=self.model_config,
549
562
  load_config=self.load_config,
@@ -761,6 +774,9 @@ class ModelRunner:
761
774
  ]
762
775
  if load_format == "direct":
763
776
  _model_load_weights_direct(self.model, named_tensors)
777
+ elif load_format in self.server_args.custom_weight_loader:
778
+ custom_loader = dynamic_import(load_format)
779
+ custom_loader(self.model, named_tensors)
764
780
  elif load_format is None:
765
781
  self.model.load_weights(named_tensors)
766
782
  else:
@@ -787,7 +803,6 @@ class ModelRunner:
787
803
  def init_lora_manager(self):
788
804
  self.lora_manager = LoRAManager(
789
805
  base_model=self.model,
790
- lora_paths=self.server_args.lora_paths,
791
806
  base_hf_config=self.model_config.hf_config,
792
807
  max_loras_per_batch=self.server_args.max_loras_per_batch,
793
808
  load_config=self.load_config,
@@ -796,6 +811,7 @@ class ModelRunner:
796
811
  tp_size=self.tp_size,
797
812
  tp_rank=self.tp_rank,
798
813
  )
814
+ self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
799
815
  logger.info("LoRA manager ready.")
800
816
 
801
817
  def profile_max_num_token(self, total_gpu_memory: int):
@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
337
337
  hf_weights_files,
338
338
  )
339
339
  elif use_safetensors:
340
- weights_iterator = safetensors_weights_iterator(hf_weights_files)
340
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
341
+
342
+ weight_loader_disable_mmap = global_server_args_dict.get(
343
+ "weight_loader_disable_mmap"
344
+ )
345
+ weights_iterator = safetensors_weights_iterator(
346
+ hf_weights_files, disable_mmap=weight_loader_disable_mmap
347
+ )
341
348
  else:
342
349
  weights_iterator = pt_weights_iterator(hf_weights_files)
343
350
 
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
34
34
  from sglang.srt.configs.model_config import ModelConfig
35
35
  from sglang.srt.distributed import get_tensor_model_parallel_rank
36
36
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
37
+ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
37
38
  from sglang.srt.utils import print_warning_once
38
39
 
39
40
  logger = logging.getLogger(__name__)
@@ -206,7 +207,10 @@ def get_quant_config(
206
207
  config["adapter_name_or_path"] = model_name_or_path
207
208
  elif model_config.quantization == "modelopt":
208
209
  if config["producer"]["name"] == "modelopt":
209
- return quant_cls.from_config(config)
210
+ if "FP4" in config["quantization"]["quant_algo"]:
211
+ return ModelOptFp4Config.from_config(config)
212
+ else:
213
+ return quant_cls.from_config(config)
210
214
  else:
211
215
  raise ValueError(
212
216
  f"Unsupported quantization config"
@@ -418,6 +422,7 @@ def safetensors_weights_iterator(
418
422
  hf_weights_files: List[str],
419
423
  is_all_weights_sharded: bool = False,
420
424
  decryption_key: Optional[str] = None,
425
+ disable_mmap: bool = False,
421
426
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
422
427
  """Iterate over the weights in the model safetensor files.
423
428
 
@@ -439,7 +444,11 @@ def safetensors_weights_iterator(
439
444
  disable=not enable_tqdm,
440
445
  bar_format=_BAR_FORMAT,
441
446
  ):
442
- result = safetensors.torch.load_file(st_file, device="cpu")
447
+ if disable_mmap:
448
+ with open(st_file, "rb") as f:
449
+ result = safetensors.torch.load(f.read())
450
+ else:
451
+ result = safetensors.torch.load_file(st_file, device="cpu")
443
452
  for name, param in result.items():
444
453
  yield name, param
445
454
 
@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
- from sglang.srt.layers.linear import ReplicatedLinear
26
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
27
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -45,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
45
44
  prefix: str = "",
46
45
  ) -> None:
47
46
  super().__init__()
47
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
48
+ logger.warning(
49
+ "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
50
+ )
51
+ quant_config = None
52
+
48
53
  self.vocab_size = config.vocab_size
49
54
 
50
55
  self.embed_tokens = VocabParallelEmbedding(
@@ -77,6 +82,7 @@ class DeepseekModelNextN(nn.Module):
77
82
  forward_batch: ForwardBatch,
78
83
  input_embeds: torch.Tensor = None,
79
84
  ) -> torch.Tensor:
85
+
80
86
  zero_allocator = BumpAllocator(
81
87
  buffer_size=2,
82
88
  dtype=torch.float32,
@@ -90,15 +96,16 @@ class DeepseekModelNextN(nn.Module):
90
96
  else:
91
97
  hidden_states = input_embeds
92
98
 
93
- hidden_states = self.eh_proj(
94
- torch.cat(
95
- (
96
- self.enorm(hidden_states),
97
- self.hnorm(forward_batch.spec_info.hidden_states),
98
- ),
99
- dim=-1,
99
+ if hidden_states.shape[0] > 0:
100
+ hidden_states = self.eh_proj(
101
+ torch.cat(
102
+ (
103
+ self.enorm(hidden_states),
104
+ self.hnorm(forward_batch.spec_info.hidden_states),
105
+ ),
106
+ dim=-1,
107
+ )
100
108
  )
101
- )
102
109
 
103
110
  residual = None
104
111
  hidden_states, residual = self.decoder(
@@ -106,7 +113,11 @@ class DeepseekModelNextN(nn.Module):
106
113
  )
107
114
 
108
115
  if not forward_batch.forward_mode.is_idle():
109
- hidden_states, _ = self.shared_head.norm(hidden_states, residual)
116
+ if residual is not None:
117
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
118
+ else:
119
+ hidden_states = self.shared_head.norm(hidden_states)
120
+
110
121
  return hidden_states
111
122
 
112
123
 
@@ -127,23 +138,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
127
138
  self.model = DeepseekModelNextN(
128
139
  config, quant_config, prefix=add_prefix("model", prefix)
129
140
  )
130
-
131
- if global_server_args_dict["enable_dp_attention"]:
132
- self.lm_head = ReplicatedLinear(
133
- config.hidden_size,
134
- config.vocab_size,
135
- bias=False,
136
- prefix=add_prefix("model.shared_head.head", prefix),
137
- )
138
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
139
- else:
140
- self.lm_head = ParallelLMHead(
141
- config.vocab_size,
142
- config.hidden_size,
143
- quant_config=quant_config,
144
- prefix=add_prefix("model.shared_head.head", prefix),
145
- )
146
- self.logits_processor = LogitsProcessor(config)
141
+ self.lm_head = ParallelLMHead(
142
+ config.vocab_size,
143
+ config.hidden_size,
144
+ quant_config=quant_config,
145
+ prefix=add_prefix("model.shared_head.head", prefix),
146
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
147
+ )
148
+ self.logits_processor = LogitsProcessor(config)
147
149
 
148
150
  @torch.no_grad()
149
151
  def forward(
@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
72
72
  block_dequant as int8_block_dequant,
73
73
  )
74
74
  from sglang.srt.layers.radix_attention import RadixAttention
75
- from sglang.srt.layers.rotary_embedding import get_rope
75
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
76
76
  from sglang.srt.layers.vocab_parallel_embedding import (
77
77
  ParallelLMHead,
78
78
  VocabParallelEmbedding,
@@ -95,8 +95,10 @@ from sglang.srt.utils import (
95
95
  LazyValue,
96
96
  add_prefix,
97
97
  bind_or_assign,
98
+ cpu_has_amx_support,
98
99
  get_bool_env_var,
99
100
  get_int_env_var,
101
+ is_cpu,
100
102
  is_cuda,
101
103
  is_hip,
102
104
  is_non_idle_and_non_empty,
@@ -107,9 +109,13 @@ _is_hip = is_hip()
107
109
  _is_cuda = is_cuda()
108
110
  _is_fp8_fnuz = is_fp8_fnuz()
109
111
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
112
+ _is_cpu_amx_available = cpu_has_amx_support()
113
+ _is_cpu = is_cpu()
110
114
 
111
115
  if _is_cuda:
112
116
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
117
+ elif _is_cpu and _is_cpu_amx_available:
118
+ pass
113
119
  else:
114
120
  from vllm._custom_ops import awq_dequantize
115
121
 
@@ -220,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
220
226
  layer_id: int,
221
227
  quant_config: Optional[QuantizationConfig] = None,
222
228
  prefix: str = "",
229
+ alt_stream: Optional[torch.cuda.Stream] = None,
223
230
  ):
224
231
  super().__init__()
225
232
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -232,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
232
239
  )
233
240
  self.config = config
234
241
  self.layer_id = layer_id
242
+ self.alt_stream = alt_stream
235
243
 
236
244
  if self.tp_size > config.n_routed_experts:
237
245
  raise ValueError(
@@ -269,6 +277,15 @@ class DeepseekV2MoE(nn.Module):
269
277
  if global_server_args_dict["enable_deepep_moe"]
270
278
  else {}
271
279
  ),
280
+ # Additional args for FusedMoE
281
+ **(
282
+ dict(
283
+ enable_flashinfer_moe=True,
284
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
285
+ )
286
+ if global_server_args_dict["enable_flashinfer_moe"]
287
+ else {}
288
+ ),
272
289
  )
273
290
 
274
291
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
@@ -332,10 +349,38 @@ class DeepseekV2MoE(nn.Module):
332
349
  self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
333
350
  ) -> torch.Tensor:
334
351
  if not self._enable_deepep_moe:
335
- return self.forward_normal(hidden_states)
352
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
353
+ if (
354
+ self.alt_stream is not None
355
+ and self.num_fused_shared_experts == 0
356
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
357
+ ):
358
+ return self.forward_normal_dual_stream(hidden_states)
359
+ else:
360
+ return self.forward_normal(hidden_states)
336
361
  else:
337
362
  return self.forward_deepep(hidden_states, forward_batch)
338
363
 
364
+ def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
365
+ # router_logits: (num_tokens, n_experts)
366
+ router_logits = self.gate(hidden_states)
367
+
368
+ current_stream = torch.cuda.current_stream()
369
+ self.alt_stream.wait_stream(current_stream)
370
+ shared_output = self._forward_shared_experts(hidden_states)
371
+
372
+ with torch.cuda.stream(self.alt_stream):
373
+ final_hidden_states = self.experts(
374
+ hidden_states=hidden_states, router_logits=router_logits
375
+ )
376
+ if not _is_cuda:
377
+ final_hidden_states *= self.routed_scaling_factor
378
+ current_stream.wait_stream(self.alt_stream)
379
+ final_hidden_states = final_hidden_states + shared_output
380
+ if self.tp_size > 1:
381
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
382
+ return final_hidden_states
383
+
339
384
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
340
385
  shared_output = self._forward_shared_experts(hidden_states)
341
386
  # router_logits: (num_tokens, n_experts)
@@ -665,13 +710,14 @@ class DeepseekV2AttentionMLA(nn.Module):
665
710
  if rope_scaling:
666
711
  rope_scaling["rope_type"] = "deepseek_yarn"
667
712
 
668
- self.rotary_emb = get_rope(
713
+ self.rotary_emb = get_rope_wrapper(
669
714
  qk_rope_head_dim,
670
715
  rotary_dim=qk_rope_head_dim,
671
716
  max_position=max_position_embeddings,
672
717
  base=rope_theta,
673
718
  rope_scaling=rope_scaling,
674
719
  is_neox_style=False,
720
+ device=global_server_args_dict["device"],
675
721
  )
676
722
 
677
723
  if rope_scaling:
@@ -1040,13 +1086,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1040
1086
  masked_m,
1041
1087
  expected_m,
1042
1088
  )
1043
- attn_bmm_output = attn_bmm_output[:, :expected_m, :]
1089
+ attn_bmm_output = (
1090
+ attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
1091
+ )
1044
1092
  elif _is_hip:
1045
1093
  # TODO(haishaw): add bmm_fp8 to ROCm
1046
1094
  attn_bmm_output = torch.bmm(
1047
1095
  attn_output.to(torch.bfloat16).transpose(0, 1),
1048
1096
  self.w_vc.to(torch.bfloat16) * self.w_scale,
1049
1097
  )
1098
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1050
1099
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1051
1100
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1052
1101
  attn_output.transpose(0, 1),
@@ -1059,10 +1108,21 @@ class DeepseekV2AttentionMLA(nn.Module):
1059
1108
  self.w_scale,
1060
1109
  torch.bfloat16,
1061
1110
  )
1111
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1062
1112
  else:
1063
- attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
1064
- attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1065
- output, _ = self.o_proj(attn_output)
1113
+ attn_bmm_output = torch.empty(
1114
+ (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
1115
+ dtype=attn_output.dtype,
1116
+ device=attn_output.device,
1117
+ )
1118
+ torch.bmm(
1119
+ attn_output.transpose(0, 1),
1120
+ self.w_vc,
1121
+ out=attn_bmm_output.view(
1122
+ -1, self.num_local_heads, self.v_head_dim
1123
+ ).transpose(0, 1),
1124
+ )
1125
+ output, _ = self.o_proj(attn_bmm_output)
1066
1126
 
1067
1127
  return output
1068
1128
 
@@ -1399,7 +1459,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1399
1459
  rope_scaling = getattr(config, "rope_scaling", None)
1400
1460
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1401
1461
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1462
+ self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1402
1463
  self.layer_id = layer_id
1464
+ self.is_nextn = is_nextn
1403
1465
  self.self_attn = DeepseekV2AttentionMLA(
1404
1466
  config=config,
1405
1467
  hidden_size=self.hidden_size,
@@ -1426,7 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1426
1488
 
1427
1489
  self.layer_scatter_modes = LayerScatterModes.init_new(
1428
1490
  layer_id=layer_id,
1429
- num_layers=config.num_hidden_layers,
1491
+ num_layers=1 if is_nextn else config.num_hidden_layers,
1430
1492
  is_layer_sparse=self.is_layer_sparse,
1431
1493
  is_previous_layer_sparse=is_previous_layer_sparse,
1432
1494
  )
@@ -1437,6 +1499,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1437
1499
  quant_config=quant_config,
1438
1500
  prefix=add_prefix("mlp", prefix),
1439
1501
  layer_id=self.layer_id,
1502
+ alt_stream=alt_stream,
1440
1503
  )
1441
1504
  else:
1442
1505
  if enable_moe_dense_fully_dp():
@@ -1479,6 +1542,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1479
1542
  residual: Optional[torch.Tensor],
1480
1543
  zero_allocator: BumpAllocator,
1481
1544
  ) -> torch.Tensor:
1545
+
1482
1546
  hidden_states, residual = self.layer_communicator.prepare_attn(
1483
1547
  hidden_states, residual, forward_batch
1484
1548
  )
@@ -1500,6 +1564,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1500
1564
  hidden_states, residual, forward_batch
1501
1565
  )
1502
1566
 
1567
+ if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1568
+ # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1569
+ # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1570
+ hidden_states = hidden_states.clone()
1571
+
1503
1572
  return hidden_states, residual
1504
1573
 
1505
1574
  def op_comm_prepare_attn(
@@ -1607,8 +1676,6 @@ class DeepseekV2Model(nn.Module):
1607
1676
  )
1608
1677
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1609
1678
 
1610
- self.dp_size = get_local_attention_dp_size()
1611
-
1612
1679
  def get_input_embeddings(self) -> torch.Tensor:
1613
1680
  return self.embed_tokens
1614
1681
 
@@ -1692,7 +1759,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1692
1759
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1693
1760
  )
1694
1761
  self.logits_processor = LogitsProcessor(config)
1695
- self.dp_size = get_local_attention_dp_size()
1696
1762
 
1697
1763
  self._routed_experts_weights_of_layer = LazyValue(
1698
1764
  lambda: {
@@ -1717,12 +1783,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1717
1783
  disable_reason = None
1718
1784
  if (
1719
1785
  not _is_cuda
1720
- or torch.cuda.get_device_capability("cuda") < (9, 0)
1786
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
1721
1787
  or self.config.architectures[0] != architecture
1722
1788
  or self.config.n_routed_experts != 256
1723
1789
  or self.config.n_shared_experts != 1
1724
1790
  ):
1725
- disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
1791
+ disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
1726
1792
  elif (
1727
1793
  global_server_args_dict["enable_deepep_moe"]
1728
1794
  or global_server_args_dict["enable_ep_moe"]
@@ -1919,10 +1985,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1919
1985
  if (
1920
1986
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1921
1987
  and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1988
+ and hasattr(self.quant_config, "weight_block_size")
1989
+ and self.quant_config.weight_block_size is not None
1922
1990
  ):
1923
- self._weight_requant_ue8m0()
1991
+ self._weight_requant_ue8m0(is_nextn)
1924
1992
 
1925
- def _weight_requant_ue8m0(self):
1993
+ def _weight_requant_ue8m0(self, is_nextn=False):
1926
1994
  weight_block_size = self.quant_config.weight_block_size
1927
1995
 
1928
1996
  moe_layers = list(
@@ -1933,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1933
2001
  )
1934
2002
  )
1935
2003
 
1936
- for layer_id in range(self.config.num_hidden_layers):
1937
- layer = self.model.layers[layer_id]
2004
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2005
+ for layer_id in range(num_hidden_layers):
2006
+ if is_nextn:
2007
+ layer = self.model.decoder
2008
+ else:
2009
+ layer = self.model.layers[layer_id]
1938
2010
 
1939
2011
  for module in [
1940
2012
  layer.self_attn.fused_qkv_a_proj_with_mqa,
@@ -1946,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1946
2018
  module.weight, module.weight_scale_inv, weight_block_size
1947
2019
  )
1948
2020
 
1949
- if layer_id in moe_layers:
2021
+ if layer_id in moe_layers or is_nextn:
1950
2022
  shared_experts = getattr(layer.mlp, "shared_experts", None)
1951
2023
  if shared_experts is not None:
1952
2024
  for module in [
@@ -2022,7 +2094,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2022
2094
 
2023
2095
  if self.num_fused_shared_experts > 0:
2024
2096
  assert self.num_fused_shared_experts == 1
2025
- logger.info("Shared experts fusion optimization enabled.")
2097
+ log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2026
2098
 
2027
2099
  params_dict = dict(self.named_parameters())
2028
2100
  weight_names = []
@@ -2128,8 +2200,14 @@ class DeepseekV2ForCausalLM(nn.Module):
2128
2200
  ):
2129
2201
  q_a_proj_weight = cached_a_proj[q_a_proj_name]
2130
2202
  kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2203
+ cat_dim = 0
2204
+ if self.quant_config is not None and (
2205
+ self.quant_config.get_name() == "awq"
2206
+ or self.quant_config.get_name() == "moe_wna16"
2207
+ ):
2208
+ cat_dim = 1
2131
2209
  fused_weight = torch.cat(
2132
- [q_a_proj_weight, kv_a_proj_weight], dim=0
2210
+ [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2133
2211
  )
2134
2212
  param_name = (
2135
2213
  name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
@@ -2151,12 +2229,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2151
2229
  "k_scale" in name or "v_scale" in name
2152
2230
  ) and name not in params_dict:
2153
2231
  # modelopt attn kv scale is named differently
2154
- if any(scale in name for scale in ["k_scale", "v_scale"]):
2155
- name = name.replace("_proj", "attn_mqa")
2156
- else:
2157
- logger.warning(
2158
- f"Unknown scale found in checkpoint: {name}"
2159
- )
2232
+ for scale in ["k_scale", "v_scale"]:
2233
+ if scale in name:
2234
+ name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2235
+ break
2236
+ if name not in params_dict:
2237
+ # modelopt ckpt contains not needed weights for MTP module:
2238
+ # model.decoder.self_attn.attn_mqa.v_scale and
2239
+ # model.decoder.self_attn.attn_mqa.k_scale
2240
+ logger.warning(f"{name} not found in params_dict.")
2241
+ continue
2160
2242
  param = params_dict[name]
2161
2243
  weight_loader = getattr(
2162
2244
  param, "weight_loader", default_weight_loader