sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -113,6 +113,8 @@ def synchronized():
113
113
 
114
114
 
115
115
  class HiCacheHF3FS(HiCacheStorage):
116
+ """HiCache backend that stores KV cache pages in HF3FS files."""
117
+
116
118
  default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
117
119
 
118
120
  def __init__(
@@ -125,6 +127,7 @@ class HiCacheHF3FS(HiCacheStorage):
125
127
  entries: int,
126
128
  dtype: torch.dtype,
127
129
  metadata_client: Hf3fsMetadataInterface,
130
+ is_mla_model: bool = False,
128
131
  ):
129
132
  self.rank = rank
130
133
  self.file_path = file_path
@@ -134,9 +137,13 @@ class HiCacheHF3FS(HiCacheStorage):
134
137
  self.entries = entries
135
138
  self.dtype = dtype
136
139
  self.metadata_client = metadata_client
137
-
140
+ self.is_mla_model = is_mla_model
138
141
  self.numel = self.bytes_per_page // self.dtype.itemsize
139
142
  self.num_pages = self.file_size // self.bytes_per_page
143
+ self.skip_backup = False
144
+ if self.is_mla_model and self.rank != 0:
145
+ self.skip_backup = True
146
+ self.rank = 0
140
147
 
141
148
  logger.info(
142
149
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
@@ -171,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
171
178
  dtype: torch.dtype,
172
179
  storage_config: HiCacheStorageConfig = None,
173
180
  ) -> "HiCacheHF3FS":
181
+ """Create a HiCacheHF3FS instance from environment configuration.
182
+
183
+ Environment:
184
+ - Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
185
+ - Falls back to a local single-machine config when the env var is not set.
186
+
187
+ Raises:
188
+ ValueError: If MLA Model is requested without global metadata server or required keys are missing.
189
+ """
174
190
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
175
191
  Hf3fsGlobalMetadataClient,
176
192
  Hf3fsLocalMetadataClient,
177
193
  )
178
194
 
179
- rank = storage_config.tp_rank if storage_config is not None else 0
195
+ if storage_config is not None:
196
+ rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
197
+ else:
198
+ rank, is_mla_model = 0, False
199
+
200
+ mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
180
201
 
181
202
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
182
203
  if not config_path:
204
+ if is_mla_model:
205
+ raise ValueError(mla_unsupported_msg)
206
+
183
207
  return HiCacheHF3FS(
184
208
  rank=rank,
185
209
  file_path=f"/data/hicache.{rank}.bin",
@@ -209,26 +233,34 @@ class HiCacheHF3FS(HiCacheStorage):
209
233
  raise ValueError(f"Missing required keys in config: {missing_keys}")
210
234
 
211
235
  # Choose metadata client based on configuration
212
- if "metadata_server_url" in config and config["metadata_server_url"]:
236
+ if config.get("metadata_server_url"):
213
237
  # Use global metadata client to connect to metadata server
214
238
  metadata_server_url = config["metadata_server_url"]
215
239
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
240
+
216
241
  logger.info(
217
242
  f"Using global metadata client with server url: {metadata_server_url}"
218
243
  )
219
244
  else:
245
+ # Enable MLA optimization only when using the global metadata client
246
+ if is_mla_model:
247
+ raise ValueError(mla_unsupported_msg)
248
+
220
249
  # Use local metadata client for single-machine deployment
221
250
  metadata_client = Hf3fsLocalMetadataClient()
222
251
 
252
+ rank_for_path = 0 if is_mla_model else rank
223
253
  return HiCacheHF3FS(
224
254
  rank=rank,
225
- file_path=f"{config['file_path_prefix']}.{rank}.bin",
255
+ # Let all ranks use the same file path for MLA model
256
+ file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
226
257
  file_size=int(config["file_size"]),
227
258
  numjobs=int(config["numjobs"]),
228
259
  bytes_per_page=bytes_per_page,
229
260
  entries=int(config["entries"]),
230
261
  dtype=dtype,
231
262
  metadata_client=metadata_client,
263
+ is_mla_model=is_mla_model,
232
264
  )
233
265
 
234
266
  def get(
@@ -312,6 +344,10 @@ class HiCacheHF3FS(HiCacheStorage):
312
344
  target_locations: Optional[Any] = None,
313
345
  target_sizes: Optional[Any] = None,
314
346
  ) -> bool:
347
+ # In MLA backend, only one rank needs to backup the KV cache
348
+ if self.skip_backup:
349
+ return True
350
+
315
351
  # Todo: Add prefix block's hash key
316
352
  key_with_prefix = [(key, "") for key in keys]
317
353
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -363,18 +399,29 @@ class HiCacheHF3FS(HiCacheStorage):
363
399
 
364
400
  return all(results)
365
401
 
366
- @synchronized()
367
402
  def delete(self, key: str) -> None:
368
403
  self.metadata_client.delete_keys(self.rank, [key])
369
404
 
370
- @synchronized()
371
405
  def exists(self, key: str) -> bool:
372
406
  result = self.metadata_client.exists(self.rank, [key])
373
407
  return result[0] if result else False
374
408
 
375
- @synchronized()
376
- def clear(self) -> None:
377
- self.metadata_client.clear(self.rank)
409
+ def batch_exists(self, keys: List[str]) -> int:
410
+ results = self.metadata_client.exists(self.rank, keys)
411
+ for i in range(len(keys)):
412
+ if not results[i]:
413
+ return i
414
+
415
+ return len(keys)
416
+
417
+ def clear(self) -> bool:
418
+ try:
419
+ self.metadata_client.clear(self.rank)
420
+ logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
421
+ return True
422
+ except Exception as e:
423
+ logger.error(f"Failed to clear HiCacheHF3FS: {e}")
424
+ return False
378
425
 
379
426
  def close(self) -> None:
380
427
  try:
@@ -159,6 +159,7 @@ class MooncakeStore(HiCacheStorage):
159
159
  def batch_set(
160
160
  self,
161
161
  keys: List[str],
162
+ values: Optional[List[torch.Tensor]] = None,
162
163
  target_location: Optional[List[int]] = None,
163
164
  target_sizes: Optional[List[int]] = None,
164
165
  ) -> bool:
@@ -253,7 +254,7 @@ class MooncakeStore(HiCacheStorage):
253
254
  pass
254
255
 
255
256
  def clear(self) -> None:
256
- raise (NotImplementedError)
257
+ self.store.remove_all()
257
258
 
258
259
  def _put_batch_zero_copy_impl(
259
260
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
464
464
  self.req_to_token_pool.free(req.req_pool_idx)
465
465
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
466
 
467
- def cache_unfinished_req(self, req: Req) -> None:
467
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
468
468
  """Cache request when it is unfinished."""
469
469
  if self.disable:
470
470
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -307,7 +307,10 @@ class ModelRunner:
307
307
  model_num_layers = (
308
308
  self.model_config.num_nextn_predict_layers
309
309
  if self.is_draft_worker and model_has_mtp_layers
310
- else self.model_config.num_hidden_layers
310
+ else max(
311
+ self.model_config.num_hidden_layers,
312
+ self.model_config.num_attention_layers,
313
+ )
311
314
  )
312
315
  self.start_layer = getattr(self.model, "start_layer", 0)
313
316
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
@@ -1440,14 +1443,12 @@ class ModelRunner:
1440
1443
  else self.server_args.attention_backend
1441
1444
  )
1442
1445
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1443
- assert (
1444
- self.server_args.speculative_algorithm is None
1445
- ), "Currently HybridAttentionBackend does not support speculative decoding."
1446
1446
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1447
1447
  HybridAttnBackend,
1448
1448
  )
1449
1449
 
1450
1450
  attn_backend = HybridAttnBackend(
1451
+ self,
1451
1452
  decode_backend=self._get_attention_backend_from_str(
1452
1453
  self.decode_attention_backend_str
1453
1454
  ),
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.model_loader.utils import (
44
44
  get_model_architecture,
45
+ post_load_weights,
45
46
  set_default_torch_dtype,
46
47
  )
47
48
  from sglang.srt.model_loader.weight_utils import (
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
600
601
  # random values to the weights.
601
602
  initialize_dummy_weights(model)
602
603
 
603
- # Model weight loading consists of two stages:
604
- # 1. Initial weight loading.
605
- # 2. Post-processing of weights, including assigning specific member variables.
606
- # For `dummy_init`, only the second stage is required.
607
- if hasattr(model, "post_load_weights"):
608
- if (
609
- model_config.hf_config.architectures[0]
610
- == "DeepseekV3ForCausalLMNextN"
611
- ):
612
- model.post_load_weights(is_nextn=True)
613
- else:
614
- model.post_load_weights()
604
+ post_load_weights(model, model_config)
615
605
 
616
606
  return model.eval()
617
607
 
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
751
741
  state_dict.pop(key)
752
742
  if state_dict:
753
743
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
744
+
745
+ post_load_weights(model, model_config)
746
+
754
747
  return model.eval()
755
748
 
756
749
  @staticmethod
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
1421
1414
  # ignore hidden files
1422
1415
  if file_name.startswith("."):
1423
1416
  continue
1424
- if os.path.splitext(file_name)[1] not in (
1425
- ".bin",
1426
- ".pt",
1427
- ".safetensors",
1428
- ):
1417
+ if os.path.splitext(file_name)[1] in (".json", ".py"):
1429
1418
  file_path = os.path.join(root, file_name)
1430
1419
  with open(file_path, encoding="utf-8") as file:
1431
1420
  file_content = file.read()
1432
1421
  f_key = f"{model_name}/files/{file_name}"
1433
1422
  client.setstr(f_key, file_content)
1434
1423
 
1435
- def _load_model_from_remote_kv(self, model: nn.Module, client):
1424
+ def _load_model_from_remote_kv(
1425
+ self, model: nn.Module, model_config: ModelConfig, client
1426
+ ):
1436
1427
  for _, module in model.named_modules():
1437
1428
  quant_method = getattr(module, "quant_method", None)
1438
1429
  if quant_method is not None:
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
1460
1451
  if state_dict:
1461
1452
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1462
1453
 
1454
+ post_load_weights(model, model_config)
1455
+
1463
1456
  def _load_model_from_remote_fs(
1464
1457
  self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1465
1458
  ) -> nn.Module:
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
1501
1494
  with set_default_torch_dtype(model_config.dtype):
1502
1495
  with torch.device(device_config.device):
1503
1496
  model = _initialize_model(model_config, self.load_config)
1504
- for _, module in model.named_modules():
1505
- quant_method = getattr(module, "quant_method", None)
1506
- if quant_method is not None:
1507
- quant_method.process_weights_after_loading(module)
1508
1497
 
1509
- with create_remote_connector(model_weights, device_config.device) as client:
1498
+ with create_remote_connector(
1499
+ model_weights, device=device_config.device
1500
+ ) as client:
1510
1501
  connector_type = get_connector_type(client)
1511
1502
  if connector_type == ConnectorType.KV:
1512
- self._load_model_from_remote_kv(model, client)
1503
+ self._load_model_from_remote_kv(model, model_config, client)
1513
1504
  elif connector_type == ConnectorType.FS:
1514
1505
  self._load_model_from_remote_fs(
1515
1506
  model, client, model_config, device_config
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
105
105
 
106
106
  def get_architecture_class_name(model_config: ModelConfig) -> str:
107
107
  return get_model_architecture(model_config)[1]
108
+
109
+
110
+ def post_load_weights(model: nn.Module, model_config: ModelConfig):
111
+ # Model weight loading consists of two stages:
112
+ # 1. Initial weight loading.
113
+ # 2. Post-processing of weights, including assigning specific member variables.
114
+ # For `dummy_init`, only the second stage is required.
115
+ if hasattr(model, "post_load_weights"):
116
+ if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
117
+ model.post_load_weights(is_nextn=True)
118
+ else:
119
+ model.post_load_weights()
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
114
114
  is_flashinfer_available,
115
115
  is_hip,
116
116
  is_non_idle_and_non_empty,
117
+ is_npu,
117
118
  is_sm100_supported,
118
119
  log_info_on_rank0,
119
120
  make_layers,
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
122
123
 
123
124
  _is_hip = is_hip()
124
125
  _is_cuda = is_cuda()
126
+ _is_npu = is_npu()
125
127
  _is_fp8_fnuz = is_fp8_fnuz()
126
128
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
127
129
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1181
1183
  k[..., : self.qk_nope_head_dim] = k_nope
1182
1184
  k[..., self.qk_nope_head_dim :] = k_pe
1183
1185
 
1184
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1185
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
+ if not _is_npu:
1187
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1188
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
1189
 
1187
- # Save latent cache
1188
- forward_batch.token_to_kv_pool.set_kv_buffer(
1189
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1190
- )
1190
+ # Save latent cache
1191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1192
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1193
+ )
1194
+ else:
1195
+ # To reduce a time-costing split operation
1196
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1197
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1198
+ )
1191
1199
 
1192
1200
  return q, k, v, forward_batch
1193
1201
 
@@ -2177,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2177
2185
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2178
2186
  elif get_moe_expert_parallel_world_size() > 1:
2179
2187
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2188
+ elif self.quant_config.get_name() == "w4afp8":
2189
+ disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2180
2190
 
2181
2191
  if disable_reason is not None:
2182
2192
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2406,18 +2416,26 @@ class DeepseekV2ForCausalLM(nn.Module):
2406
2416
  )
2407
2417
 
2408
2418
  num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2419
+
2409
2420
  for layer_id in range(num_hidden_layers):
2410
2421
  if is_nextn:
2411
2422
  layer = self.model.decoder
2412
2423
  else:
2413
2424
  layer = self.model.layers[layer_id]
2414
2425
 
2415
- for module in [
2416
- layer.self_attn.fused_qkv_a_proj_with_mqa,
2417
- layer.self_attn.q_b_proj,
2426
+ module_list = [
2418
2427
  layer.self_attn.kv_b_proj,
2419
2428
  layer.self_attn.o_proj,
2420
- ]:
2429
+ ]
2430
+
2431
+ if self.config.q_lora_rank is not None:
2432
+ module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
2433
+ module_list.append(layer.self_attn.q_b_proj)
2434
+ else:
2435
+ module_list.append(layer.self_attn.kv_a_proj_with_mqa)
2436
+ module_list.append(layer.self_attn.q_proj)
2437
+
2438
+ for module in module_list:
2421
2439
  requant_weight_ue8m0_inplace(
2422
2440
  module.weight, module.weight_scale_inv, weight_block_size
2423
2441
  )
@@ -2480,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
2480
2498
  ckpt_up_proj_name="up_proj",
2481
2499
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2482
2500
  )
2501
+ # Params for special naming rules in mixed-precision models, for example:
2502
+ # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
2503
+ # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2483
2504
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2484
2505
  expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2485
2506
  num_experts=self.config.n_routed_experts
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
193
193
  return ans
194
194
 
195
195
 
196
- def _enable_fused_set_kv_buffer():
197
- return _is_cuda
196
+ def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
197
+ """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
198
+ return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
198
199
 
199
200
 
200
201
  # TODO maybe move to a model-common utils
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
341
342
  layer=self.attn,
342
343
  forward_batch=forward_batch,
343
344
  )
344
- if _enable_fused_set_kv_buffer()
345
+ if _enable_fused_set_kv_buffer(forward_batch)
345
346
  else None
346
347
  ),
347
348
  )
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
355
356
  attn_output = self.attn(
356
357
  *inner_state,
357
358
  sinks=self.sinks,
358
- save_kv_cache=not _enable_fused_set_kv_buffer(),
359
+ save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
359
360
  )
360
361
  output, _ = self.o_proj(attn_output)
361
362
  return output
@@ -1029,10 +1030,6 @@ class GptOssForCausalLM(nn.Module):
1029
1030
  )
1030
1031
 
1031
1032
  params_dict = dict(self.named_parameters())
1032
- params_checker = {k: False for k, v in params_dict.items()}
1033
-
1034
- for other_loaded_param_name in other_loaded_param_names:
1035
- params_checker[other_loaded_param_name] = True
1036
1033
 
1037
1034
  for name, loaded_weight in weights:
1038
1035
  loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
@@ -1069,7 +1066,6 @@ class GptOssForCausalLM(nn.Module):
1069
1066
  param = params_dict[name]
1070
1067
  weight_loader = param.weight_loader
1071
1068
  weight_loader(param, loaded_weight, shard_id)
1072
- params_checker[name] = True
1073
1069
  break
1074
1070
  else:
1075
1071
  for mapping in expert_params_mapping:
@@ -1092,7 +1088,6 @@ class GptOssForCausalLM(nn.Module):
1092
1088
  name,
1093
1089
  shard_id=shard_id,
1094
1090
  )
1095
- params_checker[name] = True
1096
1091
  break
1097
1092
  else:
1098
1093
  if name.endswith(".bias") and name not in params_dict:
@@ -1111,17 +1106,9 @@ class GptOssForCausalLM(nn.Module):
1111
1106
  param, "weight_loader", default_weight_loader
1112
1107
  )
1113
1108
  weight_loader(param, loaded_weight)
1114
- params_checker[name] = True
1115
1109
  else:
1116
1110
  logger.warning(f"Parameter {name} not found in params_dict")
1117
1111
 
1118
- not_loaded_params = [k for k, v in params_checker.items() if not v]
1119
- if tp_rank == 0:
1120
- if len(not_loaded_params) > 0:
1121
- raise Exception(f"Not all parameters loaded: {not_loaded_params}")
1122
- else:
1123
- logging.info("All parameters loaded successfully.")
1124
-
1125
1112
  def get_embed_and_head(self):
1126
1113
  return self.model.embed_tokens.weight, self.lm_head.weight
1127
1114
 
@@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
185
185
  )
186
186
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
187
187
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
188
+ self.load_lm_head_from_target = False
188
189
  if self.config.tie_word_embeddings:
189
190
  self.lm_head = self.model.embed_tokens
190
191
  else:
192
+ if config.draft_vocab_size is None:
193
+ self.load_lm_head_from_target = True
194
+ config.draft_vocab_size = config.vocab_size
191
195
  self.lm_head = ParallelLMHead(
192
196
  config.draft_vocab_size,
193
197
  config.hidden_size,