sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -20,15 +20,17 @@ from contextlib import contextmanager
|
|
20
20
|
from typing import TYPE_CHECKING, Callable
|
21
21
|
|
22
22
|
import torch
|
23
|
+
import tqdm
|
24
|
+
from vllm.distributed import get_tensor_model_parallel_rank
|
23
25
|
from vllm.distributed.parallel_state import graph_capture
|
24
26
|
from vllm.model_executor.custom_op import CustomOp
|
25
27
|
|
26
|
-
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
27
28
|
from sglang.srt.layers.logits_processor import (
|
28
29
|
LogitsMetadata,
|
29
30
|
LogitsProcessor,
|
30
31
|
LogitsProcessorOutput,
|
31
32
|
)
|
33
|
+
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
32
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
33
35
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
34
36
|
|
@@ -127,7 +129,7 @@ class CudaGraphRunner:
|
|
127
129
|
|
128
130
|
# Batch sizes to capture
|
129
131
|
if model_runner.server_args.disable_cuda_graph_padding:
|
130
|
-
self.capture_bs = list(range(1,
|
132
|
+
self.capture_bs = list(range(1, 33)) + [64, 128]
|
131
133
|
else:
|
132
134
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
133
135
|
|
@@ -255,7 +257,12 @@ class CudaGraphRunner:
|
|
255
257
|
def capture(self):
|
256
258
|
with graph_capture() as graph_capture_context:
|
257
259
|
self.stream = graph_capture_context.stream
|
258
|
-
|
260
|
+
capture_bs = (
|
261
|
+
tqdm.tqdm(self.capture_bs)
|
262
|
+
if get_tensor_model_parallel_rank() == 0
|
263
|
+
else self.capture_bs
|
264
|
+
)
|
265
|
+
for bs in capture_bs:
|
259
266
|
with patch_model(
|
260
267
|
self.model_runner.model,
|
261
268
|
bs in self.compile_bs,
|
@@ -387,8 +394,14 @@ class CudaGraphRunner:
|
|
387
394
|
|
388
395
|
# Extract logprobs
|
389
396
|
if forward_batch.return_logprob:
|
390
|
-
|
391
|
-
|
397
|
+
logits_metadata = LogitsMetadata(
|
398
|
+
forward_mode=ForwardMode.DECODE,
|
399
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
400
|
+
)
|
401
|
+
next_token_logprobs = (
|
402
|
+
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
|
403
|
+
next_token_logits, logits_metadata
|
404
|
+
)
|
392
405
|
)
|
393
406
|
logits_output = LogitsProcessorOutput(
|
394
407
|
next_token_logits=next_token_logits,
|
@@ -396,13 +409,14 @@ class CudaGraphRunner:
|
|
396
409
|
)
|
397
410
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
398
411
|
if return_top_logprob:
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
)
|
403
|
-
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
412
|
+
(
|
413
|
+
logits_output.output_top_logprobs_val,
|
414
|
+
logits_output.output_top_logprobs_idx,
|
415
|
+
) = LogitsProcessor.get_top_logprobs(
|
404
416
|
next_token_logprobs, logits_metadata
|
405
|
-
)[
|
417
|
+
)[
|
418
|
+
2:4
|
419
|
+
]
|
406
420
|
else:
|
407
421
|
logits_output = LogitsProcessorOutput(
|
408
422
|
next_token_logits=next_token_logits,
|
@@ -95,6 +95,12 @@ class ModelRunner:
|
|
95
95
|
):
|
96
96
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
97
97
|
self.server_args.attention_backend = "triton"
|
98
|
+
# FIXME(HandH1998)
|
99
|
+
if (
|
100
|
+
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
|
101
|
+
and not self.server_args.disable_cuda_graph
|
102
|
+
):
|
103
|
+
self.server_args.disable_cuda_graph = True
|
98
104
|
|
99
105
|
if self.server_args.enable_double_sparsity:
|
100
106
|
logger.info(
|
@@ -111,17 +117,20 @@ class ModelRunner:
|
|
111
117
|
)
|
112
118
|
|
113
119
|
if self.is_multimodal:
|
114
|
-
server_args.chunked_prefill_size = -1
|
115
120
|
self.mem_fraction_static *= 0.95
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
+
if self.model_config.hf_config.architectures == [
|
122
|
+
"MllamaForConditionalGeneration"
|
123
|
+
]:
|
124
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
125
|
+
server_args.chunked_prefill_size = -1
|
121
126
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
122
127
|
if self.model_config.hf_config.architectures == [
|
123
128
|
"Qwen2VLForConditionalGeneration"
|
124
129
|
]:
|
130
|
+
logger.info(
|
131
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
132
|
+
)
|
133
|
+
server_args.chunked_prefill_size = -1
|
125
134
|
server_args.disable_radix_cache = True
|
126
135
|
|
127
136
|
# Global vars
|
@@ -154,6 +163,11 @@ class ModelRunner:
|
|
154
163
|
self.sampler = Sampler()
|
155
164
|
self.load_model()
|
156
165
|
|
166
|
+
# Apply torchao quantization
|
167
|
+
apply_torchao_config_to_model(
|
168
|
+
self.model, global_server_args_dict["torchao_config"]
|
169
|
+
)
|
170
|
+
|
157
171
|
# Apply torch TP if the model supports it
|
158
172
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
159
173
|
if self.tp_size > 1 and supports_torch_tp:
|
@@ -162,10 +176,6 @@ class ModelRunner:
|
|
162
176
|
else:
|
163
177
|
self.torch_tp_applied = False
|
164
178
|
|
165
|
-
apply_torchao_config_to_model(
|
166
|
-
self.model, global_server_args_dict["torchao_config"]
|
167
|
-
)
|
168
|
-
|
169
179
|
# Init memory pool and attention backends
|
170
180
|
if server_args.lora_paths is not None:
|
171
181
|
self.init_lora_manager()
|
@@ -242,20 +252,22 @@ class ModelRunner:
|
|
242
252
|
if torch.cuda.get_device_capability()[1] < 5:
|
243
253
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
244
254
|
|
245
|
-
# Prepare the
|
255
|
+
# Prepare the model config
|
246
256
|
self.load_config = LoadConfig(
|
247
257
|
load_format=self.server_args.load_format,
|
248
258
|
download_dir=self.server_args.download_dir,
|
249
259
|
)
|
250
|
-
|
251
260
|
if self.server_args.load_format == "gguf":
|
252
261
|
monkey_patch_vllm_gguf_config()
|
262
|
+
|
263
|
+
# Load the model
|
253
264
|
self.model = get_model(
|
254
265
|
model_config=self.model_config,
|
255
266
|
load_config=self.load_config,
|
256
267
|
device_config=DeviceConfig(self.device),
|
257
268
|
)
|
258
269
|
|
270
|
+
# Parse other args
|
259
271
|
self.sliding_window_size = (
|
260
272
|
self.model.get_attention_sliding_window_size()
|
261
273
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
@@ -270,8 +282,10 @@ class ModelRunner:
|
|
270
282
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
271
283
|
)
|
272
284
|
|
273
|
-
def update_weights_from_disk(
|
274
|
-
|
285
|
+
def update_weights_from_disk(
|
286
|
+
self, model_path: str, load_format: str
|
287
|
+
) -> tuple[bool, str]:
|
288
|
+
"""Update engine weights in-place from the disk."""
|
275
289
|
from sglang.srt.model_loader.loader import (
|
276
290
|
DefaultModelLoader,
|
277
291
|
device_loading_context,
|
sglang/srt/model_parallel.py
CHANGED
@@ -2,18 +2,18 @@
|
|
2
2
|
Common utilities for torch model parallelism.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Optional
|
5
|
+
from typing import Optional, Sequence
|
6
6
|
|
7
7
|
import torch
|
8
|
+
import torch.nn as nn
|
8
9
|
from torch.distributed.device_mesh import DeviceMesh
|
9
10
|
|
10
11
|
try:
|
11
|
-
|
12
|
+
import torch.distributed.tensor as dt
|
12
13
|
except ImportError:
|
13
14
|
# torch 2.4 or older
|
14
|
-
|
15
|
+
import torch.distributed._tensor as dt
|
15
16
|
|
16
|
-
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
17
|
from torch.distributed.tensor.parallel import (
|
18
18
|
ColwiseParallel,
|
19
19
|
RowwiseParallel,
|
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
|
|
21
21
|
)
|
22
22
|
|
23
23
|
|
24
|
+
def _shard_tensor(
|
25
|
+
full_tensor: torch.Tensor,
|
26
|
+
device_mesh: DeviceMesh,
|
27
|
+
placements: Sequence[dt.Shard],
|
28
|
+
) -> "dt.DTensor":
|
29
|
+
"""
|
30
|
+
Locally shards a full tensor based on indicated sharding arrangement, and
|
31
|
+
returns a DTensor containing the local shard.
|
32
|
+
|
33
|
+
.. warning:: This is a private API that is subject to change. It skips the
|
34
|
+
communication otherwise required by `distribute_tensor`. It is only
|
35
|
+
applicable to cases where all ranks have the same `full_tensor`. For
|
36
|
+
example, in distributed inference all ranks load from the same
|
37
|
+
checkpoint. This API will not check for data equality between ranks, it
|
38
|
+
is thus user's responsibility to ensure the `full_tensor` is the same
|
39
|
+
across ranks.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
full_tensor (torch.Tensor): the full tensor to be sharded.
|
43
|
+
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
|
44
|
+
DTensor. Must have same dimension as the number of placements.
|
45
|
+
placements (Sequence[:class:`Shard`]): the placements that
|
46
|
+
describes how to place the local tensor on DeviceMesh.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A :class:`DTensor` object with the shard as its local tensor.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
>>> # xdoctest: +SKIP("need world_size and rank")
|
53
|
+
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
|
54
|
+
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
|
55
|
+
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
|
56
|
+
"""
|
57
|
+
shape, offset = dt._utils.compute_local_shape_and_global_offset(
|
58
|
+
full_tensor.shape, device_mesh, placements
|
59
|
+
)
|
60
|
+
slices = [
|
61
|
+
slice(cur_offset, cur_offset + cur_shape)
|
62
|
+
for cur_shape, cur_offset in zip(shape, offset)
|
63
|
+
]
|
64
|
+
local_tensor = full_tensor[slices]
|
65
|
+
return dt.DTensor.from_local(local_tensor, device_mesh, placements)
|
66
|
+
|
67
|
+
|
24
68
|
class ColwiseParallelSharded(ColwiseParallel):
|
25
69
|
"""
|
26
70
|
A version of ColwiseParallel where the local weight has been already
|
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
|
|
34
78
|
# means Colwise as Linear is input * weight^T + bias, where
|
35
79
|
# weight would become Shard(1)
|
36
80
|
for name, param in module.named_parameters():
|
37
|
-
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
81
|
+
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
|
38
82
|
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
83
|
module.register_parameter(name, dist_param)
|
40
84
|
|
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|
47
91
|
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
92
|
"""
|
49
93
|
|
94
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
95
|
+
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
96
|
+
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
97
|
+
# weight would become Shard(0)
|
98
|
+
module.register_parameter(
|
99
|
+
"weight",
|
100
|
+
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
|
101
|
+
)
|
102
|
+
if getattr(module, "bias", None) is not None:
|
103
|
+
# The Linear module has bias
|
104
|
+
module.register_parameter(
|
105
|
+
"bias",
|
106
|
+
nn.Parameter(
|
107
|
+
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
|
108
|
+
),
|
109
|
+
)
|
110
|
+
|
50
111
|
@staticmethod
|
51
112
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
113
|
outputs = super(
|
sglang/srt/models/dbrx.py
CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
|
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
28
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
29
29
|
|
30
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
QKVParallelLinear,
|
33
32
|
ReplicatedLinear,
|
34
33
|
RowParallelLinear,
|
35
34
|
)
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/deepseek.py
CHANGED
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
30
|
|
31
31
|
from sglang.srt.layers.activation import SiluAndMul
|
32
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
MergedColumnParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
38
37
|
RowParallelLinear,
|
39
38
|
)
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
|
+
import torch.nn.functional as F
|
22
23
|
from torch import nn
|
23
24
|
from transformers import PretrainedConfig
|
24
25
|
from vllm import _custom_ops as ops
|
@@ -31,8 +32,6 @@ from vllm.distributed import (
|
|
31
32
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
33
|
|
33
34
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
-
from sglang.srt.layers.ep_moe.layer import EPMoE
|
35
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
36
35
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
36
|
from sglang.srt.layers.linear import (
|
38
37
|
ColumnParallelLinear,
|
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
|
|
41
40
|
RowParallelLinear,
|
42
41
|
)
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
44
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
47
|
+
block_quant_to_tensor_quant,
|
48
|
+
input_to_float8,
|
49
|
+
)
|
45
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
46
51
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
47
52
|
ParallelLMHead,
|
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
|
|
90
95
|
return x
|
91
96
|
|
92
97
|
|
98
|
+
class MoEGate(nn.Module):
|
99
|
+
def __init__(self, config):
|
100
|
+
super().__init__()
|
101
|
+
self.weight = nn.Parameter(
|
102
|
+
torch.empty((config.n_routed_experts, config.hidden_size))
|
103
|
+
)
|
104
|
+
if config.topk_method == "noaux_tc":
|
105
|
+
self.e_score_correction_bias = nn.Parameter(
|
106
|
+
torch.empty((config.n_routed_experts))
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
self.e_score_correction_bias = None
|
110
|
+
|
111
|
+
def forward(self, hidden_states):
|
112
|
+
logits = F.linear(hidden_states, self.weight, None)
|
113
|
+
return logits
|
114
|
+
|
115
|
+
|
93
116
|
class DeepseekV2MoE(nn.Module):
|
94
117
|
|
95
118
|
def __init__(
|
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
|
|
114
137
|
"Only silu is supported for now."
|
115
138
|
)
|
116
139
|
|
140
|
+
self.gate = MoEGate(config=config)
|
141
|
+
|
117
142
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
118
143
|
self.experts = MoEImpl(
|
119
144
|
num_experts=config.n_routed_experts,
|
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
|
|
125
150
|
use_grouped_topk=True,
|
126
151
|
num_expert_group=config.n_group,
|
127
152
|
topk_group=config.topk_group,
|
153
|
+
correction_bias=self.gate.e_score_correction_bias,
|
128
154
|
)
|
129
155
|
|
130
|
-
self.gate = ReplicatedLinear(
|
131
|
-
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
|
132
|
-
)
|
133
156
|
if config.n_shared_experts is not None:
|
134
157
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
135
158
|
self.shared_experts = DeepseekV2MLP(
|
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
|
|
146
169
|
if self.n_shared_experts is not None:
|
147
170
|
shared_output = self.shared_experts(hidden_states)
|
148
171
|
# router_logits: (num_tokens, n_experts)
|
149
|
-
router_logits
|
172
|
+
router_logits = self.gate(hidden_states)
|
150
173
|
final_hidden_states = (
|
151
174
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
152
175
|
* self.routed_scaling_factor
|
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
167
190
|
return 0.1 * mscale * math.log(scale) + 1.0
|
168
191
|
|
169
192
|
|
170
|
-
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
171
|
-
finfo = torch.finfo(dtype)
|
172
|
-
min_val, max_val = x.aminmax()
|
173
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
174
|
-
scale = finfo.max / amax
|
175
|
-
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
176
|
-
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
177
|
-
|
178
|
-
|
179
193
|
class DeepseekV2Attention(nn.Module):
|
180
194
|
|
181
195
|
def __init__(
|
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
439
453
|
quant_config=quant_config,
|
440
454
|
)
|
441
455
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
442
|
-
|
456
|
+
|
457
|
+
if rope_scaling:
|
458
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
459
|
+
|
443
460
|
self.rotary_emb = get_rope(
|
444
461
|
qk_rope_head_dim,
|
445
462
|
rotary_dim=qk_rope_head_dim,
|
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
454
471
|
scaling_factor = rope_scaling["factor"]
|
455
472
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
456
473
|
self.scaling = self.scaling * mscale * mscale
|
474
|
+
else:
|
475
|
+
self.rotary_emb.forward = self.rotary_emb.forward_native
|
457
476
|
|
458
477
|
self.attn_mqa = RadixAttention(
|
459
478
|
self.num_local_heads,
|
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
845
864
|
|
846
865
|
params_dict = dict(self.named_parameters())
|
847
866
|
for name, loaded_weight in weights:
|
867
|
+
# TODO(HandH1998): Modify it when nextn is supported.
|
868
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
869
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
870
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
871
|
+
name_list = name.split(".")
|
872
|
+
if (
|
873
|
+
len(name_list) >= 3
|
874
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
875
|
+
):
|
876
|
+
continue
|
848
877
|
if "rotary_emb.inv_freq" in name:
|
849
878
|
continue
|
850
879
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
909
938
|
).T
|
910
939
|
else:
|
911
940
|
w = self_attn.kv_b_proj.weight
|
941
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
942
|
+
# This may affect the accuracy of fp8 model.
|
943
|
+
if (
|
944
|
+
hasattr(self.quant_config, "weight_block_size")
|
945
|
+
and w.dtype == torch.float8_e4m3fn
|
946
|
+
):
|
947
|
+
weight_block_size = self.quant_config.weight_block_size
|
948
|
+
if weight_block_size is not None:
|
949
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
950
|
+
w, scale = block_quant_to_tensor_quant(
|
951
|
+
w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
|
952
|
+
)
|
953
|
+
self_attn.w_scale = scale
|
912
954
|
w_kc, w_vc = w.unflatten(
|
913
955
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
914
956
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
915
957
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
916
958
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
917
|
-
if
|
959
|
+
if (
|
960
|
+
hasattr(self_attn.kv_b_proj, "weight_scale")
|
961
|
+
and self_attn.w_scale is None
|
962
|
+
):
|
918
963
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
919
964
|
|
920
965
|
|
921
|
-
|
966
|
+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
967
|
+
pass
|
968
|
+
|
969
|
+
|
970
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
sglang/srt/models/gemma2.py
CHANGED
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
|
|
355
355
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
356
356
|
)
|
357
357
|
|
358
|
+
def get_hidden_dim(self, module_name):
|
359
|
+
# return input_dim, output_dim
|
360
|
+
if module_name in ["q_proj", "qkv_proj"]:
|
361
|
+
return (
|
362
|
+
self.config.hidden_size,
|
363
|
+
self.config.head_dim * self.config.num_attention_heads,
|
364
|
+
)
|
365
|
+
elif module_name in ["o_proj"]:
|
366
|
+
return (
|
367
|
+
self.config.head_dim * self.config.num_attention_heads,
|
368
|
+
self.config.hidden_size,
|
369
|
+
)
|
370
|
+
elif module_name in ["kv_proj"]:
|
371
|
+
return (
|
372
|
+
self.config.hidden_size,
|
373
|
+
self.config.head_dim * self.config.num_key_value_heads,
|
374
|
+
)
|
375
|
+
elif module_name == "gate_up_proj":
|
376
|
+
return self.config.hidden_size, self.config.intermediate_size
|
377
|
+
elif module_name == "down_proj":
|
378
|
+
return self.config.intermediate_size, self.config.hidden_size
|
379
|
+
else:
|
380
|
+
raise NotImplementedError()
|
381
|
+
|
382
|
+
def get_module_name(self, name):
|
383
|
+
params_mapping = {
|
384
|
+
"q_proj": "qkv_proj",
|
385
|
+
"k_proj": "qkv_proj",
|
386
|
+
"v_proj": "qkv_proj",
|
387
|
+
"gate_proj": "gate_up_proj",
|
388
|
+
"up_proj": "gate_up_proj",
|
389
|
+
}
|
390
|
+
return params_mapping.get(name, name)
|
391
|
+
|
358
392
|
def get_attention_sliding_window_size(self):
|
359
393
|
return get_attention_sliding_window_size(self.config)
|
360
394
|
|
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
|
|
32
32
|
) -> None:
|
33
33
|
super().__init__()
|
34
34
|
self.config = config
|
35
|
-
self.torchao_config = None
|
36
35
|
self.quant_config = quant_config
|
37
36
|
self.num_labels = config.num_labels
|
38
37
|
self.model = Gemma2Model(config, quant_config=quant_config)
|