sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
37
 
39
38
  from sglang.srt.layers.activation import SiluAndMul
40
39
  from sglang.srt.layers.layernorm import RMSNorm
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
45
  VocabParallelEmbedding,
47
46
  )
48
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
49
 
50
50
 
51
51
  def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
329
329
  self,
330
330
  config: PretrainedConfig,
331
331
  position_embedding: str,
332
- cache_config=None,
333
332
  quant_config: Optional[QuantizationConfig] = None,
334
333
  ):
335
334
  super().__init__()
@@ -338,11 +337,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
338
337
 
339
338
  self.quant_config = quant_config
340
339
  self.model = BaiChuanModel(config, position_embedding, quant_config)
341
- self.lm_head = ParallelLMHead(
342
- config.vocab_size, config.hidden_size, quant_config=quant_config
343
- )
344
340
  if self.config.tie_word_embeddings:
345
- self.lm_head.weight = self.model.embed_tokens.weight
341
+ self.lm_head = self.model.embed_tokens
342
+ else:
343
+ self.lm_head = ParallelLMHead(
344
+ config.vocab_size, config.hidden_size, quant_config=quant_config
345
+ )
346
346
  self.logits_processor = LogitsProcessor(config)
347
347
 
348
348
  def forward(
@@ -353,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
353
353
  ) -> torch.Tensor:
354
354
  hidden_states = self.model(input_ids, positions, forward_batch)
355
355
  return self.logits_processor(
356
- input_ids, hidden_states, self.lm_head.weight, forward_batch
356
+ input_ids, hidden_states, self.lm_head, forward_batch
357
357
  )
358
358
 
359
359
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -403,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
403
403
  def __init__(
404
404
  self,
405
405
  config,
406
- cache_config=None,
407
406
  quant_config: Optional[QuantizationConfig] = None,
408
407
  ):
409
408
  if config.hidden_size == 4096: # baichuan2 7b
410
- super().__init__(config, "ROPE", cache_config, quant_config)
409
+ super().__init__(config, "ROPE", quant_config)
411
410
  else: # baichuan 13b, baichuan2 13b
412
- super().__init__(config, "ALIBI", cache_config, quant_config)
411
+ super().__init__(config, "ALIBI", quant_config)
413
412
 
414
413
 
415
414
  EntryClass = [BaichuanForCausalLM]
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from torch.nn import LayerNorm
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
  from vllm.transformers_utils.configs import ChatGLMConfig
28
27
 
29
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
40
  VocabParallelEmbedding,
42
41
  )
43
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
44
44
 
45
45
  LoraConfig = None
46
46
 
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
50
50
  self,
51
51
  config,
52
52
  layer_id: int = 0,
53
- cache_config=None,
54
53
  quant_config: Optional[QuantizationConfig] = None,
55
54
  ):
56
55
  super().__init__()
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
186
185
  self,
187
186
  config,
188
187
  layer_id: int,
189
- cache_config=None,
190
188
  quant_config: Optional[QuantizationConfig] = None,
191
189
  ):
192
190
  super().__init__()
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
203
201
  )
204
202
 
205
203
  # Self attention.
206
- self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
204
+ self.self_attention = GLMAttention(config, layer_id, quant_config)
207
205
  self.hidden_dropout = config.hidden_dropout
208
206
 
209
207
  # Layernorm on the attention output
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
258
256
  def __init__(
259
257
  self,
260
258
  config,
261
- cache_config=None,
262
259
  quant_config: Optional[QuantizationConfig] = None,
263
260
  ):
264
261
  super().__init__()
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
269
266
 
270
267
  # Transformer layers.
271
268
  self.layers = nn.ModuleList(
272
- [
273
- GLMBlock(config, i, cache_config, quant_config)
274
- for i in range(self.num_layers)
275
- ]
269
+ [GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
276
270
  )
277
271
 
278
272
  if self.post_layer_norm:
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
306
300
  def __init__(
307
301
  self,
308
302
  config,
309
- cache_config=None,
310
303
  quant_config: Optional[QuantizationConfig] = None,
311
304
  ):
312
305
  super().__init__()
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
318
311
  self.num_layers = config.num_layers
319
312
  self.multi_query_group_num = config.multi_query_group_num
320
313
  self.kv_channels = config.kv_channels
321
- self.encoder = GLMTransformer(config, cache_config, quant_config)
314
+ self.encoder = GLMTransformer(config, quant_config)
322
315
 
323
316
  self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
324
317
 
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
357
350
  def __init__(
358
351
  self,
359
352
  config: ChatGLMConfig,
360
- cache_config=None,
361
353
  quant_config: Optional[QuantizationConfig] = None,
362
- lora_config: Optional[LoraConfig] = None,
363
354
  ):
364
355
  super().__init__()
365
356
  self.config: ChatGLMConfig = config
366
357
  self.quant_config = quant_config
367
358
  self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
368
- self.transformer = ChatGLMM(config, cache_config, quant_config)
359
+ self.transformer = ChatGLMM(config, quant_config)
369
360
  self.lm_head = self.transformer.output_layer
370
361
  self.logits_processor = LogitsProcessor(config)
371
362
 
@@ -378,7 +369,7 @@ class ChatGLMForCausalLM(nn.Module):
378
369
  ) -> torch.Tensor:
379
370
  hidden_states = self.transformer(input_ids, positions, forward_batch)
380
371
  return self.logits_processor(
381
- input_ids, hidden_states, self.lm_head.weight, forward_batch
372
+ input_ids, hidden_states, self.lm_head, forward_batch
382
373
  )
383
374
 
384
375
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -49,7 +49,6 @@ from vllm.distributed import (
49
49
  get_tensor_model_parallel_world_size,
50
50
  )
51
51
  from vllm.model_executor.layers.rotary_embedding import get_rope
52
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
53
52
 
54
53
  from sglang.srt.layers.activation import SiluAndMul
55
54
  from sglang.srt.layers.linear import (
@@ -62,10 +61,11 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
62
61
  from sglang.srt.layers.radix_attention import RadixAttention
63
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
64
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
65
- from sglang.srt.utils import set_weight_attrs
64
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
65
+ from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
66
 
67
67
 
68
- @torch.compile
68
+ @torch.compile(backend=get_compiler_backend())
69
69
  def layer_norm_func(hidden_states, weight, variance_epsilon):
70
70
  input_dtype = hidden_states.dtype
71
71
  hidden_states = hidden_states.to(torch.float32)
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
318
318
  self,
319
319
  config: PretrainedConfig,
320
320
  quant_config: Optional[QuantizationConfig] = None,
321
- cache_config=None,
322
321
  ) -> None:
323
322
  super().__init__()
324
323
  self.config = config
@@ -339,7 +338,7 @@ class CohereForCausalLM(nn.Module):
339
338
  forward_batch,
340
339
  )
341
340
  return self.logits_processor(
342
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
341
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
343
342
  )
344
343
 
345
344
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/dbrx.py CHANGED
@@ -25,7 +25,6 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_reduce,
26
26
  )
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
30
29
 
31
30
  from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
42
  VocabParallelEmbedding,
44
43
  )
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
46
  from sglang.srt.utils import set_weight_attrs
47
47
 
48
48
 
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
366
366
  self,
367
367
  config: DbrxConfig,
368
368
  quant_config: Optional[QuantizationConfig] = None,
369
- cache_config=None,
370
369
  ):
371
370
  super().__init__()
372
371
  self.config = config
@@ -390,7 +389,7 @@ class DbrxForCausalLM(nn.Module):
390
389
  ) -> torch.Tensor:
391
390
  hidden_states = self.transformer(input_ids, positions, forward_batch)
392
391
  return self.logits_processor(
393
- input_ids, hidden_states, self.lm_head.weight, forward_batch
392
+ input_ids, hidden_states, self.lm_head, forward_batch
394
393
  )
395
394
 
396
395
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
45
  VocabParallelEmbedding,
47
46
  )
48
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
49
 
50
50
 
51
51
  class DeepseekMLP(nn.Module):
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
184
184
  rope_theta: float = 10000,
185
185
  rope_scaling: Optional[Dict[str, Any]] = None,
186
186
  max_position_embeddings: int = 8192,
187
- cache_config=None,
188
187
  quant_config: Optional[QuantizationConfig] = None,
189
188
  ) -> None:
190
189
  super().__init__()
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
261
260
  self,
262
261
  config: PretrainedConfig,
263
262
  layer_id: int,
264
- cache_config=None,
265
263
  quant_config: Optional[QuantizationConfig] = None,
266
264
  ) -> None:
267
265
  super().__init__()
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
277
275
  rope_theta=rope_theta,
278
276
  rope_scaling=rope_scaling,
279
277
  max_position_embeddings=max_position_embeddings,
280
- cache_config=cache_config,
281
278
  quant_config=quant_config,
282
279
  )
283
280
  if (
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
330
327
  def __init__(
331
328
  self,
332
329
  config: PretrainedConfig,
333
- cache_config=None,
334
330
  quant_config: Optional[QuantizationConfig] = None,
335
331
  ) -> None:
336
332
  super().__init__()
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
343
339
  )
344
340
  self.layers = nn.ModuleList(
345
341
  [
346
- DeepseekDecoderLayer(
347
- config, layer_id, cache_config, quant_config=quant_config
348
- )
342
+ DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
349
343
  for layer_id in range(config.num_hidden_layers)
350
344
  ]
351
345
  )
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
373
367
  def __init__(
374
368
  self,
375
369
  config: PretrainedConfig,
376
- cache_config=None,
377
370
  quant_config: Optional[QuantizationConfig] = None,
378
371
  ) -> None:
379
372
  super().__init__()
380
373
  self.config = config
381
374
  self.quant_config = quant_config
382
- self.model = DeepseekModel(config, cache_config, quant_config)
375
+ self.model = DeepseekModel(config, quant_config)
383
376
  self.lm_head = ParallelLMHead(
384
377
  config.vocab_size, config.hidden_size, quant_config=quant_config
385
378
  )
@@ -394,7 +387,7 @@ class DeepseekForCausalLM(nn.Module):
394
387
  ) -> torch.Tensor:
395
388
  hidden_states = self.model(input_ids, positions, forward_batch)
396
389
  return self.logits_processor(
397
- input_ids, hidden_states, self.lm_head.weight, forward_batch
390
+ input_ids, hidden_states, self.lm_head, forward_batch
398
391
  )
399
392
 
400
393
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
+ from vllm import _custom_ops as ops
24
25
  from vllm.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
@@ -28,9 +29,9 @@ from vllm.distributed import (
28
29
  tensor_model_parallel_all_reduce,
29
30
  )
30
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
32
 
33
33
  from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.ep_moe.layer import EPMoE
34
35
  from sglang.srt.layers.fused_moe_triton import FusedMoE
35
36
  from sglang.srt.layers.layernorm import RMSNorm
36
37
  from sglang.srt.layers.linear import (
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
49
  )
49
50
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
51
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
51
53
  from sglang.srt.utils import is_flashinfer_available
52
54
 
53
55
  if is_flashinfer_available():
@@ -112,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
112
114
  "Only silu is supported for now."
113
115
  )
114
116
 
115
- self.experts = FusedMoE(
117
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
+ self.experts = MoEImpl(
116
119
  num_experts=config.n_routed_experts,
117
120
  top_k=config.num_experts_per_tok,
118
121
  hidden_size=config.hidden_size,
119
122
  intermediate_size=config.moe_intermediate_size,
120
- reduce_results=False,
121
123
  renormalize=config.norm_topk_prob,
122
124
  quant_config=quant_config,
123
125
  use_grouped_topk=True,
@@ -189,7 +191,6 @@ class DeepseekV2Attention(nn.Module):
189
191
  rope_theta: float = 10000,
190
192
  rope_scaling: Optional[Dict[str, Any]] = None,
191
193
  max_position_embeddings: int = 8192,
192
- cache_config=None,
193
194
  quant_config: Optional[QuantizationConfig] = None,
194
195
  layer_id=None,
195
196
  ) -> None:
@@ -337,7 +338,6 @@ class DeepseekV2AttentionMLA(nn.Module):
337
338
  rope_theta: float = 10000,
338
339
  rope_scaling: Optional[Dict[str, Any]] = None,
339
340
  max_position_embeddings: int = 8192,
340
- cache_config=None,
341
341
  quant_config: Optional[QuantizationConfig] = None,
342
342
  layer_id=None,
343
343
  use_dp=False,
@@ -455,7 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
455
455
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
456
456
  self.scaling = self.scaling * mscale * mscale
457
457
 
458
- self.attn = RadixAttention(
458
+ self.attn_mqa = RadixAttention(
459
459
  self.num_local_heads,
460
460
  self.kv_lora_rank + self.qk_rope_head_dim,
461
461
  self.scaling,
@@ -464,6 +464,15 @@ class DeepseekV2AttentionMLA(nn.Module):
464
464
  v_head_dim=self.kv_lora_rank,
465
465
  )
466
466
 
467
+ self.attn_mha = RadixAttention(
468
+ self.num_local_heads,
469
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
470
+ self.scaling,
471
+ num_kv_heads=self.num_local_heads,
472
+ layer_id=layer_id,
473
+ v_head_dim=self.v_head_dim,
474
+ )
475
+
467
476
  self.w_kc = None
468
477
  self.w_vc = None
469
478
  self.w_scale = None
@@ -473,6 +482,63 @@ class DeepseekV2AttentionMLA(nn.Module):
473
482
  positions: torch.Tensor,
474
483
  hidden_states: torch.Tensor,
475
484
  forward_batch: ForwardBatch,
485
+ ) -> torch.Tensor:
486
+ # Use normal computation for prefill and use weight absorption for extend/decode
487
+ if (
488
+ forward_batch.forward_mode.is_extend()
489
+ and forward_batch.extend_prefix_lens.sum() == 0
490
+ ):
491
+ return self.forward_normal(positions, hidden_states, forward_batch)
492
+ else:
493
+ return self.forward_absorb(positions, hidden_states, forward_batch)
494
+
495
+ def forward_normal(
496
+ self,
497
+ positions: torch.Tensor,
498
+ hidden_states: torch.Tensor,
499
+ forward_batch: ForwardBatch,
500
+ ) -> torch.Tensor:
501
+ if self.q_lora_rank is not None:
502
+ q = self.q_a_proj(hidden_states)[0]
503
+ q = self.q_a_layernorm(q)
504
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
505
+ else:
506
+ q = self.q_proj(hidden_states)[0].view(
507
+ -1, self.num_local_heads, self.qk_head_dim
508
+ )
509
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
510
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
511
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
512
+ latent_cache = latent_cache.unsqueeze(1)
513
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
514
+ kv = self.kv_b_proj(kv_a)[0]
515
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
516
+ k_nope = kv[..., : self.qk_nope_head_dim]
517
+ v = kv[..., self.qk_nope_head_dim :]
518
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
519
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
520
+ q[..., self.qk_nope_head_dim :] = q_pe
521
+ k = torch.empty_like(q)
522
+ k[..., : self.qk_nope_head_dim] = k_nope
523
+ k[..., self.qk_nope_head_dim :] = k_pe
524
+
525
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
526
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
527
+
528
+ # Save latent cache
529
+ forward_batch.token_to_kv_pool.set_kv_buffer(
530
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
531
+ )
532
+ attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
533
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
534
+ output, _ = self.o_proj(attn_output)
535
+ return output
536
+
537
+ def forward_absorb(
538
+ self,
539
+ positions: torch.Tensor,
540
+ hidden_states: torch.Tensor,
541
+ forward_batch: ForwardBatch,
476
542
  ) -> torch.Tensor:
477
543
  q_len = hidden_states.shape[0]
478
544
  q_input = hidden_states.new_empty(
@@ -510,7 +576,7 @@ class DeepseekV2AttentionMLA(nn.Module):
510
576
  q_input[..., self.kv_lora_rank :] = q_pe
511
577
  k_input[..., self.kv_lora_rank :] = k_pe
512
578
 
513
- attn_output = self.attn(q_input, k_input, v_input, forward_batch)
579
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
514
580
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
515
581
 
516
582
  if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -568,7 +634,6 @@ class DeepseekV2DecoderLayer(nn.Module):
568
634
  self,
569
635
  config: PretrainedConfig,
570
636
  layer_id: int,
571
- cache_config=None,
572
637
  quant_config: Optional[QuantizationConfig] = None,
573
638
  ) -> None:
574
639
  super().__init__()
@@ -599,7 +664,6 @@ class DeepseekV2DecoderLayer(nn.Module):
599
664
  rope_theta=rope_theta,
600
665
  rope_scaling=rope_scaling,
601
666
  max_position_embeddings=max_position_embeddings,
602
- cache_config=cache_config,
603
667
  quant_config=quant_config,
604
668
  layer_id=layer_id,
605
669
  use_dp=self.enable_dp_attention,
@@ -619,7 +683,6 @@ class DeepseekV2DecoderLayer(nn.Module):
619
683
  rope_theta=rope_theta,
620
684
  rope_scaling=rope_scaling,
621
685
  max_position_embeddings=max_position_embeddings,
622
- cache_config=cache_config,
623
686
  quant_config=quant_config,
624
687
  layer_id=layer_id,
625
688
  )
@@ -685,7 +748,6 @@ class DeepseekV2Model(nn.Module):
685
748
  def __init__(
686
749
  self,
687
750
  config: PretrainedConfig,
688
- cache_config=None,
689
751
  quant_config: Optional[QuantizationConfig] = None,
690
752
  ) -> None:
691
753
  super().__init__()
@@ -702,7 +764,6 @@ class DeepseekV2Model(nn.Module):
702
764
  DeepseekV2DecoderLayer(
703
765
  config,
704
766
  layer_id,
705
- cache_config=cache_config,
706
767
  quant_config=quant_config,
707
768
  )
708
769
  for layer_id in range(config.num_hidden_layers)
@@ -733,13 +794,12 @@ class DeepseekV2ForCausalLM(nn.Module):
733
794
  def __init__(
734
795
  self,
735
796
  config: PretrainedConfig,
736
- cache_config=None,
737
797
  quant_config: Optional[QuantizationConfig] = None,
738
798
  ) -> None:
739
799
  super().__init__()
740
800
  self.config = config
741
801
  self.quant_config = quant_config
742
- self.model = DeepseekV2Model(config, cache_config, quant_config)
802
+ self.model = DeepseekV2Model(config, quant_config)
743
803
  if global_server_args_dict["enable_dp_attention"]:
744
804
  self.lm_head = ReplicatedLinear(
745
805
  config.hidden_size,
@@ -763,7 +823,7 @@ class DeepseekV2ForCausalLM(nn.Module):
763
823
  hidden_states = self.model(input_ids, positions, forward_batch)
764
824
  if not forward_batch.forward_mode.is_idle():
765
825
  return self.logits_processor(
766
- input_ids, hidden_states, self.lm_head.weight, forward_batch
826
+ input_ids, hidden_states, self.lm_head, forward_batch
767
827
  )
768
828
 
769
829
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -775,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
775
835
 
776
836
  # Params for weights, fp8 weight scales, fp8 activation scales
777
837
  # (param_name, weight_name, expert_id, shard_id)
778
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
838
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
839
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
779
840
  ckpt_gate_proj_name="gate_proj",
780
841
  ckpt_down_proj_name="down_proj",
781
842
  ckpt_up_proj_name="up_proj",
@@ -836,14 +897,25 @@ class DeepseekV2ForCausalLM(nn.Module):
836
897
  if not global_server_args_dict["disable_mla"]:
837
898
  for layer_id in range(self.config.num_hidden_layers):
838
899
  self_attn = self.model.layers[layer_id].self_attn
839
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
900
+ if hasattr(self_attn.kv_b_proj, "qweight"):
901
+ # AWQ compatible
902
+ w = ops.awq_dequantize(
903
+ self_attn.kv_b_proj.qweight,
904
+ self_attn.kv_b_proj.scales,
905
+ self_attn.kv_b_proj.qzeros,
906
+ 0,
907
+ 0,
908
+ 0,
909
+ ).T
910
+ else:
911
+ w = self_attn.kv_b_proj.weight
912
+ w_kc, w_vc = w.unflatten(
840
913
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
841
914
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
842
915
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
843
916
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
844
917
  if hasattr(self_attn.kv_b_proj, "weight_scale"):
845
918
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
846
- del self_attn.kv_b_proj
847
919
 
848
920
 
849
921
  EntryClass = DeepseekV2ForCausalLM
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
38
  VocabParallelEmbedding,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
 
43
43
 
44
44
  class ExaoneGatedMLP(nn.Module):
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
293
293
  self,
294
294
  config,
295
295
  quant_config: Optional[QuantizationConfig] = None,
296
- cache_config=None,
297
296
  ) -> None:
298
297
  super().__init__()
299
298
  self.config = config
@@ -314,7 +313,7 @@ class ExaoneForCausalLM(nn.Module):
314
313
  input_ids, positions, forward_batch, input_embeds
315
314
  )
316
315
  return self.logits_processor(
317
- input_ids, hidden_states, self.lm_head.weight, forward_batch
316
+ input_ids, hidden_states, self.lm_head, forward_batch
318
317
  )
319
318
 
320
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import LoRAConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
26
 
29
27
  from sglang.srt.layers.activation import GeluAndMul
30
28
  from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
36
  from sglang.srt.layers.radix_attention import RadixAttention
39
37
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
40
 
42
41
 
43
42
  class GemmaMLP(nn.Module):
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
278
277
  self,
279
278
  config: PretrainedConfig,
280
279
  quant_config: Optional[QuantizationConfig] = None,
281
- lora_config: Optional[LoRAConfig] = None,
282
- cache_config=None,
283
280
  ) -> None:
284
- del lora_config # Unused.
285
281
  super().__init__()
286
282
  self.config = config
287
283
  self.quant_config = quant_config
@@ -298,7 +294,7 @@ class GemmaForCausalLM(nn.Module):
298
294
  ) -> torch.Tensor:
299
295
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
300
296
  return self.logits_processor(
301
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
297
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
302
298
  )
303
299
 
304
300
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):