sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """ModelRunner runs the forward passes of the models."""
17
+ from dataclasses import dataclass
18
+ from enum import IntEnum, auto
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
25
+
26
+
27
+ class ForwardMode(IntEnum):
28
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
29
+ PREFILL = auto()
30
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
31
+ EXTEND = auto()
32
+ # Decode one token.
33
+ DECODE = auto()
34
+
35
+
36
+ @dataclass
37
+ class InputMetadata:
38
+ """Store all inforamtion of a forward pass."""
39
+
40
+ forward_mode: ForwardMode
41
+ batch_size: int
42
+ total_num_tokens: int
43
+ req_pool_indices: torch.Tensor
44
+ seq_lens: torch.Tensor
45
+ positions: torch.Tensor
46
+ req_to_token_pool: ReqToTokenPool
47
+ token_to_kv_pool: BaseTokenToKVPool
48
+
49
+ # For extend
50
+ extend_seq_lens: torch.Tensor
51
+ extend_start_loc: torch.Tensor
52
+ extend_no_prefix: bool
53
+
54
+ # Output location of the KV cache
55
+ out_cache_loc: torch.Tensor = None
56
+
57
+ # Output options
58
+ return_logprob: bool = False
59
+ top_logprobs_nums: List[int] = None
60
+
61
+ # Trition attention backend
62
+ triton_max_seq_len: int = 0
63
+ triton_max_extend_len: int = 0
64
+ triton_start_loc: torch.Tensor = None
65
+ triton_prefix_lens: torch.Tensor = None
66
+
67
+ # FlashInfer attention backend
68
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
+ flashinfer_use_ragged: bool = False
72
+
73
+ @classmethod
74
+ def create(
75
+ cls,
76
+ model_runner,
77
+ forward_mode,
78
+ req_pool_indices,
79
+ seq_lens,
80
+ prefix_lens,
81
+ position_ids_offsets,
82
+ out_cache_loc,
83
+ top_logprobs_nums=None,
84
+ return_logprob=False,
85
+ skip_flashinfer_init=False,
86
+ ):
87
+ flashinfer_use_ragged = False
88
+ if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
89
+ if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
90
+ flashinfer_use_ragged = True
91
+ init_flashinfer_args(
92
+ forward_mode,
93
+ model_runner,
94
+ req_pool_indices,
95
+ seq_lens,
96
+ prefix_lens,
97
+ model_runner.flashinfer_decode_wrapper,
98
+ flashinfer_use_ragged,
99
+ )
100
+
101
+ batch_size = len(req_pool_indices)
102
+
103
+ if forward_mode == ForwardMode.DECODE:
104
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
105
+ extend_seq_lens = extend_start_loc = extend_no_prefix = None
106
+ if not model_runner.server_args.disable_flashinfer:
107
+ # This variable is not needed in this case,
108
+ # we do not compute it to make it compatbile with cuda graph.
109
+ total_num_tokens = None
110
+ else:
111
+ total_num_tokens = int(torch.sum(seq_lens))
112
+ else:
113
+ seq_lens_cpu = seq_lens.cpu().numpy()
114
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
115
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
116
+ positions = torch.tensor(
117
+ np.concatenate(
118
+ [
119
+ np.arange(
120
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
121
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
122
+ )
123
+ for i in range(batch_size)
124
+ ],
125
+ axis=0,
126
+ ),
127
+ device="cuda",
128
+ )
129
+ extend_seq_lens = seq_lens - prefix_lens
130
+ extend_start_loc = torch.zeros_like(seq_lens)
131
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
132
+ extend_no_prefix = torch.all(prefix_lens == 0)
133
+ total_num_tokens = int(torch.sum(seq_lens))
134
+
135
+ ret = cls(
136
+ forward_mode=forward_mode,
137
+ batch_size=batch_size,
138
+ total_num_tokens=total_num_tokens,
139
+ req_pool_indices=req_pool_indices,
140
+ seq_lens=seq_lens,
141
+ positions=positions,
142
+ req_to_token_pool=model_runner.req_to_token_pool,
143
+ token_to_kv_pool=model_runner.token_to_kv_pool,
144
+ out_cache_loc=out_cache_loc,
145
+ extend_seq_lens=extend_seq_lens,
146
+ extend_start_loc=extend_start_loc,
147
+ extend_no_prefix=extend_no_prefix,
148
+ return_logprob=return_logprob,
149
+ top_logprobs_nums=top_logprobs_nums,
150
+ flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
151
+ flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
152
+ flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
153
+ flashinfer_use_ragged=flashinfer_use_ragged,
154
+ )
155
+
156
+ if model_runner.server_args.disable_flashinfer:
157
+ (
158
+ ret.triton_max_seq_len,
159
+ ret.triton_max_extend_len,
160
+ ret.triton_start_loc,
161
+ ret.triton_prefix_lens,
162
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
163
+
164
+ return ret
165
+
166
+
167
+ def init_flashinfer_args(
168
+ forward_mode,
169
+ model_runner,
170
+ req_pool_indices,
171
+ seq_lens,
172
+ prefix_lens,
173
+ flashinfer_decode_wrapper,
174
+ flashinfer_use_ragged=False,
175
+ ):
176
+ """Init auxiliary variables for FlashInfer attention backend."""
177
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
178
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
179
+ head_dim = model_runner.model_config.head_dim
180
+ batch_size = len(req_pool_indices)
181
+ total_num_tokens = int(torch.sum(seq_lens))
182
+
183
+ if flashinfer_use_ragged:
184
+ paged_kernel_lens = prefix_lens
185
+ else:
186
+ paged_kernel_lens = seq_lens
187
+
188
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
189
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
190
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
191
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
192
+ kv_indices = torch.cat(
193
+ [
194
+ model_runner.req_to_token_pool.req_to_token[
195
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
196
+ ]
197
+ for i in range(batch_size)
198
+ ],
199
+ dim=0,
200
+ ).contiguous()
201
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
202
+
203
+ if forward_mode == ForwardMode.DECODE:
204
+ flashinfer_decode_wrapper.end_forward()
205
+ flashinfer_decode_wrapper.begin_forward(
206
+ kv_indptr,
207
+ kv_indices,
208
+ kv_last_page_len,
209
+ num_qo_heads,
210
+ num_kv_heads,
211
+ head_dim,
212
+ 1,
213
+ )
214
+ else:
215
+ # extend part
216
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
217
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
218
+
219
+ if flashinfer_use_ragged:
220
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
221
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
222
+ qo_indptr,
223
+ qo_indptr,
224
+ num_qo_heads,
225
+ num_kv_heads,
226
+ head_dim,
227
+ )
228
+
229
+ # cached part
230
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
231
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
232
+ qo_indptr,
233
+ kv_indptr,
234
+ kv_indices,
235
+ kv_last_page_len,
236
+ num_qo_heads,
237
+ num_kv_heads,
238
+ head_dim,
239
+ 1,
240
+ )
241
+
242
+
243
+ def init_triton_args(forward_mode, seq_lens, prefix_lens):
244
+ """Init auxiliary variables for triton attention backend."""
245
+ batch_size = len(seq_lens)
246
+ max_seq_len = int(torch.max(seq_lens))
247
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
248
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
249
+
250
+ if forward_mode == ForwardMode.DECODE:
251
+ max_extend_len = None
252
+ else:
253
+ extend_seq_lens = seq_lens - prefix_lens
254
+ max_extend_len = int(torch.max(extend_seq_lens))
255
+
256
+ return max_seq_len, max_extend_len, start_loc, prefix_lens
@@ -41,13 +41,14 @@ from vllm.distributed import (
41
41
  from vllm.model_executor.models import ModelRegistry
42
42
 
43
43
  from sglang.global_config import global_config
44
- from sglang.srt.managers.schedule_batch import (
45
- Batch,
46
- ForwardMode,
47
- InputMetadata,
48
- global_server_args_dict,
44
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
45
+ from sglang.srt.mem_cache.memory_pool import (
46
+ MHATokenToKVPool,
47
+ MLATokenToKVPool,
48
+ ReqToTokenPool,
49
49
  )
50
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
50
+ from sglang.srt.model_config import AttentionArch
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
51
52
  from sglang.srt.server_args import ServerArgs
52
53
  from sglang.srt.utils import (
53
54
  get_available_gpu_memory,
@@ -86,6 +87,7 @@ class ModelRunner:
86
87
  "disable_flashinfer": server_args.disable_flashinfer,
87
88
  "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
88
89
  "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
90
+ "enable_mla": server_args.enable_mla,
89
91
  }
90
92
  )
91
93
 
@@ -193,15 +195,23 @@ class ModelRunner:
193
195
  available_gpu_memory = get_available_gpu_memory(
194
196
  self.gpu_id, distributed=self.tp_size > 1
195
197
  )
196
- head_dim = self.model_config.head_dim
197
- head_num = self.model_config.get_num_kv_heads(self.tp_size)
198
- cell_size = (
199
- head_num
200
- * head_dim
201
- * self.model_config.num_hidden_layers
202
- * 2
203
- * torch._utils._element_size(self.dtype)
204
- )
198
+ if (
199
+ self.model_config.attention_arch == AttentionArch.MLA
200
+ and self.server_args.enable_mla
201
+ ):
202
+ cell_size = (
203
+ (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
204
+ * self.model_config.num_hidden_layers
205
+ * torch._utils._element_size(self.dtype)
206
+ )
207
+ else:
208
+ cell_size = (
209
+ self.model_config.get_num_kv_heads(self.tp_size)
210
+ * self.model_config.head_dim
211
+ * self.model_config.num_hidden_layers
212
+ * 2
213
+ * torch._utils._element_size(self.dtype)
214
+ )
205
215
  rest_memory = available_gpu_memory - total_gpu_memory * (
206
216
  1 - self.mem_fraction_static
207
217
  )
@@ -241,13 +251,28 @@ class ModelRunner:
241
251
  max_num_reqs,
242
252
  self.model_config.context_len + 8,
243
253
  )
244
- self.token_to_kv_pool = TokenToKVPool(
245
- self.max_total_num_tokens,
246
- dtype=self.dtype,
247
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
248
- head_dim=self.model_config.head_dim,
249
- layer_num=self.model_config.num_hidden_layers,
250
- )
254
+ if (
255
+ self.model_config.attention_arch == AttentionArch.MLA
256
+ and self.server_args.enable_mla
257
+ ):
258
+ self.token_to_kv_pool = MLATokenToKVPool(
259
+ self.max_total_num_tokens,
260
+ dtype=self.dtype,
261
+ kv_lora_rank=self.model_config.kv_lora_rank,
262
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
263
+ layer_num=self.model_config.num_hidden_layers,
264
+ )
265
+ logger.info("using MLA Triton implementaion, flashinfer is disabled")
266
+ # FIXME: temporarily only Triton MLA is supported
267
+ self.server_args.disable_flashinfer = True
268
+ else:
269
+ self.token_to_kv_pool = MHATokenToKVPool(
270
+ self.max_total_num_tokens,
271
+ dtype=self.dtype,
272
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
273
+ head_dim=self.model_config.head_dim,
274
+ layer_num=self.model_config.num_hidden_layers,
275
+ )
251
276
  logger.info(
252
277
  f"[gpu={self.gpu_id}] Memory pool end. "
253
278
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
@@ -321,7 +346,7 @@ class ModelRunner:
321
346
  )
322
347
 
323
348
  @torch.inference_mode()
324
- def forward_decode(self, batch: Batch):
349
+ def forward_decode(self, batch: ScheduleBatch):
325
350
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
326
351
  return self.cuda_graph_runner.replay(batch)
327
352
 
@@ -341,7 +366,7 @@ class ModelRunner:
341
366
  )
342
367
 
343
368
  @torch.inference_mode()
344
- def forward_extend(self, batch: Batch):
369
+ def forward_extend(self, batch: ScheduleBatch):
345
370
  input_metadata = InputMetadata.create(
346
371
  self,
347
372
  forward_mode=ForwardMode.EXTEND,
@@ -358,7 +383,7 @@ class ModelRunner:
358
383
  )
359
384
 
360
385
  @torch.inference_mode()
361
- def forward_extend_multi_modal(self, batch: Batch):
386
+ def forward_extend_multi_modal(self, batch: ScheduleBatch):
362
387
  input_metadata = InputMetadata.create(
363
388
  self,
364
389
  forward_mode=ForwardMode.EXTEND,
@@ -379,7 +404,7 @@ class ModelRunner:
379
404
  batch.image_offsets,
380
405
  )
381
406
 
382
- def forward(self, batch: Batch, forward_mode: ForwardMode):
407
+ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
383
408
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
384
409
  return self.forward_extend_multi_modal(batch)
385
410
  elif forward_mode == ForwardMode.DECODE:
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
  LoraConfig = None
51
51
 
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
 
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.model_executor.model_runner import InputMetadata
67
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
68
 
69
69
 
70
70
  @torch.compile
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
 
51
51
  class DbrxRouter(nn.Module):
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
46
 
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.managers.schedule_batch import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
51
51
 
52
52
  class DeepseekMLP(nn.Module):
@@ -45,7 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
51
52
  class DeepseekV2MLP(nn.Module):
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
312
313
  return output
313
314
 
314
315
 
316
+ class DeepseekV2AttentionMLA(nn.Module):
317
+
318
+ def __init__(
319
+ self,
320
+ config: PretrainedConfig,
321
+ hidden_size: int,
322
+ num_heads: int,
323
+ qk_nope_head_dim: int,
324
+ qk_rope_head_dim: int,
325
+ v_head_dim: int,
326
+ q_lora_rank: int,
327
+ kv_lora_rank: int,
328
+ rope_theta: float = 10000,
329
+ rope_scaling: Optional[Dict[str, Any]] = None,
330
+ max_position_embeddings: int = 8192,
331
+ cache_config: Optional[CacheConfig] = None,
332
+ quant_config: Optional[QuantizationConfig] = None,
333
+ layer_id=None,
334
+ ) -> None:
335
+ super().__init__()
336
+ self.layer_id = layer_id
337
+ self.hidden_size = hidden_size
338
+ self.qk_nope_head_dim = qk_nope_head_dim
339
+ self.qk_rope_head_dim = qk_rope_head_dim
340
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
341
+ self.v_head_dim = v_head_dim
342
+ self.q_lora_rank = q_lora_rank
343
+ self.kv_lora_rank = kv_lora_rank
344
+ self.num_heads = num_heads
345
+ tp_size = get_tensor_model_parallel_world_size()
346
+ assert num_heads % tp_size == 0
347
+ self.num_local_heads = num_heads // tp_size
348
+ self.scaling = self.qk_head_dim**-0.5
349
+ self.rope_theta = rope_theta
350
+ self.max_position_embeddings = max_position_embeddings
351
+
352
+ if self.q_lora_rank is not None:
353
+ self.q_a_proj = ReplicatedLinear(
354
+ self.hidden_size,
355
+ self.q_lora_rank,
356
+ bias=False,
357
+ quant_config=quant_config,
358
+ )
359
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
360
+ self.q_b_proj = ColumnParallelLinear(
361
+ q_lora_rank,
362
+ self.num_heads * self.qk_head_dim,
363
+ bias=False,
364
+ quant_config=quant_config,
365
+ )
366
+ else:
367
+ self.q_proj = ColumnParallelLinear(
368
+ self.hidden_size,
369
+ self.num_heads * self.qk_head_dim,
370
+ bias=False,
371
+ quant_config=quant_config,
372
+ )
373
+
374
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
375
+ self.hidden_size,
376
+ self.kv_lora_rank + self.qk_rope_head_dim,
377
+ bias=False,
378
+ quant_config=quant_config,
379
+ )
380
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
381
+ self.kv_b_proj = ColumnParallelLinear(
382
+ self.kv_lora_rank,
383
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
384
+ bias=False,
385
+ quant_config=quant_config,
386
+ )
387
+ # O projection.
388
+ self.o_proj = RowParallelLinear(
389
+ self.num_heads * self.v_head_dim,
390
+ self.hidden_size,
391
+ bias=False,
392
+ quant_config=quant_config,
393
+ )
394
+ rope_scaling["type"] = "deepseek_yarn"
395
+ self.rotary_emb = get_rope(
396
+ qk_rope_head_dim,
397
+ rotary_dim=qk_rope_head_dim,
398
+ max_position=max_position_embeddings,
399
+ base=rope_theta,
400
+ rope_scaling=rope_scaling,
401
+ is_neox_style=False,
402
+ )
403
+
404
+ if rope_scaling:
405
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
406
+ scaling_factor = rope_scaling["factor"]
407
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
408
+ self.scaling = self.scaling * mscale * mscale
409
+
410
+ self.attn = RadixAttention(
411
+ self.num_local_heads,
412
+ self.kv_lora_rank + self.qk_rope_head_dim,
413
+ self.scaling,
414
+ num_kv_heads=1,
415
+ layer_id=layer_id,
416
+ v_head_dim=self.kv_lora_rank,
417
+ )
418
+
419
+ kv_b_proj = self.kv_b_proj
420
+ w_kc, w_vc = kv_b_proj.weight.unflatten(
421
+ 0, (-1, qk_nope_head_dim + v_head_dim)
422
+ ).split([qk_nope_head_dim, v_head_dim], dim=1)
423
+ self.w_kc = w_kc
424
+ self.w_vc = w_vc
425
+
426
+ def forward(
427
+ self,
428
+ positions: torch.Tensor,
429
+ hidden_states: torch.Tensor,
430
+ input_metadata: InputMetadata,
431
+ ) -> torch.Tensor:
432
+ q_len = hidden_states.shape[0]
433
+ q_input = hidden_states.new_empty(
434
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
435
+ )
436
+ if self.q_lora_rank is not None:
437
+ q = self.q_a_proj(hidden_states)[0]
438
+ q = self.q_a_layernorm(q)
439
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
440
+ else:
441
+ q = self.q_proj(hidden_states)[0].view(
442
+ -1, self.num_local_heads, self.qk_head_dim
443
+ )
444
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
445
+ q_nope_out = q_input[..., : self.kv_lora_rank]
446
+ torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
447
+
448
+ k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
449
+ k_pe = k_input[..., self.kv_lora_rank :]
450
+ v_input = k_input[..., : self.kv_lora_rank]
451
+ v_input = self.kv_a_layernorm(v_input.contiguous())
452
+ k_input[..., : self.kv_lora_rank] = v_input
453
+
454
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
455
+ q_input[..., self.kv_lora_rank :] = q_pe
456
+ k_input[..., self.kv_lora_rank :] = k_pe
457
+
458
+ attn_output = self.attn(q_input, k_input, v_input, input_metadata)
459
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
460
+ attn_bmm_output = attn_output.new_empty(
461
+ q_len, self.num_local_heads, self.v_head_dim
462
+ )
463
+ torch.bmm(
464
+ attn_output.transpose(0, 1),
465
+ self.w_vc.transpose(1, 2).contiguous(),
466
+ out=attn_bmm_output.transpose(0, 1),
467
+ )
468
+
469
+ attn_output = attn_bmm_output.flatten(1, 2)
470
+ output, _ = self.o_proj(attn_output)
471
+
472
+ return output
473
+
474
+
315
475
  class DeepseekV2DecoderLayer(nn.Module):
316
476
 
317
477
  def __init__(
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
326
486
  rope_theta = getattr(config, "rope_theta", 10000)
327
487
  rope_scaling = getattr(config, "rope_scaling", None)
328
488
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
329
- self.self_attn = DeepseekV2Attention(
330
- config=config,
331
- hidden_size=self.hidden_size,
332
- num_heads=config.num_attention_heads,
333
- qk_nope_head_dim=config.qk_nope_head_dim,
334
- qk_rope_head_dim=config.qk_rope_head_dim,
335
- v_head_dim=config.v_head_dim,
336
- q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
337
- kv_lora_rank=config.kv_lora_rank,
338
- rope_theta=rope_theta,
339
- rope_scaling=rope_scaling,
340
- max_position_embeddings=max_position_embeddings,
341
- cache_config=cache_config,
342
- quant_config=quant_config,
343
- layer_id=layer_id,
344
- )
489
+ if global_server_args_dict["enable_mla"]:
490
+ self.self_attn = DeepseekV2AttentionMLA(
491
+ config=config,
492
+ hidden_size=self.hidden_size,
493
+ num_heads=config.num_attention_heads,
494
+ qk_nope_head_dim=config.qk_nope_head_dim,
495
+ qk_rope_head_dim=config.qk_rope_head_dim,
496
+ v_head_dim=config.v_head_dim,
497
+ q_lora_rank=(
498
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
499
+ ),
500
+ kv_lora_rank=config.kv_lora_rank,
501
+ rope_theta=rope_theta,
502
+ rope_scaling=rope_scaling,
503
+ max_position_embeddings=max_position_embeddings,
504
+ cache_config=cache_config,
505
+ quant_config=quant_config,
506
+ layer_id=layer_id,
507
+ )
508
+ else:
509
+ self.self_attn = DeepseekV2Attention(
510
+ config=config,
511
+ hidden_size=self.hidden_size,
512
+ num_heads=config.num_attention_heads,
513
+ qk_nope_head_dim=config.qk_nope_head_dim,
514
+ qk_rope_head_dim=config.qk_rope_head_dim,
515
+ v_head_dim=config.v_head_dim,
516
+ q_lora_rank=(
517
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
518
+ ),
519
+ kv_lora_rank=config.kv_lora_rank,
520
+ rope_theta=rope_theta,
521
+ rope_scaling=rope_scaling,
522
+ max_position_embeddings=max_position_embeddings,
523
+ cache_config=cache_config,
524
+ quant_config=quant_config,
525
+ layer_id=layer_id,
526
+ )
345
527
  if (
346
528
  config.n_routed_experts is not None
347
529
  and layer_id >= config.first_k_dense_replace
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
37
 
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
41
 
42
42
 
43
43
  class GemmaMLP(nn.Module):
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
42
42
 
43
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.model_executor.model_runner import InputMetadata
45
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
46
 
47
47
 
48
48
  class GemmaRMSNorm(CustomOp):