sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
29
29
  from sglang.srt.constrained import RegexGuide
30
30
  from sglang.srt.constrained.jump_forward import JumpForwardMap
31
31
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
32
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
33
33
  from sglang.srt.mem_cache.radix_cache import RadixCache
34
34
 
35
35
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -39,6 +39,7 @@ global_server_args_dict = {
39
39
  "disable_flashinfer": False,
40
40
  "disable_flashinfer_sampling": False,
41
41
  "attention_reduce_in_fp32": False,
42
+ "enable_mla": False,
42
43
  }
43
44
 
44
45
 
@@ -289,7 +290,7 @@ class Batch:
289
290
  # Request, memory pool, and cache
290
291
  reqs: List[Req]
291
292
  req_to_token_pool: ReqToTokenPool
292
- token_to_kv_pool: TokenToKVPool
293
+ token_to_kv_pool: BaseTokenToKVPool
293
294
  tree_cache: RadixCache
294
295
 
295
296
  # Batched arguments to model runner
@@ -380,13 +381,15 @@ class Batch:
380
381
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
381
382
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
382
383
  if out_cache_loc is None:
383
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
384
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
384
+ if self.tree_cache is not None:
385
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
386
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
385
387
 
386
388
  if out_cache_loc is None:
387
- logger.error("Prefill out of memory. This should never happen.")
388
- self.tree_cache.pretty_print()
389
- exit()
389
+ logger.error("Prefill out of memory. Try to lower your batch size.")
390
+ if self.tree_cache is not None:
391
+ self.tree_cache.pretty_print()
392
+ exit(1)
390
393
 
391
394
  pt = 0
392
395
  for i in range(bs):
@@ -637,9 +640,10 @@ class Batch:
637
640
  self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
638
641
 
639
642
  if self.out_cache_loc is None:
640
- logger.error("Decode out of memory. This should never happen.")
641
- self.tree_cache.pretty_print()
642
- exit()
643
+ logger.error("Decode out of memory. Try to lower your batch size.")
644
+ if self.tree_cache is not None:
645
+ self.tree_cache.pretty_print()
646
+ exit(1)
643
647
 
644
648
  self.req_to_token_pool.req_to_token[
645
649
  self.req_pool_indices, self.seq_lens - 1
@@ -777,7 +781,7 @@ class InputMetadata:
777
781
  seq_lens: torch.Tensor
778
782
  positions: torch.Tensor
779
783
  req_to_token_pool: ReqToTokenPool
780
- token_to_kv_pool: TokenToKVPool
784
+ token_to_kv_pool: BaseTokenToKVPool
781
785
 
782
786
  # For extend
783
787
  extend_seq_lens: torch.Tensor
@@ -153,8 +153,9 @@ class TokenizerManager:
153
153
  async def _handle_single_request(
154
154
  self, obj, request, index=None, is_cache_for_prefill=False
155
155
  ):
156
- if not is_cache_for_prefill:
157
- not_use_index = not (index is not None)
156
+ if not is_cache_for_prefill: # The normal case with a single prompt
157
+ not_use_index = index is None
158
+
158
159
  rid = obj.rid if not_use_index else obj.rid[index]
159
160
  input_text = obj.text if not_use_index else obj.text[index]
160
161
  input_ids = (
@@ -182,14 +183,27 @@ class TokenizerManager:
182
183
  top_logprobs_num = (
183
184
  obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
184
185
  )
185
- else:
186
- if isinstance(obj.text, list):
187
- input_text = obj.text[index]
188
- rid = obj.rid[index]
186
+ else: # A prefill request to cache the common prompt for parallel sampling
187
+ if obj.text is not None:
188
+ if isinstance(obj.text, list):
189
+ input_text = obj.text[index]
190
+ rid = obj.rid[index]
191
+ else:
192
+ input_text = obj.text
193
+ rid = obj.rid[0]
194
+ input_ids = self.tokenizer.encode(input_text)
189
195
  else:
190
- input_text = obj.text
191
- rid = obj.rid[0]
192
- input_ids = self.tokenizer.encode(input_text)
196
+ input_text = None
197
+ if isinstance(obj.input_ids, list) and isinstance(
198
+ obj.input_ids[0], list
199
+ ):
200
+ # when obj["input_ids"] is List[List[int]]
201
+ input_ids = obj.input_ids[index]
202
+ rid = obj.rid[index]
203
+ else:
204
+ input_ids = obj.input_ids
205
+ rid = obj.rid[0]
206
+
193
207
  sampling_params = SamplingParams(**obj.sampling_params[0])
194
208
  sampling_params.max_new_tokens = 0
195
209
  pixel_values, image_hash, image_size = await self._get_pixel_values(
@@ -240,11 +254,11 @@ class TokenizerManager:
240
254
  ):
241
255
  if input_id_result is not None:
242
256
  input_id_result.append(input_id)
243
- pass
244
- if len(input_id_result) > 1 and input_id_result is not None:
257
+ if input_id_result is not None and len(input_id_result) > 1:
245
258
  obj.input_ids = input_id_result
246
259
  elif input_id_result is not None:
247
260
  obj.input_ids = input_id_result[0]
261
+
248
262
  # First send out all requests
249
263
  for i in range(batch_size):
250
264
  for j in range(parallel_sample_num):
@@ -264,11 +278,12 @@ class TokenizerManager:
264
278
  input_text = None
265
279
  input_ids = obj.input_ids[i]
266
280
  else:
281
+ assert obj.input_ids is not None
267
282
  if batch_size == 1:
268
- input_text = obj.text
283
+ input_text = None
269
284
  input_ids = obj.input_ids
270
285
  else:
271
- input_text = obj.text[i]
286
+ input_text = None
272
287
  input_ids = obj.input_ids[i]
273
288
  sampling_params = self._get_sampling_params(obj.sampling_params[index])
274
289
  pixel_values, image_hash, image_size = await self._get_pixel_values(
@@ -57,32 +57,18 @@ class ReqToTokenPool:
57
57
  self.can_use_mem_size = len(self.mem_state)
58
58
 
59
59
 
60
- class TokenToKVPool:
60
+ class BaseTokenToKVPool:
61
61
  """A memory pool that maps a token to its kv cache locations"""
62
62
 
63
63
  def __init__(
64
64
  self,
65
65
  size: int,
66
- dtype: torch.dtype,
67
- head_num: int,
68
- head_dim: int,
69
- layer_num: int,
70
66
  ):
71
67
  self.size = size
72
68
 
73
69
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
74
70
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
75
71
 
76
- # [size, head_num, head_dim] for each layer
77
- self.k_buffer = [
78
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
79
- for _ in range(layer_num)
80
- ]
81
- self.v_buffer = [
82
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
83
- for _ in range(layer_num)
84
- ]
85
-
86
72
  # Prefetch buffer
87
73
  self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
88
74
  self.prefetch_chunk_size = 512
@@ -90,15 +76,6 @@ class TokenToKVPool:
90
76
  self.can_use_mem_size = self.size
91
77
  self.clear()
92
78
 
93
- def get_key_buffer(self, layer_id: int):
94
- return self.k_buffer[layer_id]
95
-
96
- def get_value_buffer(self, layer_id: int):
97
- return self.v_buffer[layer_id]
98
-
99
- def get_kv_buffer(self, layer_id: int):
100
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
101
-
102
79
  def available_size(self):
103
80
  return self.can_use_mem_size + len(self.prefetch_buffer)
104
81
 
@@ -139,3 +116,67 @@ class TokenToKVPool:
139
116
 
140
117
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
141
118
  self.mem_state[0] = False
119
+
120
+
121
+ class MHATokenToKVPool(BaseTokenToKVPool):
122
+
123
+ def __init__(
124
+ self,
125
+ size: int,
126
+ dtype: torch.dtype,
127
+ head_num: int,
128
+ head_dim: int,
129
+ layer_num: int,
130
+ ):
131
+ super().__init__(size)
132
+
133
+ # [size, head_num, head_dim] for each layer
134
+ self.k_buffer = [
135
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
136
+ for _ in range(layer_num)
137
+ ]
138
+ self.v_buffer = [
139
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
140
+ for _ in range(layer_num)
141
+ ]
142
+
143
+ def get_key_buffer(self, layer_id: int):
144
+ return self.k_buffer[layer_id]
145
+
146
+ def get_value_buffer(self, layer_id: int):
147
+ return self.v_buffer[layer_id]
148
+
149
+ def get_kv_buffer(self, layer_id: int):
150
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
151
+
152
+
153
+ class MLATokenToKVPool(BaseTokenToKVPool):
154
+
155
+ def __init__(
156
+ self,
157
+ size: int,
158
+ dtype: torch.dtype,
159
+ kv_lora_rank: int,
160
+ qk_rope_head_dim: int,
161
+ layer_num: int,
162
+ ):
163
+ super().__init__(size)
164
+
165
+ self.kv_lora_rank = kv_lora_rank
166
+ self.kv_buffer = [
167
+ torch.empty(
168
+ (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
169
+ dtype=dtype,
170
+ device="cuda",
171
+ )
172
+ for _ in range(layer_num)
173
+ ]
174
+
175
+ def get_key_buffer(self, layer_id: int):
176
+ return self.kv_buffer[layer_id]
177
+
178
+ def get_value_buffer(self, layer_id: int):
179
+ return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
180
+
181
+ def get_kv_buffer(self, layer_id: int):
182
+ return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from enum import IntEnum, auto
16
17
  from typing import Optional
17
18
 
18
19
  from transformers import PretrainedConfig
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
20
21
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
21
22
 
22
23
 
24
+ class AttentionArch(IntEnum):
25
+ MLA = auto()
26
+ MHA = auto()
27
+
28
+
23
29
  class ModelConfig:
24
30
  def __init__(
25
31
  self,
@@ -55,6 +61,11 @@ class ModelConfig:
55
61
  # FIXME: temporary special judge for deepseek v2 MLA architecture
56
62
  if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
57
63
  self.head_dim = 256
64
+ self.attention_arch = AttentionArch.MLA
65
+ self.kv_lora_rank = self.hf_config.kv_lora_rank
66
+ self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
67
+ else:
68
+ self.attention_arch = AttentionArch.MHA
58
69
 
59
70
  self.num_attention_heads = self.hf_config.num_attention_heads
60
71
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
47
47
  InputMetadata,
48
48
  global_server_args_dict,
49
49
  )
50
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
50
+ from sglang.srt.mem_cache.memory_pool import (
51
+ MHATokenToKVPool,
52
+ MLATokenToKVPool,
53
+ ReqToTokenPool,
54
+ )
55
+ from sglang.srt.model_config import AttentionArch
51
56
  from sglang.srt.server_args import ServerArgs
52
57
  from sglang.srt.utils import (
53
58
  get_available_gpu_memory,
@@ -86,6 +91,7 @@ class ModelRunner:
86
91
  "disable_flashinfer": server_args.disable_flashinfer,
87
92
  "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
88
93
  "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
94
+ "enable_mla": server_args.enable_mla,
89
95
  }
90
96
  )
91
97
 
@@ -193,15 +199,23 @@ class ModelRunner:
193
199
  available_gpu_memory = get_available_gpu_memory(
194
200
  self.gpu_id, distributed=self.tp_size > 1
195
201
  )
196
- head_dim = self.model_config.head_dim
197
- head_num = self.model_config.get_num_kv_heads(self.tp_size)
198
- cell_size = (
199
- head_num
200
- * head_dim
201
- * self.model_config.num_hidden_layers
202
- * 2
203
- * torch._utils._element_size(self.dtype)
204
- )
202
+ if (
203
+ self.model_config.attention_arch == AttentionArch.MLA
204
+ and self.server_args.enable_mla
205
+ ):
206
+ cell_size = (
207
+ (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
208
+ * self.model_config.num_hidden_layers
209
+ * torch._utils._element_size(self.dtype)
210
+ )
211
+ else:
212
+ cell_size = (
213
+ self.model_config.get_num_kv_heads(self.tp_size)
214
+ * self.model_config.head_dim
215
+ * self.model_config.num_hidden_layers
216
+ * 2
217
+ * torch._utils._element_size(self.dtype)
218
+ )
205
219
  rest_memory = available_gpu_memory - total_gpu_memory * (
206
220
  1 - self.mem_fraction_static
207
221
  )
@@ -241,13 +255,28 @@ class ModelRunner:
241
255
  max_num_reqs,
242
256
  self.model_config.context_len + 8,
243
257
  )
244
- self.token_to_kv_pool = TokenToKVPool(
245
- self.max_total_num_tokens,
246
- dtype=self.dtype,
247
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
248
- head_dim=self.model_config.head_dim,
249
- layer_num=self.model_config.num_hidden_layers,
250
- )
258
+ if (
259
+ self.model_config.attention_arch == AttentionArch.MLA
260
+ and self.server_args.enable_mla
261
+ ):
262
+ self.token_to_kv_pool = MLATokenToKVPool(
263
+ self.max_total_num_tokens,
264
+ dtype=self.dtype,
265
+ kv_lora_rank=self.model_config.kv_lora_rank,
266
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
267
+ layer_num=self.model_config.num_hidden_layers,
268
+ )
269
+ logger.info("using MLA Triton implementaion, flashinfer is disabled")
270
+ # FIXME: temporarily only Triton MLA is supported
271
+ self.server_args.disable_flashinfer = True
272
+ else:
273
+ self.token_to_kv_pool = MHATokenToKVPool(
274
+ self.max_total_num_tokens,
275
+ dtype=self.dtype,
276
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
277
+ head_dim=self.model_config.head_dim,
278
+ layer_num=self.model_config.num_hidden_layers,
279
+ )
251
280
  logger.info(
252
281
  f"[gpu={self.gpu_id}] Memory pool end. "
253
282
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
  from sglang.srt.model_executor.model_runner import InputMetadata
49
50
 
50
51
 
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
312
313
  return output
313
314
 
314
315
 
316
+ class DeepseekV2AttentionMLA(nn.Module):
317
+
318
+ def __init__(
319
+ self,
320
+ config: PretrainedConfig,
321
+ hidden_size: int,
322
+ num_heads: int,
323
+ qk_nope_head_dim: int,
324
+ qk_rope_head_dim: int,
325
+ v_head_dim: int,
326
+ q_lora_rank: int,
327
+ kv_lora_rank: int,
328
+ rope_theta: float = 10000,
329
+ rope_scaling: Optional[Dict[str, Any]] = None,
330
+ max_position_embeddings: int = 8192,
331
+ cache_config: Optional[CacheConfig] = None,
332
+ quant_config: Optional[QuantizationConfig] = None,
333
+ layer_id=None,
334
+ ) -> None:
335
+ super().__init__()
336
+ self.layer_id = layer_id
337
+ self.hidden_size = hidden_size
338
+ self.qk_nope_head_dim = qk_nope_head_dim
339
+ self.qk_rope_head_dim = qk_rope_head_dim
340
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
341
+ self.v_head_dim = v_head_dim
342
+ self.q_lora_rank = q_lora_rank
343
+ self.kv_lora_rank = kv_lora_rank
344
+ self.num_heads = num_heads
345
+ tp_size = get_tensor_model_parallel_world_size()
346
+ assert num_heads % tp_size == 0
347
+ self.num_local_heads = num_heads // tp_size
348
+ self.scaling = self.qk_head_dim**-0.5
349
+ self.rope_theta = rope_theta
350
+ self.max_position_embeddings = max_position_embeddings
351
+
352
+ if self.q_lora_rank is not None:
353
+ self.q_a_proj = ReplicatedLinear(
354
+ self.hidden_size,
355
+ self.q_lora_rank,
356
+ bias=False,
357
+ quant_config=quant_config,
358
+ )
359
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
360
+ self.q_b_proj = ColumnParallelLinear(
361
+ q_lora_rank,
362
+ self.num_heads * self.qk_head_dim,
363
+ bias=False,
364
+ quant_config=quant_config,
365
+ )
366
+ else:
367
+ self.q_proj = ColumnParallelLinear(
368
+ self.hidden_size,
369
+ self.num_heads * self.qk_head_dim,
370
+ bias=False,
371
+ quant_config=quant_config,
372
+ )
373
+
374
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
375
+ self.hidden_size,
376
+ self.kv_lora_rank + self.qk_rope_head_dim,
377
+ bias=False,
378
+ quant_config=quant_config,
379
+ )
380
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
381
+ self.kv_b_proj = ColumnParallelLinear(
382
+ self.kv_lora_rank,
383
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
384
+ bias=False,
385
+ quant_config=quant_config,
386
+ )
387
+ # O projection.
388
+ self.o_proj = RowParallelLinear(
389
+ self.num_heads * self.v_head_dim,
390
+ self.hidden_size,
391
+ bias=False,
392
+ quant_config=quant_config,
393
+ )
394
+ rope_scaling["type"] = "deepseek_yarn"
395
+ self.rotary_emb = get_rope(
396
+ qk_rope_head_dim,
397
+ rotary_dim=qk_rope_head_dim,
398
+ max_position=max_position_embeddings,
399
+ base=rope_theta,
400
+ rope_scaling=rope_scaling,
401
+ is_neox_style=False,
402
+ )
403
+
404
+ if rope_scaling:
405
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
406
+ scaling_factor = rope_scaling["factor"]
407
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
408
+ self.scaling = self.scaling * mscale * mscale
409
+
410
+ self.attn = RadixAttention(
411
+ self.num_local_heads,
412
+ self.kv_lora_rank + self.qk_rope_head_dim,
413
+ self.scaling,
414
+ num_kv_heads=1,
415
+ layer_id=layer_id,
416
+ v_head_dim=self.kv_lora_rank,
417
+ )
418
+
419
+ kv_b_proj = self.kv_b_proj
420
+ w_kc, w_vc = kv_b_proj.weight.unflatten(
421
+ 0, (-1, qk_nope_head_dim + v_head_dim)
422
+ ).split([qk_nope_head_dim, v_head_dim], dim=1)
423
+ self.w_kc = w_kc
424
+ self.w_vc = w_vc
425
+
426
+ def forward(
427
+ self,
428
+ positions: torch.Tensor,
429
+ hidden_states: torch.Tensor,
430
+ input_metadata: InputMetadata,
431
+ ) -> torch.Tensor:
432
+ q_len = hidden_states.shape[0]
433
+ q_input = hidden_states.new_empty(
434
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
435
+ )
436
+ if self.q_lora_rank is not None:
437
+ q = self.q_a_proj(hidden_states)[0]
438
+ q = self.q_a_layernorm(q)
439
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
440
+ else:
441
+ q = self.q_proj(hidden_states)[0].view(
442
+ -1, self.num_local_heads, self.qk_head_dim
443
+ )
444
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
445
+ q_nope_out = q_input[..., : self.kv_lora_rank]
446
+ torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
447
+
448
+ k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
449
+ k_pe = k_input[..., self.kv_lora_rank :]
450
+ v_input = k_input[..., : self.kv_lora_rank]
451
+ v_input = self.kv_a_layernorm(v_input.contiguous())
452
+ k_input[..., : self.kv_lora_rank] = v_input
453
+
454
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
455
+ q_input[..., self.kv_lora_rank :] = q_pe
456
+ k_input[..., self.kv_lora_rank :] = k_pe
457
+
458
+ attn_output = self.attn(q_input, k_input, v_input, input_metadata)
459
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
460
+ attn_bmm_output = attn_output.new_empty(
461
+ q_len, self.num_local_heads, self.v_head_dim
462
+ )
463
+ torch.bmm(
464
+ attn_output.transpose(0, 1),
465
+ self.w_vc.transpose(1, 2).contiguous(),
466
+ out=attn_bmm_output.transpose(0, 1),
467
+ )
468
+
469
+ attn_output = attn_bmm_output.flatten(1, 2)
470
+ output, _ = self.o_proj(attn_output)
471
+
472
+ return output
473
+
474
+
315
475
  class DeepseekV2DecoderLayer(nn.Module):
316
476
 
317
477
  def __init__(
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
326
486
  rope_theta = getattr(config, "rope_theta", 10000)
327
487
  rope_scaling = getattr(config, "rope_scaling", None)
328
488
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
329
- self.self_attn = DeepseekV2Attention(
330
- config=config,
331
- hidden_size=self.hidden_size,
332
- num_heads=config.num_attention_heads,
333
- qk_nope_head_dim=config.qk_nope_head_dim,
334
- qk_rope_head_dim=config.qk_rope_head_dim,
335
- v_head_dim=config.v_head_dim,
336
- q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
337
- kv_lora_rank=config.kv_lora_rank,
338
- rope_theta=rope_theta,
339
- rope_scaling=rope_scaling,
340
- max_position_embeddings=max_position_embeddings,
341
- cache_config=cache_config,
342
- quant_config=quant_config,
343
- layer_id=layer_id,
344
- )
489
+ if global_server_args_dict["enable_mla"]:
490
+ self.self_attn = DeepseekV2AttentionMLA(
491
+ config=config,
492
+ hidden_size=self.hidden_size,
493
+ num_heads=config.num_attention_heads,
494
+ qk_nope_head_dim=config.qk_nope_head_dim,
495
+ qk_rope_head_dim=config.qk_rope_head_dim,
496
+ v_head_dim=config.v_head_dim,
497
+ q_lora_rank=(
498
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
499
+ ),
500
+ kv_lora_rank=config.kv_lora_rank,
501
+ rope_theta=rope_theta,
502
+ rope_scaling=rope_scaling,
503
+ max_position_embeddings=max_position_embeddings,
504
+ cache_config=cache_config,
505
+ quant_config=quant_config,
506
+ layer_id=layer_id,
507
+ )
508
+ else:
509
+ self.self_attn = DeepseekV2Attention(
510
+ config=config,
511
+ hidden_size=self.hidden_size,
512
+ num_heads=config.num_attention_heads,
513
+ qk_nope_head_dim=config.qk_nope_head_dim,
514
+ qk_rope_head_dim=config.qk_rope_head_dim,
515
+ v_head_dim=config.v_head_dim,
516
+ q_lora_rank=(
517
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
518
+ ),
519
+ kv_lora_rank=config.kv_lora_rank,
520
+ rope_theta=rope_theta,
521
+ rope_scaling=rope_scaling,
522
+ max_position_embeddings=max_position_embeddings,
523
+ cache_config=cache_config,
524
+ quant_config=quant_config,
525
+ layer_id=layer_id,
526
+ )
345
527
  if (
346
528
  config.n_routed_experts is not None
347
529
  and layer_id >= config.first_k_dense_replace