sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ limitations under the License.
|
|
18
18
|
import gc
|
19
19
|
import importlib
|
20
20
|
import importlib.resources
|
21
|
+
import inspect
|
21
22
|
import json
|
22
23
|
import logging
|
23
24
|
import pkgutil
|
@@ -56,9 +57,11 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
56
57
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
58
|
from sglang.srt.server_args import ServerArgs
|
58
59
|
from sglang.srt.utils import (
|
60
|
+
crash_on_warnings,
|
59
61
|
enable_show_time_cost,
|
60
62
|
get_available_gpu_memory,
|
61
|
-
|
63
|
+
is_hip,
|
64
|
+
monkey_patch_vllm_model_config,
|
62
65
|
monkey_patch_vllm_p2p_access_check,
|
63
66
|
)
|
64
67
|
|
@@ -113,7 +116,7 @@ class ModelRunner:
|
|
113
116
|
)
|
114
117
|
|
115
118
|
if self.is_multimodal:
|
116
|
-
logger.
|
119
|
+
logger.info(
|
117
120
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
118
121
|
)
|
119
122
|
server_args.chunked_prefill_size = None
|
@@ -139,8 +142,8 @@ class ModelRunner:
|
|
139
142
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
140
143
|
"disable_mla": server_args.disable_mla,
|
141
144
|
"torchao_config": server_args.torchao_config,
|
142
|
-
"
|
143
|
-
"
|
145
|
+
"enable_nan_detection": server_args.enable_nan_detection,
|
146
|
+
"enable_dp_attention": server_args.enable_dp_attention,
|
144
147
|
}
|
145
148
|
)
|
146
149
|
|
@@ -148,6 +151,15 @@ class ModelRunner:
|
|
148
151
|
min_per_gpu_memory = self.init_torch_distributed()
|
149
152
|
self.sampler = Sampler()
|
150
153
|
self.load_model()
|
154
|
+
|
155
|
+
# Apply torch TP if model supports it
|
156
|
+
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
157
|
+
if self.tp_size > 1 and supports_torch_tp:
|
158
|
+
self.apply_torch_tp()
|
159
|
+
self.torch_tp_applied = True
|
160
|
+
else:
|
161
|
+
self.torch_tp_applied = False
|
162
|
+
|
151
163
|
if server_args.lora_paths is not None:
|
152
164
|
self.init_lora_manager()
|
153
165
|
self.init_memory_pool(
|
@@ -215,6 +227,47 @@ class ModelRunner:
|
|
215
227
|
|
216
228
|
return min_per_gpu_memory
|
217
229
|
|
230
|
+
def setup_model(self):
|
231
|
+
try:
|
232
|
+
from vllm.config import VllmConfig
|
233
|
+
|
234
|
+
vllm_config = VllmConfig()
|
235
|
+
vllm_config.model_config = self.vllm_model_config
|
236
|
+
vllm_config.load_config = self.load_config
|
237
|
+
vllm_config.device_config = DeviceConfig(self.device)
|
238
|
+
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
239
|
+
vllm_config.model_config, vllm_config.load_config
|
240
|
+
)
|
241
|
+
return get_model(vllm_config=vllm_config)
|
242
|
+
except ImportError:
|
243
|
+
return get_model(
|
244
|
+
model_config=self.vllm_model_config,
|
245
|
+
load_config=self.load_config,
|
246
|
+
device_config=DeviceConfig(self.device),
|
247
|
+
parallel_config=None,
|
248
|
+
scheduler_config=None,
|
249
|
+
lora_config=None,
|
250
|
+
cache_config=None,
|
251
|
+
)
|
252
|
+
|
253
|
+
def get_model_config_params(self):
|
254
|
+
sig = inspect.signature(VllmModelConfig.__init__)
|
255
|
+
params = {
|
256
|
+
"model": self.server_args.model_path,
|
257
|
+
"quantization": self.server_args.quantization,
|
258
|
+
"tokenizer": None,
|
259
|
+
"tokenizer_mode": None,
|
260
|
+
"trust_remote_code": self.server_args.trust_remote_code,
|
261
|
+
"dtype": self.server_args.dtype,
|
262
|
+
"seed": self.server_args.random_seed,
|
263
|
+
"skip_tokenizer_init": True,
|
264
|
+
}
|
265
|
+
|
266
|
+
if "task" in sig.parameters:
|
267
|
+
params["task"] = ""
|
268
|
+
|
269
|
+
return params
|
270
|
+
|
218
271
|
def load_model(self):
|
219
272
|
logger.info(
|
220
273
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -232,42 +285,25 @@ class ModelRunner:
|
|
232
285
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
233
286
|
|
234
287
|
# Prepare the vllm model config
|
235
|
-
monkey_patch_vllm_dummy_weight_loader()
|
236
288
|
self.load_config = LoadConfig(
|
237
289
|
load_format=self.server_args.load_format,
|
238
290
|
download_dir=self.server_args.download_dir,
|
239
291
|
)
|
240
|
-
|
241
|
-
|
242
|
-
quantization=self.server_args.quantization,
|
243
|
-
tokenizer=None,
|
244
|
-
tokenizer_mode=None,
|
245
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
246
|
-
dtype=self.server_args.dtype,
|
247
|
-
seed=self.server_args.random_seed,
|
248
|
-
skip_tokenizer_init=True,
|
249
|
-
)
|
292
|
+
monkey_patch_vllm_model_config()
|
293
|
+
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
250
294
|
if self.model_config.model_override_args is not None:
|
251
295
|
self.vllm_model_config.hf_config.update(
|
252
296
|
self.model_config.model_override_args
|
253
297
|
)
|
254
|
-
self.dtype = self.vllm_model_config.dtype
|
255
298
|
|
256
|
-
|
257
|
-
|
258
|
-
model_config=self.vllm_model_config,
|
259
|
-
load_config=self.load_config,
|
260
|
-
device_config=DeviceConfig(self.device),
|
261
|
-
parallel_config=None,
|
262
|
-
scheduler_config=None,
|
263
|
-
lora_config=None,
|
264
|
-
cache_config=None,
|
265
|
-
)
|
299
|
+
self.model = self.setup_model()
|
300
|
+
|
266
301
|
self.sliding_window_size = (
|
267
302
|
self.model.get_attention_sliding_window_size()
|
268
303
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
269
304
|
else None
|
270
305
|
)
|
306
|
+
self.dtype = self.vllm_model_config.dtype
|
271
307
|
|
272
308
|
logger.info(
|
273
309
|
f"Load weight end. "
|
@@ -293,17 +329,9 @@ class ModelRunner:
|
|
293
329
|
target_device = torch.device(self.device)
|
294
330
|
|
295
331
|
try:
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
quantization=self.server_args.quantization,
|
300
|
-
tokenizer=None,
|
301
|
-
tokenizer_mode=None,
|
302
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
303
|
-
dtype=self.server_args.dtype,
|
304
|
-
seed=self.server_args.random_seed,
|
305
|
-
skip_tokenizer_init=True,
|
306
|
-
)
|
332
|
+
model_config_params = self.get_model_config_params()
|
333
|
+
model_config_params["model"] = model_path
|
334
|
+
vllm_model_config = VllmModelConfig(**model_config_params)
|
307
335
|
except Exception as e:
|
308
336
|
message = f"Failed to load model config: {e}."
|
309
337
|
return False, message
|
@@ -412,7 +440,10 @@ class ModelRunner:
|
|
412
440
|
if self.server_args.kv_cache_dtype == "auto":
|
413
441
|
self.kv_cache_dtype = self.dtype
|
414
442
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
415
|
-
|
443
|
+
if is_hip(): # Using natively supported format
|
444
|
+
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
445
|
+
else:
|
446
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
416
447
|
else:
|
417
448
|
raise ValueError(
|
418
449
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -551,6 +582,13 @@ class ModelRunner:
|
|
551
582
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
552
583
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
553
584
|
|
585
|
+
def apply_torch_tp(self):
|
586
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
587
|
+
from sglang.srt.model_parallel import tensor_parallel
|
588
|
+
|
589
|
+
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
590
|
+
tensor_parallel(self.model, device_mesh)
|
591
|
+
|
554
592
|
def forward_decode(self, forward_batch: ForwardBatch):
|
555
593
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
556
594
|
return self.cuda_graph_runner.replay(forward_batch)
|
@@ -576,21 +614,37 @@ class ModelRunner:
|
|
576
614
|
get_embedding=True,
|
577
615
|
)
|
578
616
|
|
617
|
+
def forward_idle(self, forward_batch: ForwardBatch):
|
618
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
619
|
+
return self.cuda_graph_runner.replay(forward_batch)
|
620
|
+
|
621
|
+
return self.model.forward(
|
622
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
623
|
+
)
|
624
|
+
|
579
625
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
580
626
|
if forward_batch.forward_mode.is_decode():
|
581
627
|
return self.forward_decode(forward_batch)
|
582
628
|
elif forward_batch.forward_mode.is_extend():
|
583
629
|
return self.forward_extend(forward_batch)
|
630
|
+
elif forward_batch.forward_mode.is_idle():
|
631
|
+
return self.forward_idle(forward_batch)
|
584
632
|
else:
|
585
633
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
586
634
|
|
587
635
|
def sample(
|
588
636
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
589
637
|
) -> torch.Tensor:
|
590
|
-
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
591
638
|
sampling_info = forward_batch.sampling_info
|
592
|
-
sampling_info.
|
593
|
-
|
639
|
+
if sampling_info.sampling_info_done:
|
640
|
+
# Overlap mode: the function update_regex_vocab_mask was executed
|
641
|
+
# in process_batch_result of the last batch.
|
642
|
+
if sampling_info.grammars:
|
643
|
+
sampling_info.sampling_info_done.wait()
|
644
|
+
else:
|
645
|
+
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
646
|
+
sampling_info.update_regex_vocab_mask()
|
647
|
+
sampling_info.update_penalties()
|
594
648
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
595
649
|
|
596
650
|
# Sample the next tokens.
|
@@ -616,7 +670,7 @@ class ModelRunner:
|
|
616
670
|
|
617
671
|
# Apply regex vocab_mask
|
618
672
|
if sampling_info.vocab_mask is not None:
|
619
|
-
logits =
|
673
|
+
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
620
674
|
|
621
675
|
return logits
|
622
676
|
|
@@ -640,7 +694,9 @@ def import_model_classes():
|
|
640
694
|
try:
|
641
695
|
module = importlib.import_module(name)
|
642
696
|
except Exception as e:
|
643
|
-
logger.warning(f"Ignore import error when loading {name}.
|
697
|
+
logger.warning(f"Ignore import error when loading {name}. {e}")
|
698
|
+
if crash_on_warnings():
|
699
|
+
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
644
700
|
continue
|
645
701
|
if hasattr(module, "EntryClass"):
|
646
702
|
entry = module.EntryClass
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Common utilities for torch model parallelism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch.distributed.device_mesh import DeviceMesh
|
9
|
+
|
10
|
+
try:
|
11
|
+
from torch.distributed.tensor import DTensor, Shard
|
12
|
+
except ImportError:
|
13
|
+
# torch 2.4 or older
|
14
|
+
from torch.distributed._tensor import DTensor, Shard
|
15
|
+
|
16
|
+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
|
+
from torch.distributed.tensor.parallel import (
|
18
|
+
ColwiseParallel,
|
19
|
+
RowwiseParallel,
|
20
|
+
parallelize_module,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class ColwiseParallelSharded(ColwiseParallel):
|
25
|
+
"""
|
26
|
+
A version of ColwiseParallel where the local weight has been already
|
27
|
+
sharded. This is used for the fused wqkv case, where during loading, we
|
28
|
+
already sharded wq, wk, wv before fusing them.
|
29
|
+
"""
|
30
|
+
|
31
|
+
# Override the _partition_linear_fn in ColwiseParallel
|
32
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
33
|
+
# colwise shard weight/bias to Shard(0), weight be Shard(0)
|
34
|
+
# means Colwise as Linear is input * weight^T + bias, where
|
35
|
+
# weight would become Shard(1)
|
36
|
+
for name, param in module.named_parameters():
|
37
|
+
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
38
|
+
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
|
+
module.register_parameter(name, dist_param)
|
40
|
+
|
41
|
+
|
42
|
+
class RowwiseParallelMaybeWait(RowwiseParallel):
|
43
|
+
"""
|
44
|
+
A version of RowwiseParallel that waits for the output (establish dependency
|
45
|
+
between comm stream and compute stream in CUDA sense) before going into the
|
46
|
+
next op. This is needed to workaround the current interaction between
|
47
|
+
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
|
+
"""
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
|
+
outputs = super(
|
53
|
+
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
|
54
|
+
)._prepare_output_fn(
|
55
|
+
output_layouts, use_local_output, mod, outputs, device_mesh
|
56
|
+
)
|
57
|
+
# wait for the output to be ready
|
58
|
+
if isinstance(outputs, AsyncCollectiveTensor):
|
59
|
+
return outputs.wait()
|
60
|
+
else:
|
61
|
+
return outputs
|
62
|
+
|
63
|
+
|
64
|
+
def tensor_parallel(
|
65
|
+
module: torch.nn.Module,
|
66
|
+
device_mesh: Optional[DeviceMesh] = None,
|
67
|
+
):
|
68
|
+
"""
|
69
|
+
Tensor parallelize the model across the given device mesh.
|
70
|
+
Args:
|
71
|
+
module (`torch.nn.Module`):
|
72
|
+
The module to tensor parallelize.
|
73
|
+
device_mesh (`torch.distributed.DeviceMesh`):
|
74
|
+
The device mesh to use for tensor parallelism.
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
78
|
+
# No op if `_tp_plan` attribute does not exist under the module.
|
79
|
+
# This is a helper function to be used with `model.apply` to recursively
|
80
|
+
# parallelize a model.
|
81
|
+
def tplize(mod: torch.nn.Module) -> None:
|
82
|
+
tp_plan = getattr(mod, "_tp_plan", None)
|
83
|
+
if tp_plan is None:
|
84
|
+
return
|
85
|
+
for child_name, tp_style in tp_plan.items():
|
86
|
+
submod = mod.get_submodule(child_name)
|
87
|
+
if tp_style == "Colwise":
|
88
|
+
parallelize_module(submod, device_mesh, ColwiseParallel())
|
89
|
+
elif tp_style == "Rowwise":
|
90
|
+
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
|
91
|
+
elif tp_style == "Colwise_Sharded":
|
92
|
+
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unknown TP style {tp_style}")
|
95
|
+
|
96
|
+
# `apply` is a native method of `nn.Module` that recursively applies a
|
97
|
+
# function to every submodule.
|
98
|
+
module.apply(tplize)
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -22,7 +22,9 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_rank,
|
25
26
|
get_tensor_model_parallel_world_size,
|
27
|
+
get_tp_group,
|
26
28
|
tensor_model_parallel_all_reduce,
|
27
29
|
)
|
28
30
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
338
340
|
cache_config=None,
|
339
341
|
quant_config: Optional[QuantizationConfig] = None,
|
340
342
|
layer_id=None,
|
343
|
+
use_dp=False,
|
341
344
|
) -> None:
|
342
345
|
super().__init__()
|
343
346
|
self.layer_id = layer_id
|
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
351
354
|
self.num_heads = num_heads
|
352
355
|
tp_size = get_tensor_model_parallel_world_size()
|
353
356
|
assert num_heads % tp_size == 0
|
354
|
-
self.num_local_heads = num_heads // tp_size
|
357
|
+
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
355
358
|
self.scaling = self.qk_head_dim**-0.5
|
356
359
|
self.rope_theta = rope_theta
|
357
360
|
self.max_position_embeddings = max_position_embeddings
|
358
361
|
|
359
|
-
if
|
360
|
-
|
361
|
-
|
362
|
-
self.
|
362
|
+
if use_dp:
|
363
|
+
# For data parallel attention
|
364
|
+
if self.q_lora_rank is not None:
|
365
|
+
self.q_a_proj = ReplicatedLinear(
|
366
|
+
self.hidden_size,
|
367
|
+
self.q_lora_rank,
|
368
|
+
bias=False,
|
369
|
+
quant_config=quant_config,
|
370
|
+
)
|
371
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
372
|
+
self.q_b_proj = ReplicatedLinear(
|
373
|
+
q_lora_rank,
|
374
|
+
self.num_heads * self.qk_head_dim,
|
375
|
+
bias=False,
|
376
|
+
quant_config=quant_config,
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
self.q_proj = ReplicatedLinear(
|
380
|
+
self.hidden_size,
|
381
|
+
self.num_heads * self.qk_head_dim,
|
382
|
+
bias=False,
|
383
|
+
quant_config=quant_config,
|
384
|
+
)
|
385
|
+
self.kv_b_proj = ReplicatedLinear(
|
386
|
+
self.kv_lora_rank,
|
387
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
363
388
|
bias=False,
|
364
389
|
quant_config=quant_config,
|
365
390
|
)
|
366
|
-
|
367
|
-
self.
|
368
|
-
|
369
|
-
self.
|
391
|
+
# O projection.
|
392
|
+
self.o_proj = ReplicatedLinear(
|
393
|
+
self.num_heads * self.v_head_dim,
|
394
|
+
self.hidden_size,
|
370
395
|
bias=False,
|
371
396
|
quant_config=quant_config,
|
372
397
|
)
|
373
398
|
else:
|
374
|
-
|
399
|
+
# For tensor parallel attention
|
400
|
+
if self.q_lora_rank is not None:
|
401
|
+
self.q_a_proj = ReplicatedLinear(
|
402
|
+
self.hidden_size,
|
403
|
+
self.q_lora_rank,
|
404
|
+
bias=False,
|
405
|
+
quant_config=quant_config,
|
406
|
+
)
|
407
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
408
|
+
self.q_b_proj = ColumnParallelLinear(
|
409
|
+
q_lora_rank,
|
410
|
+
self.num_heads * self.qk_head_dim,
|
411
|
+
bias=False,
|
412
|
+
quant_config=quant_config,
|
413
|
+
)
|
414
|
+
else:
|
415
|
+
self.q_proj = ColumnParallelLinear(
|
416
|
+
self.hidden_size,
|
417
|
+
self.num_heads * self.qk_head_dim,
|
418
|
+
bias=False,
|
419
|
+
quant_config=quant_config,
|
420
|
+
)
|
421
|
+
self.kv_b_proj = ColumnParallelLinear(
|
422
|
+
self.kv_lora_rank,
|
423
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
424
|
+
bias=False,
|
425
|
+
quant_config=quant_config,
|
426
|
+
)
|
427
|
+
# O projection.
|
428
|
+
self.o_proj = RowParallelLinear(
|
429
|
+
self.num_heads * self.v_head_dim,
|
375
430
|
self.hidden_size,
|
376
|
-
self.num_heads * self.qk_head_dim,
|
377
431
|
bias=False,
|
378
432
|
quant_config=quant_config,
|
379
433
|
)
|
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
439
|
quant_config=quant_config,
|
386
440
|
)
|
387
441
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
388
|
-
self.kv_b_proj = ColumnParallelLinear(
|
389
|
-
self.kv_lora_rank,
|
390
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
391
|
-
bias=False,
|
392
|
-
quant_config=quant_config,
|
393
|
-
)
|
394
|
-
# O projection.
|
395
|
-
self.o_proj = RowParallelLinear(
|
396
|
-
self.num_heads * self.v_head_dim,
|
397
|
-
self.hidden_size,
|
398
|
-
bias=False,
|
399
|
-
quant_config=quant_config,
|
400
|
-
)
|
401
442
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
402
443
|
self.rotary_emb = get_rope(
|
403
444
|
qk_rope_head_dim,
|
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
491
532
|
return output
|
492
533
|
|
493
534
|
|
535
|
+
def all_gather(
|
536
|
+
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
537
|
+
):
|
538
|
+
if world_size == 1:
|
539
|
+
return input_tensor
|
540
|
+
|
541
|
+
all_lens = forward_batch.global_num_tokens
|
542
|
+
max_len = max(forward_batch.global_num_tokens)
|
543
|
+
|
544
|
+
padded_tensor = torch.nn.functional.pad(
|
545
|
+
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
546
|
+
)
|
547
|
+
|
548
|
+
torch.distributed.all_gather_into_tensor(
|
549
|
+
forward_batch.gathered_buffer, padded_tensor, group=group
|
550
|
+
)
|
551
|
+
|
552
|
+
gathered_tensors = torch.concat(
|
553
|
+
[
|
554
|
+
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
555
|
+
for i in range(world_size)
|
556
|
+
]
|
557
|
+
)
|
558
|
+
|
559
|
+
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
560
|
+
end_index = start_index + all_lens[rank]
|
561
|
+
|
562
|
+
return gathered_tensors, start_index, end_index
|
563
|
+
|
564
|
+
|
494
565
|
class DeepseekV2DecoderLayer(nn.Module):
|
495
566
|
|
496
567
|
def __init__(
|
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
505
576
|
rope_theta = getattr(config, "rope_theta", 10000)
|
506
577
|
rope_scaling = getattr(config, "rope_scaling", None)
|
507
578
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
579
|
+
self.enable_dp_attention = (
|
580
|
+
not global_server_args_dict["disable_mla"]
|
581
|
+
and global_server_args_dict["enable_dp_attention"]
|
582
|
+
)
|
583
|
+
if self.enable_dp_attention:
|
584
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
585
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
586
|
+
self.tp_group = get_tp_group().device_group
|
508
587
|
if not global_server_args_dict["disable_mla"]:
|
509
588
|
self.self_attn = DeepseekV2AttentionMLA(
|
510
589
|
config=config,
|
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
523
602
|
cache_config=cache_config,
|
524
603
|
quant_config=quant_config,
|
525
604
|
layer_id=layer_id,
|
605
|
+
use_dp=self.enable_dp_attention,
|
526
606
|
)
|
527
607
|
else:
|
528
608
|
self.self_attn = DeepseekV2Attention(
|
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
569
649
|
residual: Optional[torch.Tensor],
|
570
650
|
) -> torch.Tensor:
|
571
651
|
# Self Attention
|
572
|
-
if
|
573
|
-
residual
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
hidden_states=
|
580
|
-
|
581
|
-
|
652
|
+
if not forward_batch.forward_mode.is_idle():
|
653
|
+
if residual is None:
|
654
|
+
residual = hidden_states
|
655
|
+
hidden_states = self.input_layernorm(hidden_states)
|
656
|
+
else:
|
657
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
658
|
+
|
659
|
+
hidden_states = self.self_attn(
|
660
|
+
positions=positions,
|
661
|
+
hidden_states=hidden_states,
|
662
|
+
forward_batch=forward_batch,
|
663
|
+
)
|
664
|
+
hidden_states, residual = self.post_attention_layernorm(
|
665
|
+
hidden_states, residual
|
666
|
+
)
|
582
667
|
|
583
668
|
# Fully Connected
|
584
|
-
|
585
|
-
|
669
|
+
if self.enable_dp_attention:
|
670
|
+
hidden_states, start_idx, end_idx = all_gather(
|
671
|
+
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
672
|
+
)
|
673
|
+
hidden_states = self.mlp(hidden_states)
|
674
|
+
hidden_states = hidden_states[start_idx:end_idx]
|
675
|
+
else:
|
676
|
+
hidden_states = self.mlp(hidden_states)
|
677
|
+
|
586
678
|
return hidden_states, residual
|
587
679
|
|
588
680
|
|
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
|
|
603
695
|
self.embed_tokens = VocabParallelEmbedding(
|
604
696
|
config.vocab_size,
|
605
697
|
config.hidden_size,
|
698
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
606
699
|
)
|
607
700
|
self.layers = nn.ModuleList(
|
608
701
|
[
|
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
|
|
630
723
|
hidden_states, residual = layer(
|
631
724
|
positions, hidden_states, forward_batch, residual
|
632
725
|
)
|
633
|
-
|
726
|
+
if not forward_batch.forward_mode.is_idle():
|
727
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
634
728
|
return hidden_states
|
635
729
|
|
636
730
|
|
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
646
740
|
self.config = config
|
647
741
|
self.quant_config = quant_config
|
648
742
|
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
743
|
+
if global_server_args_dict["enable_dp_attention"]:
|
744
|
+
self.lm_head = ReplicatedLinear(
|
745
|
+
config.hidden_size,
|
746
|
+
config.vocab_size,
|
747
|
+
bias=False,
|
748
|
+
)
|
749
|
+
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
750
|
+
else:
|
751
|
+
self.lm_head = ParallelLMHead(
|
752
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
753
|
+
)
|
754
|
+
self.logits_processor = LogitsProcessor(config)
|
653
755
|
|
654
756
|
@torch.no_grad()
|
655
757
|
def forward(
|
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
659
761
|
forward_batch: ForwardBatch,
|
660
762
|
) -> torch.Tensor:
|
661
763
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
662
|
-
|
663
|
-
|
664
|
-
|
764
|
+
if not forward_batch.forward_mode.is_idle():
|
765
|
+
return self.logits_processor(
|
766
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
767
|
+
)
|
665
768
|
|
666
769
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
667
770
|
stacked_params_mapping = [
|