sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ limitations under the License.
18
18
  import gc
19
19
  import importlib
20
20
  import importlib.resources
21
+ import inspect
21
22
  import json
22
23
  import logging
23
24
  import pkgutil
@@ -56,9 +57,11 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
57
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
57
58
  from sglang.srt.server_args import ServerArgs
58
59
  from sglang.srt.utils import (
60
+ crash_on_warnings,
59
61
  enable_show_time_cost,
60
62
  get_available_gpu_memory,
61
- monkey_patch_vllm_dummy_weight_loader,
63
+ is_hip,
64
+ monkey_patch_vllm_model_config,
62
65
  monkey_patch_vllm_p2p_access_check,
63
66
  )
64
67
 
@@ -113,7 +116,7 @@ class ModelRunner:
113
116
  )
114
117
 
115
118
  if self.is_multimodal:
116
- logger.warning(
119
+ logger.info(
117
120
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
118
121
  )
119
122
  server_args.chunked_prefill_size = None
@@ -139,8 +142,8 @@ class ModelRunner:
139
142
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
140
143
  "disable_mla": server_args.disable_mla,
141
144
  "torchao_config": server_args.torchao_config,
142
- "disable_penalizer": server_args.disable_penalizer,
143
- "disable_nan_detection": server_args.disable_nan_detection,
145
+ "enable_nan_detection": server_args.enable_nan_detection,
146
+ "enable_dp_attention": server_args.enable_dp_attention,
144
147
  }
145
148
  )
146
149
 
@@ -148,6 +151,15 @@ class ModelRunner:
148
151
  min_per_gpu_memory = self.init_torch_distributed()
149
152
  self.sampler = Sampler()
150
153
  self.load_model()
154
+
155
+ # Apply torch TP if model supports it
156
+ supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
157
+ if self.tp_size > 1 and supports_torch_tp:
158
+ self.apply_torch_tp()
159
+ self.torch_tp_applied = True
160
+ else:
161
+ self.torch_tp_applied = False
162
+
151
163
  if server_args.lora_paths is not None:
152
164
  self.init_lora_manager()
153
165
  self.init_memory_pool(
@@ -215,6 +227,47 @@ class ModelRunner:
215
227
 
216
228
  return min_per_gpu_memory
217
229
 
230
+ def setup_model(self):
231
+ try:
232
+ from vllm.config import VllmConfig
233
+
234
+ vllm_config = VllmConfig()
235
+ vllm_config.model_config = self.vllm_model_config
236
+ vllm_config.load_config = self.load_config
237
+ vllm_config.device_config = DeviceConfig(self.device)
238
+ vllm_config.quant_config = VllmConfig._get_quantization_config(
239
+ vllm_config.model_config, vllm_config.load_config
240
+ )
241
+ return get_model(vllm_config=vllm_config)
242
+ except ImportError:
243
+ return get_model(
244
+ model_config=self.vllm_model_config,
245
+ load_config=self.load_config,
246
+ device_config=DeviceConfig(self.device),
247
+ parallel_config=None,
248
+ scheduler_config=None,
249
+ lora_config=None,
250
+ cache_config=None,
251
+ )
252
+
253
+ def get_model_config_params(self):
254
+ sig = inspect.signature(VllmModelConfig.__init__)
255
+ params = {
256
+ "model": self.server_args.model_path,
257
+ "quantization": self.server_args.quantization,
258
+ "tokenizer": None,
259
+ "tokenizer_mode": None,
260
+ "trust_remote_code": self.server_args.trust_remote_code,
261
+ "dtype": self.server_args.dtype,
262
+ "seed": self.server_args.random_seed,
263
+ "skip_tokenizer_init": True,
264
+ }
265
+
266
+ if "task" in sig.parameters:
267
+ params["task"] = ""
268
+
269
+ return params
270
+
218
271
  def load_model(self):
219
272
  logger.info(
220
273
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -232,42 +285,25 @@ class ModelRunner:
232
285
  raise RuntimeError("SGLang only supports sm75 and above.")
233
286
 
234
287
  # Prepare the vllm model config
235
- monkey_patch_vllm_dummy_weight_loader()
236
288
  self.load_config = LoadConfig(
237
289
  load_format=self.server_args.load_format,
238
290
  download_dir=self.server_args.download_dir,
239
291
  )
240
- self.vllm_model_config = VllmModelConfig(
241
- model=self.server_args.model_path,
242
- quantization=self.server_args.quantization,
243
- tokenizer=None,
244
- tokenizer_mode=None,
245
- trust_remote_code=self.server_args.trust_remote_code,
246
- dtype=self.server_args.dtype,
247
- seed=self.server_args.random_seed,
248
- skip_tokenizer_init=True,
249
- )
292
+ monkey_patch_vllm_model_config()
293
+ self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
250
294
  if self.model_config.model_override_args is not None:
251
295
  self.vllm_model_config.hf_config.update(
252
296
  self.model_config.model_override_args
253
297
  )
254
- self.dtype = self.vllm_model_config.dtype
255
298
 
256
- # Load the model
257
- self.model = get_model(
258
- model_config=self.vllm_model_config,
259
- load_config=self.load_config,
260
- device_config=DeviceConfig(self.device),
261
- parallel_config=None,
262
- scheduler_config=None,
263
- lora_config=None,
264
- cache_config=None,
265
- )
299
+ self.model = self.setup_model()
300
+
266
301
  self.sliding_window_size = (
267
302
  self.model.get_attention_sliding_window_size()
268
303
  if hasattr(self.model, "get_attention_sliding_window_size")
269
304
  else None
270
305
  )
306
+ self.dtype = self.vllm_model_config.dtype
271
307
 
272
308
  logger.info(
273
309
  f"Load weight end. "
@@ -293,17 +329,9 @@ class ModelRunner:
293
329
  target_device = torch.device(self.device)
294
330
 
295
331
  try:
296
- # TODO: Use a better method to check this
297
- vllm_model_config = VllmModelConfig(
298
- model=model_path,
299
- quantization=self.server_args.quantization,
300
- tokenizer=None,
301
- tokenizer_mode=None,
302
- trust_remote_code=self.server_args.trust_remote_code,
303
- dtype=self.server_args.dtype,
304
- seed=self.server_args.random_seed,
305
- skip_tokenizer_init=True,
306
- )
332
+ model_config_params = self.get_model_config_params()
333
+ model_config_params["model"] = model_path
334
+ vllm_model_config = VllmModelConfig(**model_config_params)
307
335
  except Exception as e:
308
336
  message = f"Failed to load model config: {e}."
309
337
  return False, message
@@ -412,7 +440,10 @@ class ModelRunner:
412
440
  if self.server_args.kv_cache_dtype == "auto":
413
441
  self.kv_cache_dtype = self.dtype
414
442
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
415
- self.kv_cache_dtype = torch.float8_e5m2
443
+ if is_hip(): # Using natively supported format
444
+ self.kv_cache_dtype = torch.float8_e5m2fnuz
445
+ else:
446
+ self.kv_cache_dtype = torch.float8_e5m2
416
447
  else:
417
448
  raise ValueError(
418
449
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -551,6 +582,13 @@ class ModelRunner:
551
582
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
552
583
  self.cuda_graph_runner = CudaGraphRunner(self)
553
584
 
585
+ def apply_torch_tp(self):
586
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
587
+ from sglang.srt.model_parallel import tensor_parallel
588
+
589
+ device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
590
+ tensor_parallel(self.model, device_mesh)
591
+
554
592
  def forward_decode(self, forward_batch: ForwardBatch):
555
593
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
556
594
  return self.cuda_graph_runner.replay(forward_batch)
@@ -576,21 +614,37 @@ class ModelRunner:
576
614
  get_embedding=True,
577
615
  )
578
616
 
617
+ def forward_idle(self, forward_batch: ForwardBatch):
618
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
619
+ return self.cuda_graph_runner.replay(forward_batch)
620
+
621
+ return self.model.forward(
622
+ forward_batch.input_ids, forward_batch.positions, forward_batch
623
+ )
624
+
579
625
  def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
580
626
  if forward_batch.forward_mode.is_decode():
581
627
  return self.forward_decode(forward_batch)
582
628
  elif forward_batch.forward_mode.is_extend():
583
629
  return self.forward_extend(forward_batch)
630
+ elif forward_batch.forward_mode.is_idle():
631
+ return self.forward_idle(forward_batch)
584
632
  else:
585
633
  raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
586
634
 
587
635
  def sample(
588
636
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
589
637
  ) -> torch.Tensor:
590
- # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
591
638
  sampling_info = forward_batch.sampling_info
592
- sampling_info.update_regex_vocab_mask()
593
- sampling_info.update_penalties()
639
+ if sampling_info.sampling_info_done:
640
+ # Overlap mode: the function update_regex_vocab_mask was executed
641
+ # in process_batch_result of the last batch.
642
+ if sampling_info.grammars:
643
+ sampling_info.sampling_info_done.wait()
644
+ else:
645
+ # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
646
+ sampling_info.update_regex_vocab_mask()
647
+ sampling_info.update_penalties()
594
648
  logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
595
649
 
596
650
  # Sample the next tokens.
@@ -616,7 +670,7 @@ class ModelRunner:
616
670
 
617
671
  # Apply regex vocab_mask
618
672
  if sampling_info.vocab_mask is not None:
619
- logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
673
+ sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
620
674
 
621
675
  return logits
622
676
 
@@ -640,7 +694,9 @@ def import_model_classes():
640
694
  try:
641
695
  module = importlib.import_module(name)
642
696
  except Exception as e:
643
- logger.warning(f"Ignore import error when loading {name}. " f"{e}")
697
+ logger.warning(f"Ignore import error when loading {name}. {e}")
698
+ if crash_on_warnings():
699
+ raise ValueError(f"Ignore import error when loading {name}. {e}")
644
700
  continue
645
701
  if hasattr(module, "EntryClass"):
646
702
  entry = module.EntryClass
@@ -0,0 +1,98 @@
1
+ """
2
+ Common utilities for torch model parallelism.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch.distributed.device_mesh import DeviceMesh
9
+
10
+ try:
11
+ from torch.distributed.tensor import DTensor, Shard
12
+ except ImportError:
13
+ # torch 2.4 or older
14
+ from torch.distributed._tensor import DTensor, Shard
15
+
16
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
17
+ from torch.distributed.tensor.parallel import (
18
+ ColwiseParallel,
19
+ RowwiseParallel,
20
+ parallelize_module,
21
+ )
22
+
23
+
24
+ class ColwiseParallelSharded(ColwiseParallel):
25
+ """
26
+ A version of ColwiseParallel where the local weight has been already
27
+ sharded. This is used for the fused wqkv case, where during loading, we
28
+ already sharded wq, wk, wv before fusing them.
29
+ """
30
+
31
+ # Override the _partition_linear_fn in ColwiseParallel
32
+ def _partition_linear_fn(self, name, module, device_mesh):
33
+ # colwise shard weight/bias to Shard(0), weight be Shard(0)
34
+ # means Colwise as Linear is input * weight^T + bias, where
35
+ # weight would become Shard(1)
36
+ for name, param in module.named_parameters():
37
+ dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
38
+ dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
39
+ module.register_parameter(name, dist_param)
40
+
41
+
42
+ class RowwiseParallelMaybeWait(RowwiseParallel):
43
+ """
44
+ A version of RowwiseParallel that waits for the output (establish dependency
45
+ between comm stream and compute stream in CUDA sense) before going into the
46
+ next op. This is needed to workaround the current interaction between
47
+ AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
48
+ """
49
+
50
+ @staticmethod
51
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
52
+ outputs = super(
53
+ RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
54
+ )._prepare_output_fn(
55
+ output_layouts, use_local_output, mod, outputs, device_mesh
56
+ )
57
+ # wait for the output to be ready
58
+ if isinstance(outputs, AsyncCollectiveTensor):
59
+ return outputs.wait()
60
+ else:
61
+ return outputs
62
+
63
+
64
+ def tensor_parallel(
65
+ module: torch.nn.Module,
66
+ device_mesh: Optional[DeviceMesh] = None,
67
+ ):
68
+ """
69
+ Tensor parallelize the model across the given device mesh.
70
+ Args:
71
+ module (`torch.nn.Module`):
72
+ The module to tensor parallelize.
73
+ device_mesh (`torch.distributed.DeviceMesh`):
74
+ The device mesh to use for tensor parallelism.
75
+ """
76
+
77
+ # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
78
+ # No op if `_tp_plan` attribute does not exist under the module.
79
+ # This is a helper function to be used with `model.apply` to recursively
80
+ # parallelize a model.
81
+ def tplize(mod: torch.nn.Module) -> None:
82
+ tp_plan = getattr(mod, "_tp_plan", None)
83
+ if tp_plan is None:
84
+ return
85
+ for child_name, tp_style in tp_plan.items():
86
+ submod = mod.get_submodule(child_name)
87
+ if tp_style == "Colwise":
88
+ parallelize_module(submod, device_mesh, ColwiseParallel())
89
+ elif tp_style == "Rowwise":
90
+ parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
91
+ elif tp_style == "Colwise_Sharded":
92
+ parallelize_module(submod, device_mesh, ColwiseParallelSharded())
93
+ else:
94
+ raise ValueError(f"Unknown TP style {tp_style}")
95
+
96
+ # `apply` is a native method of `nn.Module` that recursively applies a
97
+ # function to every submodule.
98
+ module.apply(tplize)
@@ -22,7 +22,9 @@ import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.distributed import (
25
+ get_tensor_model_parallel_rank,
25
26
  get_tensor_model_parallel_world_size,
27
+ get_tp_group,
26
28
  tensor_model_parallel_all_reduce,
27
29
  )
28
30
  from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
338
340
  cache_config=None,
339
341
  quant_config: Optional[QuantizationConfig] = None,
340
342
  layer_id=None,
343
+ use_dp=False,
341
344
  ) -> None:
342
345
  super().__init__()
343
346
  self.layer_id = layer_id
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
351
354
  self.num_heads = num_heads
352
355
  tp_size = get_tensor_model_parallel_world_size()
353
356
  assert num_heads % tp_size == 0
354
- self.num_local_heads = num_heads // tp_size
357
+ self.num_local_heads = num_heads if use_dp else num_heads // tp_size
355
358
  self.scaling = self.qk_head_dim**-0.5
356
359
  self.rope_theta = rope_theta
357
360
  self.max_position_embeddings = max_position_embeddings
358
361
 
359
- if self.q_lora_rank is not None:
360
- self.q_a_proj = ReplicatedLinear(
361
- self.hidden_size,
362
- self.q_lora_rank,
362
+ if use_dp:
363
+ # For data parallel attention
364
+ if self.q_lora_rank is not None:
365
+ self.q_a_proj = ReplicatedLinear(
366
+ self.hidden_size,
367
+ self.q_lora_rank,
368
+ bias=False,
369
+ quant_config=quant_config,
370
+ )
371
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
372
+ self.q_b_proj = ReplicatedLinear(
373
+ q_lora_rank,
374
+ self.num_heads * self.qk_head_dim,
375
+ bias=False,
376
+ quant_config=quant_config,
377
+ )
378
+ else:
379
+ self.q_proj = ReplicatedLinear(
380
+ self.hidden_size,
381
+ self.num_heads * self.qk_head_dim,
382
+ bias=False,
383
+ quant_config=quant_config,
384
+ )
385
+ self.kv_b_proj = ReplicatedLinear(
386
+ self.kv_lora_rank,
387
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
363
388
  bias=False,
364
389
  quant_config=quant_config,
365
390
  )
366
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
367
- self.q_b_proj = ColumnParallelLinear(
368
- q_lora_rank,
369
- self.num_heads * self.qk_head_dim,
391
+ # O projection.
392
+ self.o_proj = ReplicatedLinear(
393
+ self.num_heads * self.v_head_dim,
394
+ self.hidden_size,
370
395
  bias=False,
371
396
  quant_config=quant_config,
372
397
  )
373
398
  else:
374
- self.q_proj = ColumnParallelLinear(
399
+ # For tensor parallel attention
400
+ if self.q_lora_rank is not None:
401
+ self.q_a_proj = ReplicatedLinear(
402
+ self.hidden_size,
403
+ self.q_lora_rank,
404
+ bias=False,
405
+ quant_config=quant_config,
406
+ )
407
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
408
+ self.q_b_proj = ColumnParallelLinear(
409
+ q_lora_rank,
410
+ self.num_heads * self.qk_head_dim,
411
+ bias=False,
412
+ quant_config=quant_config,
413
+ )
414
+ else:
415
+ self.q_proj = ColumnParallelLinear(
416
+ self.hidden_size,
417
+ self.num_heads * self.qk_head_dim,
418
+ bias=False,
419
+ quant_config=quant_config,
420
+ )
421
+ self.kv_b_proj = ColumnParallelLinear(
422
+ self.kv_lora_rank,
423
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
424
+ bias=False,
425
+ quant_config=quant_config,
426
+ )
427
+ # O projection.
428
+ self.o_proj = RowParallelLinear(
429
+ self.num_heads * self.v_head_dim,
375
430
  self.hidden_size,
376
- self.num_heads * self.qk_head_dim,
377
431
  bias=False,
378
432
  quant_config=quant_config,
379
433
  )
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
385
439
  quant_config=quant_config,
386
440
  )
387
441
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
388
- self.kv_b_proj = ColumnParallelLinear(
389
- self.kv_lora_rank,
390
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
391
- bias=False,
392
- quant_config=quant_config,
393
- )
394
- # O projection.
395
- self.o_proj = RowParallelLinear(
396
- self.num_heads * self.v_head_dim,
397
- self.hidden_size,
398
- bias=False,
399
- quant_config=quant_config,
400
- )
401
442
  rope_scaling["rope_type"] = "deepseek_yarn"
402
443
  self.rotary_emb = get_rope(
403
444
  qk_rope_head_dim,
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
491
532
  return output
492
533
 
493
534
 
535
+ def all_gather(
536
+ input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
537
+ ):
538
+ if world_size == 1:
539
+ return input_tensor
540
+
541
+ all_lens = forward_batch.global_num_tokens
542
+ max_len = max(forward_batch.global_num_tokens)
543
+
544
+ padded_tensor = torch.nn.functional.pad(
545
+ input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
546
+ )
547
+
548
+ torch.distributed.all_gather_into_tensor(
549
+ forward_batch.gathered_buffer, padded_tensor, group=group
550
+ )
551
+
552
+ gathered_tensors = torch.concat(
553
+ [
554
+ forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
555
+ for i in range(world_size)
556
+ ]
557
+ )
558
+
559
+ start_index = 0 if rank == 0 else sum(all_lens[:rank])
560
+ end_index = start_index + all_lens[rank]
561
+
562
+ return gathered_tensors, start_index, end_index
563
+
564
+
494
565
  class DeepseekV2DecoderLayer(nn.Module):
495
566
 
496
567
  def __init__(
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
505
576
  rope_theta = getattr(config, "rope_theta", 10000)
506
577
  rope_scaling = getattr(config, "rope_scaling", None)
507
578
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
579
+ self.enable_dp_attention = (
580
+ not global_server_args_dict["disable_mla"]
581
+ and global_server_args_dict["enable_dp_attention"]
582
+ )
583
+ if self.enable_dp_attention:
584
+ self.tp_rank = get_tensor_model_parallel_rank()
585
+ self.tp_size = get_tensor_model_parallel_world_size()
586
+ self.tp_group = get_tp_group().device_group
508
587
  if not global_server_args_dict["disable_mla"]:
509
588
  self.self_attn = DeepseekV2AttentionMLA(
510
589
  config=config,
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
523
602
  cache_config=cache_config,
524
603
  quant_config=quant_config,
525
604
  layer_id=layer_id,
605
+ use_dp=self.enable_dp_attention,
526
606
  )
527
607
  else:
528
608
  self.self_attn = DeepseekV2Attention(
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
569
649
  residual: Optional[torch.Tensor],
570
650
  ) -> torch.Tensor:
571
651
  # Self Attention
572
- if residual is None:
573
- residual = hidden_states
574
- hidden_states = self.input_layernorm(hidden_states)
575
- else:
576
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
577
- hidden_states = self.self_attn(
578
- positions=positions,
579
- hidden_states=hidden_states,
580
- forward_batch=forward_batch,
581
- )
652
+ if not forward_batch.forward_mode.is_idle():
653
+ if residual is None:
654
+ residual = hidden_states
655
+ hidden_states = self.input_layernorm(hidden_states)
656
+ else:
657
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
658
+
659
+ hidden_states = self.self_attn(
660
+ positions=positions,
661
+ hidden_states=hidden_states,
662
+ forward_batch=forward_batch,
663
+ )
664
+ hidden_states, residual = self.post_attention_layernorm(
665
+ hidden_states, residual
666
+ )
582
667
 
583
668
  # Fully Connected
584
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
585
- hidden_states = self.mlp(hidden_states)
669
+ if self.enable_dp_attention:
670
+ hidden_states, start_idx, end_idx = all_gather(
671
+ hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
672
+ )
673
+ hidden_states = self.mlp(hidden_states)
674
+ hidden_states = hidden_states[start_idx:end_idx]
675
+ else:
676
+ hidden_states = self.mlp(hidden_states)
677
+
586
678
  return hidden_states, residual
587
679
 
588
680
 
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
603
695
  self.embed_tokens = VocabParallelEmbedding(
604
696
  config.vocab_size,
605
697
  config.hidden_size,
698
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
606
699
  )
607
700
  self.layers = nn.ModuleList(
608
701
  [
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
630
723
  hidden_states, residual = layer(
631
724
  positions, hidden_states, forward_batch, residual
632
725
  )
633
- hidden_states, _ = self.norm(hidden_states, residual)
726
+ if not forward_batch.forward_mode.is_idle():
727
+ hidden_states, _ = self.norm(hidden_states, residual)
634
728
  return hidden_states
635
729
 
636
730
 
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
646
740
  self.config = config
647
741
  self.quant_config = quant_config
648
742
  self.model = DeepseekV2Model(config, cache_config, quant_config)
649
- self.lm_head = ParallelLMHead(
650
- config.vocab_size, config.hidden_size, quant_config=quant_config
651
- )
652
- self.logits_processor = LogitsProcessor(config)
743
+ if global_server_args_dict["enable_dp_attention"]:
744
+ self.lm_head = ReplicatedLinear(
745
+ config.hidden_size,
746
+ config.vocab_size,
747
+ bias=False,
748
+ )
749
+ self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
750
+ else:
751
+ self.lm_head = ParallelLMHead(
752
+ config.vocab_size, config.hidden_size, quant_config=quant_config
753
+ )
754
+ self.logits_processor = LogitsProcessor(config)
653
755
 
654
756
  @torch.no_grad()
655
757
  def forward(
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
659
761
  forward_batch: ForwardBatch,
660
762
  ) -> torch.Tensor:
661
763
  hidden_states = self.model(input_ids, positions, forward_batch)
662
- return self.logits_processor(
663
- input_ids, hidden_states, self.lm_head.weight, forward_batch
664
- )
764
+ if not forward_batch.forward_mode.is_idle():
765
+ return self.logits_processor(
766
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
767
+ )
665
768
 
666
769
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
667
770
  stacked_params_mapping = [