sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend
41
+ from sglang.srt.utils import flatten_nested_list, get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -364,23 +364,23 @@ class ForwardBatch:
364
364
 
365
365
  def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
366
366
  """
367
- Merge all image inputs in the batch into a single MultiModalInputs object.
367
+ Merge all multimodal inputs in the batch into a single MultiModalInputs object.
368
368
 
369
369
  Returns:
370
- if none, current batch contains no image input
370
+ if none, current batch contains no multimodal input
371
371
 
372
372
  """
373
373
  if not self.mm_inputs or all(x is None for x in self.mm_inputs):
374
374
  return None
375
-
376
375
  # Filter out None values
377
376
  valid_inputs = [x for x in self.mm_inputs if x is not None]
378
377
 
379
- # Start with the first valid image input
380
- merged = valid_inputs[0]
378
+ # TODO: is it expensive?
379
+ # a workaround to avoid importing `MultimodalInputs`
380
+ merged = valid_inputs[0].__class__(mm_items=[])
381
381
 
382
382
  # Merge remaining inputs
383
- for mm_input in valid_inputs[1:]:
383
+ for mm_input in valid_inputs:
384
384
  merged.merge(mm_input)
385
385
 
386
386
  return merged
@@ -407,104 +407,60 @@ class ForwardBatch:
407
407
  def _compute_mrope_positions(
408
408
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
409
409
  ):
410
- device = model_runner.device
411
- hf_config = model_runner.model_config.hf_config
412
- mrope_positions_list = [None] * self.seq_lens.shape[0]
413
- if self.forward_mode.is_decode():
414
- for i, _ in enumerate(mrope_positions_list):
415
- mrope_position_delta = (
416
- 0
417
- if batch.multimodal_inputs[i] is None
418
- else batch.multimodal_inputs[i].mrope_position_delta
419
- )
420
- mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
421
- mrope_position_delta,
422
- int(self.seq_lens[i]) - 1,
423
- int(self.seq_lens[i]),
410
+ # batch_size * [3 * seq_len]
411
+ batch_size = self.seq_lens.shape[0]
412
+ mrope_positions_list = [[]] * batch_size
413
+ for batch_idx in range(batch_size):
414
+ mm_input = batch.multimodal_inputs[batch_idx]
415
+ if self.forward_mode.is_decode():
416
+ mrope_position_deltas = (
417
+ [0]
418
+ if mm_input is None
419
+ else flatten_nested_list(mm_input.mrope_position_delta.tolist())
424
420
  )
425
- elif self.forward_mode.is_extend():
426
- extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
427
- for i, mm_input in enumerate(batch.multimodal_inputs):
428
- extend_start_loc, extend_seq_len, extend_prefix_len = (
429
- extend_start_loc_cpu[i],
430
- batch.extend_seq_lens[i],
431
- batch.extend_prefix_lens[i],
421
+ next_input_positions = []
422
+ for mrope_position_delta in mrope_position_deltas:
423
+ # batched deltas needs to be processed separately
424
+ # Convert list of lists to tensor with shape [3, seq_len]
425
+ next_input_positions += [
426
+ MRotaryEmbedding.get_next_input_positions(
427
+ mrope_position_delta,
428
+ int(self.seq_lens[batch_idx]) - 1,
429
+ int(self.seq_lens[batch_idx]),
430
+ )
431
+ ]
432
+ # 3 * N
433
+ mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
434
+ elif self.forward_mode.is_extend():
435
+ extend_seq_len, extend_prefix_len = (
436
+ batch.extend_seq_lens[batch_idx],
437
+ batch.extend_prefix_lens[batch_idx],
432
438
  )
433
439
  if mm_input is None:
434
440
  # text only
435
- mrope_positions = [
441
+ mrope_positions = torch.tensor(
436
442
  [
437
- pos
438
- for pos in range(
439
- extend_prefix_len, extend_prefix_len + extend_seq_len
440
- )
443
+ [
444
+ pos
445
+ for pos in range(
446
+ extend_prefix_len,
447
+ extend_prefix_len + extend_seq_len,
448
+ )
449
+ ]
441
450
  ]
442
- ] * 3
443
- else:
444
- image_grid_thws_list = [
445
- item.image_grid_thws
446
- for item in mm_input.mm_items
447
- if item.image_grid_thws is not None
448
- ]
449
- image_grid_thw = (
450
- None
451
- if len(image_grid_thws_list) == 0
452
- else torch.cat(image_grid_thws_list, dim=0)
453
- )
454
-
455
- video_grid_thws_list = [
456
- item.video_grid_thws
457
- for item in mm_input.mm_items
458
- if item.video_grid_thws is not None
459
- ]
460
- video_grid_thw = (
461
- None
462
- if len(video_grid_thws_list) == 0
463
- else torch.cat(video_grid_thws_list, dim=0)
451
+ * 3
464
452
  )
465
-
466
- second_per_grid_ts_list = [
467
- item.second_per_grid_ts
468
- for item in mm_input.mm_items
469
- if item.second_per_grid_ts is not None
453
+ else:
454
+ mrope_positions = mm_input.mrope_positions[
455
+ :,
456
+ extend_prefix_len : extend_prefix_len + extend_seq_len,
470
457
  ]
471
- second_per_grid_ts = (
472
- None
473
- if len(second_per_grid_ts_list) == 0
474
- else torch.cat(second_per_grid_ts_list, dim=0)
475
- )
476
-
477
- # TODO: current qwen2-vl do not support radix cache since mrope position calculation
478
- mrope_positions, mrope_position_delta = (
479
- MRotaryEmbedding.get_input_positions(
480
- input_tokens=self.input_ids[
481
- extend_start_loc : extend_start_loc + extend_seq_len
482
- ].tolist(),
483
- image_grid_thw=image_grid_thw,
484
- video_grid_thw=video_grid_thw,
485
- image_token_id=hf_config.image_token_id,
486
- video_token_id=hf_config.video_token_id,
487
- vision_start_token_id=hf_config.vision_start_token_id,
488
- vision_end_token_id=hf_config.vision_end_token_id,
489
- spatial_merge_size=hf_config.vision_config.spatial_merge_size,
490
- context_len=0,
491
- seq_len=len(self.input_ids),
492
- second_per_grid_ts=second_per_grid_ts,
493
- tokens_per_second=getattr(
494
- hf_config.vision_config, "tokens_per_second", None
495
- ),
496
- )
497
- )
498
- batch.multimodal_inputs[i].mrope_position_delta = (
499
- mrope_position_delta
500
- )
501
- mrope_positions_list[i] = mrope_positions
458
+ mrope_positions_list[batch_idx] = mrope_positions
502
459
 
503
460
  self.mrope_positions = torch.cat(
504
- [torch.tensor(pos, device=device) for pos in mrope_positions_list],
505
- axis=1,
506
- )
507
- self.mrope_positions = self.mrope_positions.to(torch.int64)
461
+ [pos.to(device=model_runner.device) for pos in mrope_positions_list],
462
+ dim=1,
463
+ ).to(dtype=torch.int64, device=model_runner.device)
508
464
 
509
465
  def get_max_chunk_capacity(self):
510
466
  # Maximum number of tokens in each chunk
@@ -91,11 +91,14 @@ from sglang.srt.utils import (
91
91
  set_cuda_arch,
92
92
  )
93
93
 
94
- logger = logging.getLogger(__name__)
95
-
94
+ # Use a small KV cache pool size for tests in CI
96
95
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
96
+
97
+ # Detect stragger ranks in model loading
97
98
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
98
99
 
100
+ logger = logging.getLogger(__name__)
101
+
99
102
 
100
103
  class ModelRunner:
101
104
  """ModelRunner runs the forward passes of the models."""
@@ -177,7 +180,7 @@ class ModelRunner:
177
180
  if _ENABLE_JIT_DEEPGEMM:
178
181
  update_deep_gemm_config(gpu_id, server_args)
179
182
 
180
- # If it is a draft model tp_group can be different.
183
+ # If it is a draft model, tp_group can be different
181
184
  self.initialize(min_per_gpu_memory)
182
185
 
183
186
  def initialize(self, min_per_gpu_memory: float):
@@ -230,7 +233,8 @@ class ModelRunner:
230
233
 
231
234
  if server_args.attention_backend is None:
232
235
  """
233
- We auto select the fastest attention backend according to the current offering
236
+ Auto select the fastest attention backend.
237
+
234
238
  1. Models with MHA Architecture (e.g: Llama, QWen)
235
239
  1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
236
240
  1.2 In other cases, we will use flashinfer if available, otherwise use triton.
@@ -240,6 +244,7 @@ class ModelRunner:
240
244
  """
241
245
 
242
246
  if not self.use_mla_backend:
247
+ # MHA architecture
243
248
  if (
244
249
  is_hopper_with_cuda_12_3()
245
250
  and is_no_spec_infer_or_topk_one(server_args)
@@ -251,6 +256,7 @@ class ModelRunner:
251
256
  "flashinfer" if is_flashinfer_available() else "triton"
252
257
  )
253
258
  else:
259
+ # MLA architecture
254
260
  if is_hopper_with_cuda_12_3():
255
261
  server_args.attention_backend = "fa3"
256
262
  else:
@@ -259,7 +265,6 @@ class ModelRunner:
259
265
  f"Attention backend not set. Use {server_args.attention_backend} backend by default."
260
266
  )
261
267
  elif self.use_mla_backend:
262
- # TODO: add MLA optimization on CPU
263
268
  if server_args.device != "cpu":
264
269
  if server_args.attention_backend in [
265
270
  "flashinfer",
@@ -275,7 +280,7 @@ class ModelRunner:
275
280
  f"Invalid attention backend for MLA: {server_args.attention_backend}"
276
281
  )
277
282
  else:
278
- raise ValueError(f"MLA optimization not supported on CPU.")
283
+ raise ValueError("MLA optimization not supported on CPU.")
279
284
 
280
285
  if (
281
286
  server_args.attention_backend == "fa3"
@@ -310,18 +315,6 @@ class ModelRunner:
310
315
  )
311
316
  server_args.chunked_prefill_size = -1
312
317
 
313
- if self.model_config.hf_config.architectures == [
314
- "Qwen2VLForConditionalGeneration"
315
- ] or self.model_config.hf_config.architectures == [
316
- "Qwen2_5_VLForConditionalGeneration"
317
- ]:
318
- # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
319
- logger.info("Automatically disable radix cache for qwen-vl series.")
320
- server_args.disable_radix_cache = True
321
-
322
- if server_args.enable_deepep_moe:
323
- logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
324
-
325
318
  if not self.use_mla_backend:
326
319
  server_args.disable_chunked_prefix_cache = True
327
320
  elif self.page_size > 1:
@@ -964,12 +957,6 @@ class ModelRunner:
964
957
  return
965
958
 
966
959
  if self.server_args.disable_cuda_graph:
967
- logger.warning(
968
- "\n\nCUDA Graph is DISABLED.\n"
969
- "This will cause significant performance degradation.\n"
970
- "CUDA Graph should almost never be disabled in most usage scenarios.\n"
971
- "If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.\n"
972
- )
973
960
  return
974
961
 
975
962
  tic = time.time()
@@ -382,8 +382,14 @@ class DeepseekModel(nn.Module):
382
382
  input_ids: torch.Tensor,
383
383
  positions: torch.Tensor,
384
384
  forward_batch: ForwardBatch,
385
+ input_embeds: torch.Tensor = None,
385
386
  ) -> torch.Tensor:
386
- hidden_states = self.embed_tokens(input_ids)
387
+
388
+ if input_embeds is None:
389
+ hidden_states = self.embed_tokens(input_ids)
390
+ else:
391
+ hidden_states = input_embeds
392
+
387
393
  residual = None
388
394
  for i in range(len(self.layers)):
389
395
  layer = self.layers[i]
@@ -416,14 +422,18 @@ class DeepseekForCausalLM(nn.Module):
416
422
  )
417
423
  self.logits_processor = LogitsProcessor(config)
418
424
 
425
+ def get_input_embeddings(self) -> nn.Embedding:
426
+ return self.model.embed_tokens
427
+
419
428
  @torch.no_grad()
420
429
  def forward(
421
430
  self,
422
431
  input_ids: torch.Tensor,
423
432
  positions: torch.Tensor,
424
433
  forward_batch: ForwardBatch,
434
+ input_embeds: torch.Tensor = None,
425
435
  ) -> torch.Tensor:
426
- hidden_states = self.model(input_ids, positions, forward_batch)
436
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
427
437
  return self.logits_processor(
428
438
  input_ids, hidden_states, self.lm_head, forward_batch
429
439
  )
@@ -13,12 +13,14 @@
13
13
  # ==============================================================================
14
14
 
15
15
  """Inference-only DeepSeek NextN Speculative Decoding."""
16
+ import logging
16
17
  from typing import Iterable, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  from torch import nn
20
21
  from transformers import PretrainedConfig
21
22
 
23
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
22
24
  from sglang.srt.layers.layernorm import RMSNorm
23
25
  from sglang.srt.layers.linear import ReplicatedLinear
24
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -51,6 +53,9 @@ else:
51
53
  from vllm._custom_ops import awq_dequantize
52
54
 
53
55
 
56
+ logger = logging.getLogger(__name__)
57
+
58
+
54
59
  class DeepseekModelNextN(nn.Module):
55
60
  def __init__(
56
61
  self,
@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
134
139
  ) -> None:
135
140
  nn.Module.__init__(self)
136
141
  self.config = config
142
+ self.tp_size = get_tensor_model_parallel_world_size()
137
143
  self.quant_config = quant_config
144
+ self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
138
145
 
139
146
  self.model = DeepseekModelNextN(
140
147
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
182
189
  ("gate_up_proj", "gate_proj", 0),
183
190
  ("gate_up_proj", "up_proj", 1),
184
191
  ]
192
+ if self.n_share_experts_fusion > 0:
193
+ logger.info(
194
+ f"Cloning {self.n_share_experts_fusion} "
195
+ "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
196
+ )
197
+ weights_list = list(weights)
198
+ weights_dict = dict(weights_list)
199
+ if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
200
+ suffix_list = [
201
+ "down_proj.weight",
202
+ "down_proj.weight_scale",
203
+ "gate_proj.weight",
204
+ "gate_proj.weight_scale",
205
+ "up_proj.weight",
206
+ "up_proj.weight_scale",
207
+ ]
208
+ else:
209
+ suffix_list = [
210
+ "down_proj.weight",
211
+ "down_proj.weight_scale_inv",
212
+ "gate_proj.weight",
213
+ "gate_proj.weight_scale_inv",
214
+ "up_proj.weight",
215
+ "up_proj.weight_scale_inv",
216
+ ]
217
+ names_to_remove = []
218
+ for suffix in suffix_list:
219
+ shared_expert_weight_name = (
220
+ f"model.layers.0.mlp.shared_experts.{suffix}"
221
+ )
222
+ for num_repeat in range(self.n_share_experts_fusion):
223
+ weights_list.append(
224
+ (
225
+ f"model.layers.0."
226
+ f"mlp.experts."
227
+ f"{self.config.n_routed_experts + num_repeat}"
228
+ f".{suffix}",
229
+ weights_dict[shared_expert_weight_name],
230
+ )
231
+ )
232
+ names_to_remove += [shared_expert_weight_name]
233
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
185
234
 
186
235
  # Params for weights, fp8 weight scales, fp8 activation scales
187
236
  # (param_name, weight_name, expert_id, shard_id)
@@ -190,8 +239,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
190
239
  ckpt_gate_proj_name="gate_proj",
191
240
  ckpt_down_proj_name="down_proj",
192
241
  ckpt_up_proj_name="up_proj",
193
- num_experts=self.config.n_routed_experts,
242
+ num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
243
+ )
244
+
245
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
246
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
247
+ self.config.q_lora_rank is not None
194
248
  )
249
+ cached_a_proj = {} if fuse_qkv_a_proj else None
195
250
 
196
251
  nextn_layer_prefix = "model.layers.0"
197
252
  nextn_spec_weight_names = [
@@ -264,11 +319,51 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
264
319
  if name.endswith(".bias") and name not in params_dict:
265
320
  continue
266
321
 
267
- param = params_dict[name]
268
- weight_loader = getattr(
269
- param, "weight_loader", default_weight_loader
270
- )
271
- weight_loader(param, loaded_weight)
322
+ # Handle fused_qkv_a_proj
323
+ if fuse_qkv_a_proj and (
324
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
325
+ ):
326
+ cached_a_proj[name] = loaded_weight
327
+ q_a_proj_name = (
328
+ name
329
+ if "q_a_proj" in name
330
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
331
+ )
332
+ kv_a_proj_name = (
333
+ name
334
+ if "kv_a_proj_with_mqa" in name
335
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
336
+ )
337
+
338
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
339
+ if (
340
+ q_a_proj_name in cached_a_proj
341
+ and kv_a_proj_name in cached_a_proj
342
+ ):
343
+
344
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
345
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
346
+ fused_weight = torch.cat(
347
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
348
+ )
349
+
350
+ param_name = name.replace(
351
+ "q_a_proj", "fused_qkv_a_proj_with_mqa"
352
+ )
353
+ param = params_dict[param_name]
354
+
355
+ weight_loader = getattr(
356
+ param, "weight_loader", default_weight_loader
357
+ )
358
+ weight_loader(param, fused_weight)
359
+ cached_a_proj.pop(q_a_proj_name)
360
+ cached_a_proj.pop(kv_a_proj_name)
361
+ else:
362
+ param = params_dict[name]
363
+ weight_loader = getattr(
364
+ param, "weight_loader", default_weight_loader
365
+ )
366
+ weight_loader(param, loaded_weight)
272
367
 
273
368
  self_attn = self.model.decoder.self_attn
274
369
  if hasattr(self_attn.kv_b_proj, "qweight"):