sglang 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +114 -63
- sglang/check_env.py +2 -0
- sglang/lang/backend/runtime_endpoint.py +0 -11
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/radix_attention.py +22 -9
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +15 -11
- sglang/srt/managers/tokenizer_manager.py +28 -13
- sglang/srt/mem_cache/memory_pool.py +65 -24
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/model_runner.py +52 -21
- sglang/srt/models/deepseek_v2.py +198 -16
- sglang/srt/openai_api/adapter.py +120 -20
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/server.py +87 -78
- sglang/srt/server_args.py +8 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +94 -13
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/METADATA +29 -27
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/RECORD +33 -30
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -92,7 +92,7 @@ class GenerateReqInput:
|
|
92
92
|
for element in parallel_sample_num_list
|
93
93
|
)
|
94
94
|
if parallel_sample_num > 1 and (not all_equal):
|
95
|
-
|
95
|
+
# TODO cope with the case that the parallel_sample_num is different for different samples
|
96
96
|
raise ValueError(
|
97
97
|
"The parallel_sample_num should be the same for all samples in sample params."
|
98
98
|
)
|
@@ -103,14 +103,19 @@ class GenerateReqInput:
|
|
103
103
|
if parallel_sample_num != 1:
|
104
104
|
# parallel sampling +1 represents the original prefill stage
|
105
105
|
num = parallel_sample_num + 1
|
106
|
-
if isinstance(self.text,
|
107
|
-
|
106
|
+
if isinstance(self.text, list):
|
107
|
+
# suppot batch operation
|
108
108
|
self.batch_size = len(self.text)
|
109
109
|
num = num * len(self.text)
|
110
|
+
elif isinstance(self.input_ids, list) and isinstance(
|
111
|
+
self.input_ids[0], list
|
112
|
+
):
|
113
|
+
self.batch_size = len(self.input_ids)
|
114
|
+
num = num * len(self.input_ids)
|
110
115
|
else:
|
111
116
|
self.batch_size = 1
|
112
117
|
else:
|
113
|
-
|
118
|
+
# support select operation
|
114
119
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
115
120
|
self.batch_size = num
|
116
121
|
|
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
|
|
29
29
|
from sglang.srt.constrained import RegexGuide
|
30
30
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
31
31
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
|
-
from sglang.srt.mem_cache.memory_pool import
|
32
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
33
33
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
34
34
|
|
35
35
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
@@ -39,6 +39,7 @@ global_server_args_dict = {
|
|
39
39
|
"disable_flashinfer": False,
|
40
40
|
"disable_flashinfer_sampling": False,
|
41
41
|
"attention_reduce_in_fp32": False,
|
42
|
+
"enable_mla": False,
|
42
43
|
}
|
43
44
|
|
44
45
|
|
@@ -289,7 +290,7 @@ class Batch:
|
|
289
290
|
# Request, memory pool, and cache
|
290
291
|
reqs: List[Req]
|
291
292
|
req_to_token_pool: ReqToTokenPool
|
292
|
-
token_to_kv_pool:
|
293
|
+
token_to_kv_pool: BaseTokenToKVPool
|
293
294
|
tree_cache: RadixCache
|
294
295
|
|
295
296
|
# Batched arguments to model runner
|
@@ -380,13 +381,15 @@ class Batch:
|
|
380
381
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
381
382
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
382
383
|
if out_cache_loc is None:
|
383
|
-
self.tree_cache
|
384
|
-
|
384
|
+
if self.tree_cache is not None:
|
385
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
386
|
+
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
385
387
|
|
386
388
|
if out_cache_loc is None:
|
387
|
-
logger.error("Prefill out of memory.
|
388
|
-
self.tree_cache
|
389
|
-
|
389
|
+
logger.error("Prefill out of memory. Try to lower your batch size.")
|
390
|
+
if self.tree_cache is not None:
|
391
|
+
self.tree_cache.pretty_print()
|
392
|
+
exit(1)
|
390
393
|
|
391
394
|
pt = 0
|
392
395
|
for i in range(bs):
|
@@ -637,9 +640,10 @@ class Batch:
|
|
637
640
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
638
641
|
|
639
642
|
if self.out_cache_loc is None:
|
640
|
-
logger.error("Decode out of memory.
|
641
|
-
self.tree_cache
|
642
|
-
|
643
|
+
logger.error("Decode out of memory. Try to lower your batch size.")
|
644
|
+
if self.tree_cache is not None:
|
645
|
+
self.tree_cache.pretty_print()
|
646
|
+
exit(1)
|
643
647
|
|
644
648
|
self.req_to_token_pool.req_to_token[
|
645
649
|
self.req_pool_indices, self.seq_lens - 1
|
@@ -777,7 +781,7 @@ class InputMetadata:
|
|
777
781
|
seq_lens: torch.Tensor
|
778
782
|
positions: torch.Tensor
|
779
783
|
req_to_token_pool: ReqToTokenPool
|
780
|
-
token_to_kv_pool:
|
784
|
+
token_to_kv_pool: BaseTokenToKVPool
|
781
785
|
|
782
786
|
# For extend
|
783
787
|
extend_seq_lens: torch.Tensor
|
@@ -153,8 +153,9 @@ class TokenizerManager:
|
|
153
153
|
async def _handle_single_request(
|
154
154
|
self, obj, request, index=None, is_cache_for_prefill=False
|
155
155
|
):
|
156
|
-
if not is_cache_for_prefill:
|
157
|
-
not_use_index =
|
156
|
+
if not is_cache_for_prefill: # The normal case with a single prompt
|
157
|
+
not_use_index = index is None
|
158
|
+
|
158
159
|
rid = obj.rid if not_use_index else obj.rid[index]
|
159
160
|
input_text = obj.text if not_use_index else obj.text[index]
|
160
161
|
input_ids = (
|
@@ -182,14 +183,27 @@ class TokenizerManager:
|
|
182
183
|
top_logprobs_num = (
|
183
184
|
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
184
185
|
)
|
185
|
-
else:
|
186
|
-
if
|
187
|
-
|
188
|
-
|
186
|
+
else: # A prefill request to cache the common prompt for parallel sampling
|
187
|
+
if obj.text is not None:
|
188
|
+
if isinstance(obj.text, list):
|
189
|
+
input_text = obj.text[index]
|
190
|
+
rid = obj.rid[index]
|
191
|
+
else:
|
192
|
+
input_text = obj.text
|
193
|
+
rid = obj.rid[0]
|
194
|
+
input_ids = self.tokenizer.encode(input_text)
|
189
195
|
else:
|
190
|
-
input_text =
|
191
|
-
|
192
|
-
|
196
|
+
input_text = None
|
197
|
+
if isinstance(obj.input_ids, list) and isinstance(
|
198
|
+
obj.input_ids[0], list
|
199
|
+
):
|
200
|
+
# when obj["input_ids"] is List[List[int]]
|
201
|
+
input_ids = obj.input_ids[index]
|
202
|
+
rid = obj.rid[index]
|
203
|
+
else:
|
204
|
+
input_ids = obj.input_ids
|
205
|
+
rid = obj.rid[0]
|
206
|
+
|
193
207
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
194
208
|
sampling_params.max_new_tokens = 0
|
195
209
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
@@ -240,11 +254,11 @@ class TokenizerManager:
|
|
240
254
|
):
|
241
255
|
if input_id_result is not None:
|
242
256
|
input_id_result.append(input_id)
|
243
|
-
|
244
|
-
if len(input_id_result) > 1 and input_id_result is not None:
|
257
|
+
if input_id_result is not None and len(input_id_result) > 1:
|
245
258
|
obj.input_ids = input_id_result
|
246
259
|
elif input_id_result is not None:
|
247
260
|
obj.input_ids = input_id_result[0]
|
261
|
+
|
248
262
|
# First send out all requests
|
249
263
|
for i in range(batch_size):
|
250
264
|
for j in range(parallel_sample_num):
|
@@ -264,11 +278,12 @@ class TokenizerManager:
|
|
264
278
|
input_text = None
|
265
279
|
input_ids = obj.input_ids[i]
|
266
280
|
else:
|
281
|
+
assert obj.input_ids is not None
|
267
282
|
if batch_size == 1:
|
268
|
-
input_text =
|
283
|
+
input_text = None
|
269
284
|
input_ids = obj.input_ids
|
270
285
|
else:
|
271
|
-
input_text =
|
286
|
+
input_text = None
|
272
287
|
input_ids = obj.input_ids[i]
|
273
288
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
274
289
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
@@ -57,32 +57,18 @@ class ReqToTokenPool:
|
|
57
57
|
self.can_use_mem_size = len(self.mem_state)
|
58
58
|
|
59
59
|
|
60
|
-
class
|
60
|
+
class BaseTokenToKVPool:
|
61
61
|
"""A memory pool that maps a token to its kv cache locations"""
|
62
62
|
|
63
63
|
def __init__(
|
64
64
|
self,
|
65
65
|
size: int,
|
66
|
-
dtype: torch.dtype,
|
67
|
-
head_num: int,
|
68
|
-
head_dim: int,
|
69
|
-
layer_num: int,
|
70
66
|
):
|
71
67
|
self.size = size
|
72
68
|
|
73
69
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
74
70
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
75
71
|
|
76
|
-
# [size, head_num, head_dim] for each layer
|
77
|
-
self.k_buffer = [
|
78
|
-
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
79
|
-
for _ in range(layer_num)
|
80
|
-
]
|
81
|
-
self.v_buffer = [
|
82
|
-
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
83
|
-
for _ in range(layer_num)
|
84
|
-
]
|
85
|
-
|
86
72
|
# Prefetch buffer
|
87
73
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
88
74
|
self.prefetch_chunk_size = 512
|
@@ -90,15 +76,6 @@ class TokenToKVPool:
|
|
90
76
|
self.can_use_mem_size = self.size
|
91
77
|
self.clear()
|
92
78
|
|
93
|
-
def get_key_buffer(self, layer_id: int):
|
94
|
-
return self.k_buffer[layer_id]
|
95
|
-
|
96
|
-
def get_value_buffer(self, layer_id: int):
|
97
|
-
return self.v_buffer[layer_id]
|
98
|
-
|
99
|
-
def get_kv_buffer(self, layer_id: int):
|
100
|
-
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
101
|
-
|
102
79
|
def available_size(self):
|
103
80
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
104
81
|
|
@@ -139,3 +116,67 @@ class TokenToKVPool:
|
|
139
116
|
|
140
117
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
141
118
|
self.mem_state[0] = False
|
119
|
+
|
120
|
+
|
121
|
+
class MHATokenToKVPool(BaseTokenToKVPool):
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
size: int,
|
126
|
+
dtype: torch.dtype,
|
127
|
+
head_num: int,
|
128
|
+
head_dim: int,
|
129
|
+
layer_num: int,
|
130
|
+
):
|
131
|
+
super().__init__(size)
|
132
|
+
|
133
|
+
# [size, head_num, head_dim] for each layer
|
134
|
+
self.k_buffer = [
|
135
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
136
|
+
for _ in range(layer_num)
|
137
|
+
]
|
138
|
+
self.v_buffer = [
|
139
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
140
|
+
for _ in range(layer_num)
|
141
|
+
]
|
142
|
+
|
143
|
+
def get_key_buffer(self, layer_id: int):
|
144
|
+
return self.k_buffer[layer_id]
|
145
|
+
|
146
|
+
def get_value_buffer(self, layer_id: int):
|
147
|
+
return self.v_buffer[layer_id]
|
148
|
+
|
149
|
+
def get_kv_buffer(self, layer_id: int):
|
150
|
+
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
151
|
+
|
152
|
+
|
153
|
+
class MLATokenToKVPool(BaseTokenToKVPool):
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
size: int,
|
158
|
+
dtype: torch.dtype,
|
159
|
+
kv_lora_rank: int,
|
160
|
+
qk_rope_head_dim: int,
|
161
|
+
layer_num: int,
|
162
|
+
):
|
163
|
+
super().__init__(size)
|
164
|
+
|
165
|
+
self.kv_lora_rank = kv_lora_rank
|
166
|
+
self.kv_buffer = [
|
167
|
+
torch.empty(
|
168
|
+
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
169
|
+
dtype=dtype,
|
170
|
+
device="cuda",
|
171
|
+
)
|
172
|
+
for _ in range(layer_num)
|
173
|
+
]
|
174
|
+
|
175
|
+
def get_key_buffer(self, layer_id: int):
|
176
|
+
return self.kv_buffer[layer_id]
|
177
|
+
|
178
|
+
def get_value_buffer(self, layer_id: int):
|
179
|
+
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
180
|
+
|
181
|
+
def get_kv_buffer(self, layer_id: int):
|
182
|
+
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
sglang/srt/model_config.py
CHANGED
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from enum import IntEnum, auto
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
from transformers import PretrainedConfig
|
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
|
|
20
21
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
21
22
|
|
22
23
|
|
24
|
+
class AttentionArch(IntEnum):
|
25
|
+
MLA = auto()
|
26
|
+
MHA = auto()
|
27
|
+
|
28
|
+
|
23
29
|
class ModelConfig:
|
24
30
|
def __init__(
|
25
31
|
self,
|
@@ -55,6 +61,11 @@ class ModelConfig:
|
|
55
61
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
56
62
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
57
63
|
self.head_dim = 256
|
64
|
+
self.attention_arch = AttentionArch.MLA
|
65
|
+
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
66
|
+
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
67
|
+
else:
|
68
|
+
self.attention_arch = AttentionArch.MHA
|
58
69
|
|
59
70
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
60
71
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
|
|
47
47
|
InputMetadata,
|
48
48
|
global_server_args_dict,
|
49
49
|
)
|
50
|
-
from sglang.srt.mem_cache.memory_pool import
|
50
|
+
from sglang.srt.mem_cache.memory_pool import (
|
51
|
+
MHATokenToKVPool,
|
52
|
+
MLATokenToKVPool,
|
53
|
+
ReqToTokenPool,
|
54
|
+
)
|
55
|
+
from sglang.srt.model_config import AttentionArch
|
51
56
|
from sglang.srt.server_args import ServerArgs
|
52
57
|
from sglang.srt.utils import (
|
53
58
|
get_available_gpu_memory,
|
@@ -86,6 +91,7 @@ class ModelRunner:
|
|
86
91
|
"disable_flashinfer": server_args.disable_flashinfer,
|
87
92
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
88
93
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
94
|
+
"enable_mla": server_args.enable_mla,
|
89
95
|
}
|
90
96
|
)
|
91
97
|
|
@@ -193,15 +199,23 @@ class ModelRunner:
|
|
193
199
|
available_gpu_memory = get_available_gpu_memory(
|
194
200
|
self.gpu_id, distributed=self.tp_size > 1
|
195
201
|
)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
202
|
+
if (
|
203
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
204
|
+
and self.server_args.enable_mla
|
205
|
+
):
|
206
|
+
cell_size = (
|
207
|
+
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
208
|
+
* self.model_config.num_hidden_layers
|
209
|
+
* torch._utils._element_size(self.dtype)
|
210
|
+
)
|
211
|
+
else:
|
212
|
+
cell_size = (
|
213
|
+
self.model_config.get_num_kv_heads(self.tp_size)
|
214
|
+
* self.model_config.head_dim
|
215
|
+
* self.model_config.num_hidden_layers
|
216
|
+
* 2
|
217
|
+
* torch._utils._element_size(self.dtype)
|
218
|
+
)
|
205
219
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
206
220
|
1 - self.mem_fraction_static
|
207
221
|
)
|
@@ -241,13 +255,28 @@ class ModelRunner:
|
|
241
255
|
max_num_reqs,
|
242
256
|
self.model_config.context_len + 8,
|
243
257
|
)
|
244
|
-
|
245
|
-
self.
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
258
|
+
if (
|
259
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
260
|
+
and self.server_args.enable_mla
|
261
|
+
):
|
262
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
263
|
+
self.max_total_num_tokens,
|
264
|
+
dtype=self.dtype,
|
265
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
266
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
267
|
+
layer_num=self.model_config.num_hidden_layers,
|
268
|
+
)
|
269
|
+
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
270
|
+
# FIXME: temporarily only Triton MLA is supported
|
271
|
+
self.server_args.disable_flashinfer = True
|
272
|
+
else:
|
273
|
+
self.token_to_kv_pool = MHATokenToKVPool(
|
274
|
+
self.max_total_num_tokens,
|
275
|
+
dtype=self.dtype,
|
276
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
277
|
+
head_dim=self.model_config.head_dim,
|
278
|
+
layer_num=self.model_config.num_hidden_layers,
|
279
|
+
)
|
251
280
|
logger.info(
|
252
281
|
f"[gpu={self.gpu_id}] Memory pool end. "
|
253
282
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
@@ -312,10 +341,12 @@ class ModelRunner:
|
|
312
341
|
self.cuda_graph_runner.capture(batch_size_list)
|
313
342
|
except RuntimeError as e:
|
314
343
|
raise Exception(
|
315
|
-
f"Capture cuda graph failed: {e}
|
316
|
-
|
317
|
-
|
318
|
-
|
344
|
+
f"Capture cuda graph failed: {e}\n"
|
345
|
+
"Possible solutions:\n"
|
346
|
+
"1. disable torch compile by not using --enable-torch-compile\n"
|
347
|
+
"2. disable cuda graph by --disable-cuda-graph\n"
|
348
|
+
"3. set --mem-fraction-static to a smaller value\n"
|
349
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
319
350
|
)
|
320
351
|
|
321
352
|
@torch.inference_mode()
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
49
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
|
|
312
313
|
return output
|
313
314
|
|
314
315
|
|
316
|
+
class DeepseekV2AttentionMLA(nn.Module):
|
317
|
+
|
318
|
+
def __init__(
|
319
|
+
self,
|
320
|
+
config: PretrainedConfig,
|
321
|
+
hidden_size: int,
|
322
|
+
num_heads: int,
|
323
|
+
qk_nope_head_dim: int,
|
324
|
+
qk_rope_head_dim: int,
|
325
|
+
v_head_dim: int,
|
326
|
+
q_lora_rank: int,
|
327
|
+
kv_lora_rank: int,
|
328
|
+
rope_theta: float = 10000,
|
329
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
330
|
+
max_position_embeddings: int = 8192,
|
331
|
+
cache_config: Optional[CacheConfig] = None,
|
332
|
+
quant_config: Optional[QuantizationConfig] = None,
|
333
|
+
layer_id=None,
|
334
|
+
) -> None:
|
335
|
+
super().__init__()
|
336
|
+
self.layer_id = layer_id
|
337
|
+
self.hidden_size = hidden_size
|
338
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
339
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
340
|
+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
341
|
+
self.v_head_dim = v_head_dim
|
342
|
+
self.q_lora_rank = q_lora_rank
|
343
|
+
self.kv_lora_rank = kv_lora_rank
|
344
|
+
self.num_heads = num_heads
|
345
|
+
tp_size = get_tensor_model_parallel_world_size()
|
346
|
+
assert num_heads % tp_size == 0
|
347
|
+
self.num_local_heads = num_heads // tp_size
|
348
|
+
self.scaling = self.qk_head_dim**-0.5
|
349
|
+
self.rope_theta = rope_theta
|
350
|
+
self.max_position_embeddings = max_position_embeddings
|
351
|
+
|
352
|
+
if self.q_lora_rank is not None:
|
353
|
+
self.q_a_proj = ReplicatedLinear(
|
354
|
+
self.hidden_size,
|
355
|
+
self.q_lora_rank,
|
356
|
+
bias=False,
|
357
|
+
quant_config=quant_config,
|
358
|
+
)
|
359
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
360
|
+
self.q_b_proj = ColumnParallelLinear(
|
361
|
+
q_lora_rank,
|
362
|
+
self.num_heads * self.qk_head_dim,
|
363
|
+
bias=False,
|
364
|
+
quant_config=quant_config,
|
365
|
+
)
|
366
|
+
else:
|
367
|
+
self.q_proj = ColumnParallelLinear(
|
368
|
+
self.hidden_size,
|
369
|
+
self.num_heads * self.qk_head_dim,
|
370
|
+
bias=False,
|
371
|
+
quant_config=quant_config,
|
372
|
+
)
|
373
|
+
|
374
|
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
375
|
+
self.hidden_size,
|
376
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
377
|
+
bias=False,
|
378
|
+
quant_config=quant_config,
|
379
|
+
)
|
380
|
+
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
381
|
+
self.kv_b_proj = ColumnParallelLinear(
|
382
|
+
self.kv_lora_rank,
|
383
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
384
|
+
bias=False,
|
385
|
+
quant_config=quant_config,
|
386
|
+
)
|
387
|
+
# O projection.
|
388
|
+
self.o_proj = RowParallelLinear(
|
389
|
+
self.num_heads * self.v_head_dim,
|
390
|
+
self.hidden_size,
|
391
|
+
bias=False,
|
392
|
+
quant_config=quant_config,
|
393
|
+
)
|
394
|
+
rope_scaling["type"] = "deepseek_yarn"
|
395
|
+
self.rotary_emb = get_rope(
|
396
|
+
qk_rope_head_dim,
|
397
|
+
rotary_dim=qk_rope_head_dim,
|
398
|
+
max_position=max_position_embeddings,
|
399
|
+
base=rope_theta,
|
400
|
+
rope_scaling=rope_scaling,
|
401
|
+
is_neox_style=False,
|
402
|
+
)
|
403
|
+
|
404
|
+
if rope_scaling:
|
405
|
+
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
406
|
+
scaling_factor = rope_scaling["factor"]
|
407
|
+
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
408
|
+
self.scaling = self.scaling * mscale * mscale
|
409
|
+
|
410
|
+
self.attn = RadixAttention(
|
411
|
+
self.num_local_heads,
|
412
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
413
|
+
self.scaling,
|
414
|
+
num_kv_heads=1,
|
415
|
+
layer_id=layer_id,
|
416
|
+
v_head_dim=self.kv_lora_rank,
|
417
|
+
)
|
418
|
+
|
419
|
+
kv_b_proj = self.kv_b_proj
|
420
|
+
w_kc, w_vc = kv_b_proj.weight.unflatten(
|
421
|
+
0, (-1, qk_nope_head_dim + v_head_dim)
|
422
|
+
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
423
|
+
self.w_kc = w_kc
|
424
|
+
self.w_vc = w_vc
|
425
|
+
|
426
|
+
def forward(
|
427
|
+
self,
|
428
|
+
positions: torch.Tensor,
|
429
|
+
hidden_states: torch.Tensor,
|
430
|
+
input_metadata: InputMetadata,
|
431
|
+
) -> torch.Tensor:
|
432
|
+
q_len = hidden_states.shape[0]
|
433
|
+
q_input = hidden_states.new_empty(
|
434
|
+
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
435
|
+
)
|
436
|
+
if self.q_lora_rank is not None:
|
437
|
+
q = self.q_a_proj(hidden_states)[0]
|
438
|
+
q = self.q_a_layernorm(q)
|
439
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
440
|
+
else:
|
441
|
+
q = self.q_proj(hidden_states)[0].view(
|
442
|
+
-1, self.num_local_heads, self.qk_head_dim
|
443
|
+
)
|
444
|
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
445
|
+
q_nope_out = q_input[..., : self.kv_lora_rank]
|
446
|
+
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
447
|
+
|
448
|
+
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
|
449
|
+
k_pe = k_input[..., self.kv_lora_rank :]
|
450
|
+
v_input = k_input[..., : self.kv_lora_rank]
|
451
|
+
v_input = self.kv_a_layernorm(v_input.contiguous())
|
452
|
+
k_input[..., : self.kv_lora_rank] = v_input
|
453
|
+
|
454
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
455
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
456
|
+
k_input[..., self.kv_lora_rank :] = k_pe
|
457
|
+
|
458
|
+
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
459
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
460
|
+
attn_bmm_output = attn_output.new_empty(
|
461
|
+
q_len, self.num_local_heads, self.v_head_dim
|
462
|
+
)
|
463
|
+
torch.bmm(
|
464
|
+
attn_output.transpose(0, 1),
|
465
|
+
self.w_vc.transpose(1, 2).contiguous(),
|
466
|
+
out=attn_bmm_output.transpose(0, 1),
|
467
|
+
)
|
468
|
+
|
469
|
+
attn_output = attn_bmm_output.flatten(1, 2)
|
470
|
+
output, _ = self.o_proj(attn_output)
|
471
|
+
|
472
|
+
return output
|
473
|
+
|
474
|
+
|
315
475
|
class DeepseekV2DecoderLayer(nn.Module):
|
316
476
|
|
317
477
|
def __init__(
|
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
326
486
|
rope_theta = getattr(config, "rope_theta", 10000)
|
327
487
|
rope_scaling = getattr(config, "rope_scaling", None)
|
328
488
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
489
|
+
if global_server_args_dict["enable_mla"]:
|
490
|
+
self.self_attn = DeepseekV2AttentionMLA(
|
491
|
+
config=config,
|
492
|
+
hidden_size=self.hidden_size,
|
493
|
+
num_heads=config.num_attention_heads,
|
494
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
495
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
496
|
+
v_head_dim=config.v_head_dim,
|
497
|
+
q_lora_rank=(
|
498
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
499
|
+
),
|
500
|
+
kv_lora_rank=config.kv_lora_rank,
|
501
|
+
rope_theta=rope_theta,
|
502
|
+
rope_scaling=rope_scaling,
|
503
|
+
max_position_embeddings=max_position_embeddings,
|
504
|
+
cache_config=cache_config,
|
505
|
+
quant_config=quant_config,
|
506
|
+
layer_id=layer_id,
|
507
|
+
)
|
508
|
+
else:
|
509
|
+
self.self_attn = DeepseekV2Attention(
|
510
|
+
config=config,
|
511
|
+
hidden_size=self.hidden_size,
|
512
|
+
num_heads=config.num_attention_heads,
|
513
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
514
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
515
|
+
v_head_dim=config.v_head_dim,
|
516
|
+
q_lora_rank=(
|
517
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
518
|
+
),
|
519
|
+
kv_lora_rank=config.kv_lora_rank,
|
520
|
+
rope_theta=rope_theta,
|
521
|
+
rope_scaling=rope_scaling,
|
522
|
+
max_position_embeddings=max_position_embeddings,
|
523
|
+
cache_config=cache_config,
|
524
|
+
quant_config=quant_config,
|
525
|
+
layer_id=layer_id,
|
526
|
+
)
|
345
527
|
if (
|
346
528
|
config.n_routed_experts is not None
|
347
529
|
and layer_id >= config.first_k_dense_replace
|