sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ limitations under the License.
|
|
18
18
|
import gc
|
19
19
|
import importlib
|
20
20
|
import importlib.resources
|
21
|
+
import inspect
|
21
22
|
import json
|
22
23
|
import logging
|
23
24
|
import pkgutil
|
@@ -56,9 +57,11 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
56
57
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
58
|
from sglang.srt.server_args import ServerArgs
|
58
59
|
from sglang.srt.utils import (
|
60
|
+
crash_on_warnings,
|
59
61
|
enable_show_time_cost,
|
60
62
|
get_available_gpu_memory,
|
61
|
-
|
63
|
+
is_hip,
|
64
|
+
monkey_patch_vllm_model_config,
|
62
65
|
monkey_patch_vllm_p2p_access_check,
|
63
66
|
)
|
64
67
|
|
@@ -113,7 +116,7 @@ class ModelRunner:
|
|
113
116
|
)
|
114
117
|
|
115
118
|
if self.is_multimodal:
|
116
|
-
logger.
|
119
|
+
logger.info(
|
117
120
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
118
121
|
)
|
119
122
|
server_args.chunked_prefill_size = None
|
@@ -139,8 +142,8 @@ class ModelRunner:
|
|
139
142
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
140
143
|
"disable_mla": server_args.disable_mla,
|
141
144
|
"torchao_config": server_args.torchao_config,
|
142
|
-
"
|
143
|
-
"
|
145
|
+
"enable_nan_detection": server_args.enable_nan_detection,
|
146
|
+
"enable_dp_attention": server_args.enable_dp_attention,
|
144
147
|
}
|
145
148
|
)
|
146
149
|
|
@@ -148,6 +151,15 @@ class ModelRunner:
|
|
148
151
|
min_per_gpu_memory = self.init_torch_distributed()
|
149
152
|
self.sampler = Sampler()
|
150
153
|
self.load_model()
|
154
|
+
|
155
|
+
# Apply torch TP if model supports it
|
156
|
+
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
157
|
+
if self.tp_size > 1 and supports_torch_tp:
|
158
|
+
self.apply_torch_tp()
|
159
|
+
self.torch_tp_applied = True
|
160
|
+
else:
|
161
|
+
self.torch_tp_applied = False
|
162
|
+
|
151
163
|
if server_args.lora_paths is not None:
|
152
164
|
self.init_lora_manager()
|
153
165
|
self.init_memory_pool(
|
@@ -215,6 +227,47 @@ class ModelRunner:
|
|
215
227
|
|
216
228
|
return min_per_gpu_memory
|
217
229
|
|
230
|
+
def setup_model(self):
|
231
|
+
try:
|
232
|
+
from vllm.config import VllmConfig
|
233
|
+
|
234
|
+
vllm_config = VllmConfig()
|
235
|
+
vllm_config.model_config = self.vllm_model_config
|
236
|
+
vllm_config.load_config = self.load_config
|
237
|
+
vllm_config.device_config = DeviceConfig(self.device)
|
238
|
+
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
239
|
+
vllm_config.model_config, vllm_config.load_config
|
240
|
+
)
|
241
|
+
return get_model(vllm_config=vllm_config)
|
242
|
+
except ImportError:
|
243
|
+
return get_model(
|
244
|
+
model_config=self.vllm_model_config,
|
245
|
+
load_config=self.load_config,
|
246
|
+
device_config=DeviceConfig(self.device),
|
247
|
+
parallel_config=None,
|
248
|
+
scheduler_config=None,
|
249
|
+
lora_config=None,
|
250
|
+
cache_config=None,
|
251
|
+
)
|
252
|
+
|
253
|
+
def get_model_config_params(self):
|
254
|
+
sig = inspect.signature(VllmModelConfig.__init__)
|
255
|
+
params = {
|
256
|
+
"model": self.server_args.model_path,
|
257
|
+
"quantization": self.server_args.quantization,
|
258
|
+
"tokenizer": None,
|
259
|
+
"tokenizer_mode": None,
|
260
|
+
"trust_remote_code": self.server_args.trust_remote_code,
|
261
|
+
"dtype": self.server_args.dtype,
|
262
|
+
"seed": self.server_args.random_seed,
|
263
|
+
"skip_tokenizer_init": True,
|
264
|
+
}
|
265
|
+
|
266
|
+
if "task" in sig.parameters:
|
267
|
+
params["task"] = ""
|
268
|
+
|
269
|
+
return params
|
270
|
+
|
218
271
|
def load_model(self):
|
219
272
|
logger.info(
|
220
273
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -232,39 +285,25 @@ class ModelRunner:
|
|
232
285
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
233
286
|
|
234
287
|
# Prepare the vllm model config
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
model=self.server_args.model_path,
|
239
|
-
quantization=self.server_args.quantization,
|
240
|
-
tokenizer=None,
|
241
|
-
tokenizer_mode=None,
|
242
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
243
|
-
dtype=self.server_args.dtype,
|
244
|
-
seed=self.server_args.random_seed,
|
245
|
-
skip_tokenizer_init=True,
|
288
|
+
self.load_config = LoadConfig(
|
289
|
+
load_format=self.server_args.load_format,
|
290
|
+
download_dir=self.server_args.download_dir,
|
246
291
|
)
|
292
|
+
monkey_patch_vllm_model_config()
|
293
|
+
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
247
294
|
if self.model_config.model_override_args is not None:
|
248
295
|
self.vllm_model_config.hf_config.update(
|
249
296
|
self.model_config.model_override_args
|
250
297
|
)
|
251
|
-
self.dtype = self.vllm_model_config.dtype
|
252
298
|
|
253
|
-
|
254
|
-
|
255
|
-
model_config=self.vllm_model_config,
|
256
|
-
load_config=self.load_config,
|
257
|
-
device_config=DeviceConfig(self.device),
|
258
|
-
parallel_config=None,
|
259
|
-
scheduler_config=None,
|
260
|
-
lora_config=None,
|
261
|
-
cache_config=None,
|
262
|
-
)
|
299
|
+
self.model = self.setup_model()
|
300
|
+
|
263
301
|
self.sliding_window_size = (
|
264
302
|
self.model.get_attention_sliding_window_size()
|
265
303
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
266
304
|
else None
|
267
305
|
)
|
306
|
+
self.dtype = self.vllm_model_config.dtype
|
268
307
|
|
269
308
|
logger.info(
|
270
309
|
f"Load weight end. "
|
@@ -290,17 +329,9 @@ class ModelRunner:
|
|
290
329
|
target_device = torch.device(self.device)
|
291
330
|
|
292
331
|
try:
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
quantization=self.server_args.quantization,
|
297
|
-
tokenizer=None,
|
298
|
-
tokenizer_mode=None,
|
299
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
300
|
-
dtype=self.server_args.dtype,
|
301
|
-
seed=self.server_args.random_seed,
|
302
|
-
skip_tokenizer_init=True,
|
303
|
-
)
|
332
|
+
model_config_params = self.get_model_config_params()
|
333
|
+
model_config_params["model"] = model_path
|
334
|
+
vllm_model_config = VllmModelConfig(**model_config_params)
|
304
335
|
except Exception as e:
|
305
336
|
message = f"Failed to load model config: {e}."
|
306
337
|
return False, message
|
@@ -409,7 +440,10 @@ class ModelRunner:
|
|
409
440
|
if self.server_args.kv_cache_dtype == "auto":
|
410
441
|
self.kv_cache_dtype = self.dtype
|
411
442
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
412
|
-
|
443
|
+
if is_hip(): # Using natively supported format
|
444
|
+
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
445
|
+
else:
|
446
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
413
447
|
else:
|
414
448
|
raise ValueError(
|
415
449
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -548,6 +582,13 @@ class ModelRunner:
|
|
548
582
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
549
583
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
550
584
|
|
585
|
+
def apply_torch_tp(self):
|
586
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
587
|
+
from sglang.srt.model_parallel import tensor_parallel
|
588
|
+
|
589
|
+
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
590
|
+
tensor_parallel(self.model, device_mesh)
|
591
|
+
|
551
592
|
def forward_decode(self, forward_batch: ForwardBatch):
|
552
593
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
553
594
|
return self.cuda_graph_runner.replay(forward_batch)
|
@@ -573,21 +614,37 @@ class ModelRunner:
|
|
573
614
|
get_embedding=True,
|
574
615
|
)
|
575
616
|
|
617
|
+
def forward_idle(self, forward_batch: ForwardBatch):
|
618
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
619
|
+
return self.cuda_graph_runner.replay(forward_batch)
|
620
|
+
|
621
|
+
return self.model.forward(
|
622
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
623
|
+
)
|
624
|
+
|
576
625
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
577
626
|
if forward_batch.forward_mode.is_decode():
|
578
627
|
return self.forward_decode(forward_batch)
|
579
628
|
elif forward_batch.forward_mode.is_extend():
|
580
629
|
return self.forward_extend(forward_batch)
|
630
|
+
elif forward_batch.forward_mode.is_idle():
|
631
|
+
return self.forward_idle(forward_batch)
|
581
632
|
else:
|
582
633
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
583
634
|
|
584
635
|
def sample(
|
585
636
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
586
637
|
) -> torch.Tensor:
|
587
|
-
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
588
638
|
sampling_info = forward_batch.sampling_info
|
589
|
-
sampling_info.
|
590
|
-
|
639
|
+
if sampling_info.sampling_info_done:
|
640
|
+
# Overlap mode: the function update_regex_vocab_mask was executed
|
641
|
+
# in process_batch_result of the last batch.
|
642
|
+
if sampling_info.grammars:
|
643
|
+
sampling_info.sampling_info_done.wait()
|
644
|
+
else:
|
645
|
+
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
646
|
+
sampling_info.update_regex_vocab_mask()
|
647
|
+
sampling_info.update_penalties()
|
591
648
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
592
649
|
|
593
650
|
# Sample the next tokens.
|
@@ -613,7 +670,7 @@ class ModelRunner:
|
|
613
670
|
|
614
671
|
# Apply regex vocab_mask
|
615
672
|
if sampling_info.vocab_mask is not None:
|
616
|
-
logits =
|
673
|
+
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
617
674
|
|
618
675
|
return logits
|
619
676
|
|
@@ -637,7 +694,9 @@ def import_model_classes():
|
|
637
694
|
try:
|
638
695
|
module = importlib.import_module(name)
|
639
696
|
except Exception as e:
|
640
|
-
logger.warning(f"Ignore import error when loading {name}.
|
697
|
+
logger.warning(f"Ignore import error when loading {name}. {e}")
|
698
|
+
if crash_on_warnings():
|
699
|
+
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
641
700
|
continue
|
642
701
|
if hasattr(module, "EntryClass"):
|
643
702
|
entry = module.EntryClass
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Common utilities for torch model parallelism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch.distributed.device_mesh import DeviceMesh
|
9
|
+
|
10
|
+
try:
|
11
|
+
from torch.distributed.tensor import DTensor, Shard
|
12
|
+
except ImportError:
|
13
|
+
# torch 2.4 or older
|
14
|
+
from torch.distributed._tensor import DTensor, Shard
|
15
|
+
|
16
|
+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
|
+
from torch.distributed.tensor.parallel import (
|
18
|
+
ColwiseParallel,
|
19
|
+
RowwiseParallel,
|
20
|
+
parallelize_module,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class ColwiseParallelSharded(ColwiseParallel):
|
25
|
+
"""
|
26
|
+
A version of ColwiseParallel where the local weight has been already
|
27
|
+
sharded. This is used for the fused wqkv case, where during loading, we
|
28
|
+
already sharded wq, wk, wv before fusing them.
|
29
|
+
"""
|
30
|
+
|
31
|
+
# Override the _partition_linear_fn in ColwiseParallel
|
32
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
33
|
+
# colwise shard weight/bias to Shard(0), weight be Shard(0)
|
34
|
+
# means Colwise as Linear is input * weight^T + bias, where
|
35
|
+
# weight would become Shard(1)
|
36
|
+
for name, param in module.named_parameters():
|
37
|
+
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
38
|
+
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
|
+
module.register_parameter(name, dist_param)
|
40
|
+
|
41
|
+
|
42
|
+
class RowwiseParallelMaybeWait(RowwiseParallel):
|
43
|
+
"""
|
44
|
+
A version of RowwiseParallel that waits for the output (establish dependency
|
45
|
+
between comm stream and compute stream in CUDA sense) before going into the
|
46
|
+
next op. This is needed to workaround the current interaction between
|
47
|
+
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
|
+
"""
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
|
+
outputs = super(
|
53
|
+
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
|
54
|
+
)._prepare_output_fn(
|
55
|
+
output_layouts, use_local_output, mod, outputs, device_mesh
|
56
|
+
)
|
57
|
+
# wait for the output to be ready
|
58
|
+
if isinstance(outputs, AsyncCollectiveTensor):
|
59
|
+
return outputs.wait()
|
60
|
+
else:
|
61
|
+
return outputs
|
62
|
+
|
63
|
+
|
64
|
+
def tensor_parallel(
|
65
|
+
module: torch.nn.Module,
|
66
|
+
device_mesh: Optional[DeviceMesh] = None,
|
67
|
+
):
|
68
|
+
"""
|
69
|
+
Tensor parallelize the model across the given device mesh.
|
70
|
+
Args:
|
71
|
+
module (`torch.nn.Module`):
|
72
|
+
The module to tensor parallelize.
|
73
|
+
device_mesh (`torch.distributed.DeviceMesh`):
|
74
|
+
The device mesh to use for tensor parallelism.
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
78
|
+
# No op if `_tp_plan` attribute does not exist under the module.
|
79
|
+
# This is a helper function to be used with `model.apply` to recursively
|
80
|
+
# parallelize a model.
|
81
|
+
def tplize(mod: torch.nn.Module) -> None:
|
82
|
+
tp_plan = getattr(mod, "_tp_plan", None)
|
83
|
+
if tp_plan is None:
|
84
|
+
return
|
85
|
+
for child_name, tp_style in tp_plan.items():
|
86
|
+
submod = mod.get_submodule(child_name)
|
87
|
+
if tp_style == "Colwise":
|
88
|
+
parallelize_module(submod, device_mesh, ColwiseParallel())
|
89
|
+
elif tp_style == "Rowwise":
|
90
|
+
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
|
91
|
+
elif tp_style == "Colwise_Sharded":
|
92
|
+
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unknown TP style {tp_style}")
|
95
|
+
|
96
|
+
# `apply` is a native method of `nn.Module` that recursively applies a
|
97
|
+
# function to every submodule.
|
98
|
+
module.apply(tplize)
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -22,7 +22,9 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_rank,
|
25
26
|
get_tensor_model_parallel_world_size,
|
27
|
+
get_tp_group,
|
26
28
|
tensor_model_parallel_all_reduce,
|
27
29
|
)
|
28
30
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
338
340
|
cache_config=None,
|
339
341
|
quant_config: Optional[QuantizationConfig] = None,
|
340
342
|
layer_id=None,
|
343
|
+
use_dp=False,
|
341
344
|
) -> None:
|
342
345
|
super().__init__()
|
343
346
|
self.layer_id = layer_id
|
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
351
354
|
self.num_heads = num_heads
|
352
355
|
tp_size = get_tensor_model_parallel_world_size()
|
353
356
|
assert num_heads % tp_size == 0
|
354
|
-
self.num_local_heads = num_heads // tp_size
|
357
|
+
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
355
358
|
self.scaling = self.qk_head_dim**-0.5
|
356
359
|
self.rope_theta = rope_theta
|
357
360
|
self.max_position_embeddings = max_position_embeddings
|
358
361
|
|
359
|
-
if
|
360
|
-
|
361
|
-
|
362
|
-
self.
|
362
|
+
if use_dp:
|
363
|
+
# For data parallel attention
|
364
|
+
if self.q_lora_rank is not None:
|
365
|
+
self.q_a_proj = ReplicatedLinear(
|
366
|
+
self.hidden_size,
|
367
|
+
self.q_lora_rank,
|
368
|
+
bias=False,
|
369
|
+
quant_config=quant_config,
|
370
|
+
)
|
371
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
372
|
+
self.q_b_proj = ReplicatedLinear(
|
373
|
+
q_lora_rank,
|
374
|
+
self.num_heads * self.qk_head_dim,
|
375
|
+
bias=False,
|
376
|
+
quant_config=quant_config,
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
self.q_proj = ReplicatedLinear(
|
380
|
+
self.hidden_size,
|
381
|
+
self.num_heads * self.qk_head_dim,
|
382
|
+
bias=False,
|
383
|
+
quant_config=quant_config,
|
384
|
+
)
|
385
|
+
self.kv_b_proj = ReplicatedLinear(
|
386
|
+
self.kv_lora_rank,
|
387
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
363
388
|
bias=False,
|
364
389
|
quant_config=quant_config,
|
365
390
|
)
|
366
|
-
|
367
|
-
self.
|
368
|
-
|
369
|
-
self.
|
391
|
+
# O projection.
|
392
|
+
self.o_proj = ReplicatedLinear(
|
393
|
+
self.num_heads * self.v_head_dim,
|
394
|
+
self.hidden_size,
|
370
395
|
bias=False,
|
371
396
|
quant_config=quant_config,
|
372
397
|
)
|
373
398
|
else:
|
374
|
-
|
399
|
+
# For tensor parallel attention
|
400
|
+
if self.q_lora_rank is not None:
|
401
|
+
self.q_a_proj = ReplicatedLinear(
|
402
|
+
self.hidden_size,
|
403
|
+
self.q_lora_rank,
|
404
|
+
bias=False,
|
405
|
+
quant_config=quant_config,
|
406
|
+
)
|
407
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
408
|
+
self.q_b_proj = ColumnParallelLinear(
|
409
|
+
q_lora_rank,
|
410
|
+
self.num_heads * self.qk_head_dim,
|
411
|
+
bias=False,
|
412
|
+
quant_config=quant_config,
|
413
|
+
)
|
414
|
+
else:
|
415
|
+
self.q_proj = ColumnParallelLinear(
|
416
|
+
self.hidden_size,
|
417
|
+
self.num_heads * self.qk_head_dim,
|
418
|
+
bias=False,
|
419
|
+
quant_config=quant_config,
|
420
|
+
)
|
421
|
+
self.kv_b_proj = ColumnParallelLinear(
|
422
|
+
self.kv_lora_rank,
|
423
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
424
|
+
bias=False,
|
425
|
+
quant_config=quant_config,
|
426
|
+
)
|
427
|
+
# O projection.
|
428
|
+
self.o_proj = RowParallelLinear(
|
429
|
+
self.num_heads * self.v_head_dim,
|
375
430
|
self.hidden_size,
|
376
|
-
self.num_heads * self.qk_head_dim,
|
377
431
|
bias=False,
|
378
432
|
quant_config=quant_config,
|
379
433
|
)
|
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
439
|
quant_config=quant_config,
|
386
440
|
)
|
387
441
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
388
|
-
self.kv_b_proj = ColumnParallelLinear(
|
389
|
-
self.kv_lora_rank,
|
390
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
391
|
-
bias=False,
|
392
|
-
quant_config=quant_config,
|
393
|
-
)
|
394
|
-
# O projection.
|
395
|
-
self.o_proj = RowParallelLinear(
|
396
|
-
self.num_heads * self.v_head_dim,
|
397
|
-
self.hidden_size,
|
398
|
-
bias=False,
|
399
|
-
quant_config=quant_config,
|
400
|
-
)
|
401
442
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
402
443
|
self.rotary_emb = get_rope(
|
403
444
|
qk_rope_head_dim,
|
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
491
532
|
return output
|
492
533
|
|
493
534
|
|
535
|
+
def all_gather(
|
536
|
+
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
537
|
+
):
|
538
|
+
if world_size == 1:
|
539
|
+
return input_tensor
|
540
|
+
|
541
|
+
all_lens = forward_batch.global_num_tokens
|
542
|
+
max_len = max(forward_batch.global_num_tokens)
|
543
|
+
|
544
|
+
padded_tensor = torch.nn.functional.pad(
|
545
|
+
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
546
|
+
)
|
547
|
+
|
548
|
+
torch.distributed.all_gather_into_tensor(
|
549
|
+
forward_batch.gathered_buffer, padded_tensor, group=group
|
550
|
+
)
|
551
|
+
|
552
|
+
gathered_tensors = torch.concat(
|
553
|
+
[
|
554
|
+
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
555
|
+
for i in range(world_size)
|
556
|
+
]
|
557
|
+
)
|
558
|
+
|
559
|
+
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
560
|
+
end_index = start_index + all_lens[rank]
|
561
|
+
|
562
|
+
return gathered_tensors, start_index, end_index
|
563
|
+
|
564
|
+
|
494
565
|
class DeepseekV2DecoderLayer(nn.Module):
|
495
566
|
|
496
567
|
def __init__(
|
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
505
576
|
rope_theta = getattr(config, "rope_theta", 10000)
|
506
577
|
rope_scaling = getattr(config, "rope_scaling", None)
|
507
578
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
579
|
+
self.enable_dp_attention = (
|
580
|
+
not global_server_args_dict["disable_mla"]
|
581
|
+
and global_server_args_dict["enable_dp_attention"]
|
582
|
+
)
|
583
|
+
if self.enable_dp_attention:
|
584
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
585
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
586
|
+
self.tp_group = get_tp_group().device_group
|
508
587
|
if not global_server_args_dict["disable_mla"]:
|
509
588
|
self.self_attn = DeepseekV2AttentionMLA(
|
510
589
|
config=config,
|
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
523
602
|
cache_config=cache_config,
|
524
603
|
quant_config=quant_config,
|
525
604
|
layer_id=layer_id,
|
605
|
+
use_dp=self.enable_dp_attention,
|
526
606
|
)
|
527
607
|
else:
|
528
608
|
self.self_attn = DeepseekV2Attention(
|
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
569
649
|
residual: Optional[torch.Tensor],
|
570
650
|
) -> torch.Tensor:
|
571
651
|
# Self Attention
|
572
|
-
if
|
573
|
-
residual
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
hidden_states=
|
580
|
-
|
581
|
-
|
652
|
+
if not forward_batch.forward_mode.is_idle():
|
653
|
+
if residual is None:
|
654
|
+
residual = hidden_states
|
655
|
+
hidden_states = self.input_layernorm(hidden_states)
|
656
|
+
else:
|
657
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
658
|
+
|
659
|
+
hidden_states = self.self_attn(
|
660
|
+
positions=positions,
|
661
|
+
hidden_states=hidden_states,
|
662
|
+
forward_batch=forward_batch,
|
663
|
+
)
|
664
|
+
hidden_states, residual = self.post_attention_layernorm(
|
665
|
+
hidden_states, residual
|
666
|
+
)
|
582
667
|
|
583
668
|
# Fully Connected
|
584
|
-
|
585
|
-
|
669
|
+
if self.enable_dp_attention:
|
670
|
+
hidden_states, start_idx, end_idx = all_gather(
|
671
|
+
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
672
|
+
)
|
673
|
+
hidden_states = self.mlp(hidden_states)
|
674
|
+
hidden_states = hidden_states[start_idx:end_idx]
|
675
|
+
else:
|
676
|
+
hidden_states = self.mlp(hidden_states)
|
677
|
+
|
586
678
|
return hidden_states, residual
|
587
679
|
|
588
680
|
|
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
|
|
603
695
|
self.embed_tokens = VocabParallelEmbedding(
|
604
696
|
config.vocab_size,
|
605
697
|
config.hidden_size,
|
698
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
606
699
|
)
|
607
700
|
self.layers = nn.ModuleList(
|
608
701
|
[
|
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
|
|
630
723
|
hidden_states, residual = layer(
|
631
724
|
positions, hidden_states, forward_batch, residual
|
632
725
|
)
|
633
|
-
|
726
|
+
if not forward_batch.forward_mode.is_idle():
|
727
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
634
728
|
return hidden_states
|
635
729
|
|
636
730
|
|
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
646
740
|
self.config = config
|
647
741
|
self.quant_config = quant_config
|
648
742
|
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
743
|
+
if global_server_args_dict["enable_dp_attention"]:
|
744
|
+
self.lm_head = ReplicatedLinear(
|
745
|
+
config.hidden_size,
|
746
|
+
config.vocab_size,
|
747
|
+
bias=False,
|
748
|
+
)
|
749
|
+
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
750
|
+
else:
|
751
|
+
self.lm_head = ParallelLMHead(
|
752
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
753
|
+
)
|
754
|
+
self.logits_processor = LogitsProcessor(config)
|
653
755
|
|
654
756
|
@torch.no_grad()
|
655
757
|
def forward(
|
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
659
761
|
forward_batch: ForwardBatch,
|
660
762
|
) -> torch.Tensor:
|
661
763
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
662
|
-
|
663
|
-
|
664
|
-
|
764
|
+
if not forward_batch.forward_mode.is_idle():
|
765
|
+
return self.logits_processor(
|
766
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
767
|
+
)
|
665
768
|
|
666
769
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
667
770
|
stacked_params_mapping = [
|