sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +12 -0
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +24 -14
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  29. sglang/srt/layers/moe/ep_moe/layer.py +12 -6
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/topk.py +35 -12
  32. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  33. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  34. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  35. sglang/srt/layers/quantization/mxfp4.py +9 -4
  36. sglang/srt/layers/quantization/utils.py +13 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  38. sglang/srt/layers/rotary_embedding.py +28 -1
  39. sglang/srt/layers/sampler.py +29 -5
  40. sglang/srt/managers/cache_controller.py +62 -96
  41. sglang/srt/managers/detokenizer_manager.py +43 -2
  42. sglang/srt/managers/io_struct.py +27 -0
  43. sglang/srt/managers/mm_utils.py +5 -1
  44. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  45. sglang/srt/managers/scheduler.py +36 -2
  46. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  47. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  48. sglang/srt/managers/tokenizer_manager.py +86 -39
  49. sglang/srt/mem_cache/chunk_cache.py +1 -1
  50. sglang/srt/mem_cache/hicache_storage.py +20 -3
  51. sglang/srt/mem_cache/hiradix_cache.py +75 -68
  52. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  53. sglang/srt/mem_cache/memory_pool.py +4 -0
  54. sglang/srt/mem_cache/memory_pool_host.py +2 -4
  55. sglang/srt/mem_cache/radix_cache.py +5 -4
  56. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  57. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
  58. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  59. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  60. sglang/srt/model_executor/model_runner.py +5 -4
  61. sglang/srt/model_loader/loader.py +15 -24
  62. sglang/srt/model_loader/utils.py +12 -0
  63. sglang/srt/models/deepseek_v2.py +26 -10
  64. sglang/srt/models/gpt_oss.py +0 -14
  65. sglang/srt/models/llama_eagle3.py +4 -0
  66. sglang/srt/models/longcat_flash.py +1015 -0
  67. sglang/srt/models/longcat_flash_nextn.py +691 -0
  68. sglang/srt/models/qwen2.py +26 -3
  69. sglang/srt/models/qwen2_5_vl.py +65 -41
  70. sglang/srt/models/qwen2_moe.py +22 -2
  71. sglang/srt/models/transformers.py +1 -1
  72. sglang/srt/multimodal/processors/base_processor.py +4 -2
  73. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  74. sglang/srt/server_args.py +112 -55
  75. sglang/srt/speculative/eagle_worker.py +28 -8
  76. sglang/srt/utils.py +14 -0
  77. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  78. sglang/version.py +1 -1
  79. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
  80. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
  81. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  82. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ from functools import wraps
7
7
  import psutil
8
8
  import torch
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
12
11
  from sglang.srt.utils import is_npu
13
12
 
@@ -464,8 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache):
464
463
  else:
465
464
  raise ValueError(f"Unsupported layout: {self.layout}")
466
465
 
467
- def get_buffer_meta(self, keys, indices):
468
- local_rank = get_tensor_model_parallel_rank()
466
+ def get_buffer_meta(self, keys, indices, local_rank):
469
467
  ptr_list = []
470
468
  key_list = []
471
469
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -704,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache):
704
702
  else:
705
703
  raise ValueError(f"Unsupported layout: {self.layout}")
706
704
 
707
- def get_buffer_meta(self, keys, indices):
705
+ def get_buffer_meta(self, keys, indices, local_rank):
708
706
  ptr_list = []
709
707
  key_list = []
710
708
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -62,7 +62,6 @@ class TreeNode:
62
62
  self.host_value: Optional[torch.Tensor] = None
63
63
  # store hash values of each pages
64
64
  self.hash_value: Optional[List[str]] = None
65
- self.backuped_storage = False
66
65
 
67
66
  self.id = TreeNode.counter if id is None else id
68
67
  TreeNode.counter += 1
@@ -195,7 +194,7 @@ class RadixCache(BasePrefixCache):
195
194
  last_host_node=last_node,
196
195
  )
197
196
 
198
- def insert(self, key: List, value=None):
197
+ def insert(self, key: List, value=None, chunked=False):
199
198
  if self.disable:
200
199
  return 0
201
200
 
@@ -240,7 +239,7 @@ class RadixCache(BasePrefixCache):
240
239
  self.req_to_token_pool.free(req.req_pool_idx)
241
240
  self.dec_lock_ref(req.last_node)
242
241
 
243
- def cache_unfinished_req(self, req: Req):
242
+ def cache_unfinished_req(self, req: Req, chunked=False):
244
243
  """Cache request when it is unfinished."""
245
244
  if self.disable:
246
245
  return
@@ -261,7 +260,9 @@ class RadixCache(BasePrefixCache):
261
260
  page_aligned_token_ids = token_ids[:page_aligned_len]
262
261
 
263
262
  # Radix Cache takes one ref in memory pool
264
- new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
263
+ new_prefix_len = self.insert(
264
+ page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
265
+ )
265
266
  self.token_to_kv_pool_allocator.free(
266
267
  kv_indices[len(req.prefix_indices) : new_prefix_len]
267
268
  )
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
181
181
  self.dec_lock_ref(req.last_node)
182
182
  self.req_to_token_pool.free(req.req_pool_idx)
183
183
 
184
- def cache_unfinished_req(self, req: Req):
184
+ def cache_unfinished_req(self, req: Req, chunked=False):
185
185
  """Cache request when it is unfinished."""
186
186
  assert req.req_pool_idx is not None
187
187
  token_ids = req.fill_ids
@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
125
125
  entries: int,
126
126
  dtype: torch.dtype,
127
127
  metadata_client: Hf3fsMetadataInterface,
128
+ is_mla_model: bool = False,
128
129
  ):
129
130
  self.rank = rank
130
131
  self.file_path = file_path
@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
134
135
  self.entries = entries
135
136
  self.dtype = dtype
136
137
  self.metadata_client = metadata_client
137
-
138
+ self.is_mla_model = is_mla_model
138
139
  self.numel = self.bytes_per_page // self.dtype.itemsize
139
140
  self.num_pages = self.file_size // self.bytes_per_page
141
+ self.skip_backup = False
142
+ if self.is_mla_model and self.rank != 0:
143
+ self.skip_backup = True
144
+ self.rank = 0
140
145
 
141
146
  logger.info(
142
147
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
209
214
  raise ValueError(f"Missing required keys in config: {missing_keys}")
210
215
 
211
216
  # Choose metadata client based on configuration
217
+ is_mla_model = False
212
218
  if "metadata_server_url" in config and config["metadata_server_url"]:
213
219
  # Use global metadata client to connect to metadata server
214
220
  metadata_server_url = config["metadata_server_url"]
215
221
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
222
+
223
+ # Enable MLA optimization only when using the global metadata client
224
+ is_mla_model = storage_config.is_mla_model if storage_config else False
216
225
  logger.info(
217
226
  f"Using global metadata client with server url: {metadata_server_url}"
218
227
  )
@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
222
231
 
223
232
  return HiCacheHF3FS(
224
233
  rank=rank,
225
- file_path=f"{config['file_path_prefix']}.{rank}.bin",
234
+ # Let all ranks use the same file path for MLA model
235
+ file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
226
236
  file_size=int(config["file_size"]),
227
237
  numjobs=int(config["numjobs"]),
228
238
  bytes_per_page=bytes_per_page,
229
239
  entries=int(config["entries"]),
230
240
  dtype=dtype,
231
241
  metadata_client=metadata_client,
242
+ is_mla_model=is_mla_model,
232
243
  )
233
244
 
234
245
  def get(
@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
312
323
  target_locations: Optional[Any] = None,
313
324
  target_sizes: Optional[Any] = None,
314
325
  ) -> bool:
326
+ # In MLA backend, only one rank needs to backup the KV cache
327
+ if self.skip_backup:
328
+ return True
329
+
315
330
  # Todo: Add prefix block's hash key
316
331
  key_with_prefix = [(key, "") for key in keys]
317
332
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -363,18 +378,29 @@ class HiCacheHF3FS(HiCacheStorage):
363
378
 
364
379
  return all(results)
365
380
 
366
- @synchronized()
367
381
  def delete(self, key: str) -> None:
368
382
  self.metadata_client.delete_keys(self.rank, [key])
369
383
 
370
- @synchronized()
371
384
  def exists(self, key: str) -> bool:
372
385
  result = self.metadata_client.exists(self.rank, [key])
373
386
  return result[0] if result else False
374
387
 
375
- @synchronized()
376
- def clear(self) -> None:
377
- self.metadata_client.clear(self.rank)
388
+ def batch_exists(self, keys: List[str]) -> int:
389
+ results = self.metadata_client.exists(self.rank, keys)
390
+ for i in range(len(keys)):
391
+ if not results[i]:
392
+ return i
393
+
394
+ return len(keys)
395
+
396
+ def clear(self) -> bool:
397
+ try:
398
+ self.metadata_client.clear(self.rank)
399
+ logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
400
+ return True
401
+ except Exception as e:
402
+ logger.error(f"Failed to clear HiCacheHF3FS: {e}")
403
+ return False
378
404
 
379
405
  def close(self) -> None:
380
406
  try:
@@ -159,6 +159,7 @@ class MooncakeStore(HiCacheStorage):
159
159
  def batch_set(
160
160
  self,
161
161
  keys: List[str],
162
+ values: Optional[List[torch.Tensor]] = None,
162
163
  target_location: Optional[List[int]] = None,
163
164
  target_sizes: Optional[List[int]] = None,
164
165
  ) -> bool:
@@ -253,7 +254,7 @@ class MooncakeStore(HiCacheStorage):
253
254
  pass
254
255
 
255
256
  def clear(self) -> None:
256
- raise (NotImplementedError)
257
+ self.store.remove_all()
257
258
 
258
259
  def _put_batch_zero_copy_impl(
259
260
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
464
464
  self.req_to_token_pool.free(req.req_pool_idx)
465
465
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
466
 
467
- def cache_unfinished_req(self, req: Req) -> None:
467
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
468
468
  """Cache request when it is unfinished."""
469
469
  if self.disable:
470
470
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -307,7 +307,10 @@ class ModelRunner:
307
307
  model_num_layers = (
308
308
  self.model_config.num_nextn_predict_layers
309
309
  if self.is_draft_worker and model_has_mtp_layers
310
- else self.model_config.num_hidden_layers
310
+ else max(
311
+ self.model_config.num_hidden_layers,
312
+ self.model_config.num_attention_layers,
313
+ )
311
314
  )
312
315
  self.start_layer = getattr(self.model, "start_layer", 0)
313
316
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
@@ -1440,14 +1443,12 @@ class ModelRunner:
1440
1443
  else self.server_args.attention_backend
1441
1444
  )
1442
1445
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1443
- assert (
1444
- self.server_args.speculative_algorithm is None
1445
- ), "Currently HybridAttentionBackend does not support speculative decoding."
1446
1446
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1447
1447
  HybridAttnBackend,
1448
1448
  )
1449
1449
 
1450
1450
  attn_backend = HybridAttnBackend(
1451
+ self,
1451
1452
  decode_backend=self._get_attention_backend_from_str(
1452
1453
  self.decode_attention_backend_str
1453
1454
  ),
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.model_loader.utils import (
44
44
  get_model_architecture,
45
+ post_load_weights,
45
46
  set_default_torch_dtype,
46
47
  )
47
48
  from sglang.srt.model_loader.weight_utils import (
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
600
601
  # random values to the weights.
601
602
  initialize_dummy_weights(model)
602
603
 
603
- # Model weight loading consists of two stages:
604
- # 1. Initial weight loading.
605
- # 2. Post-processing of weights, including assigning specific member variables.
606
- # For `dummy_init`, only the second stage is required.
607
- if hasattr(model, "post_load_weights"):
608
- if (
609
- model_config.hf_config.architectures[0]
610
- == "DeepseekV3ForCausalLMNextN"
611
- ):
612
- model.post_load_weights(is_nextn=True)
613
- else:
614
- model.post_load_weights()
604
+ post_load_weights(model, model_config)
615
605
 
616
606
  return model.eval()
617
607
 
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
751
741
  state_dict.pop(key)
752
742
  if state_dict:
753
743
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
744
+
745
+ post_load_weights(model, model_config)
746
+
754
747
  return model.eval()
755
748
 
756
749
  @staticmethod
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
1421
1414
  # ignore hidden files
1422
1415
  if file_name.startswith("."):
1423
1416
  continue
1424
- if os.path.splitext(file_name)[1] not in (
1425
- ".bin",
1426
- ".pt",
1427
- ".safetensors",
1428
- ):
1417
+ if os.path.splitext(file_name)[1] in (".json", ".py"):
1429
1418
  file_path = os.path.join(root, file_name)
1430
1419
  with open(file_path, encoding="utf-8") as file:
1431
1420
  file_content = file.read()
1432
1421
  f_key = f"{model_name}/files/{file_name}"
1433
1422
  client.setstr(f_key, file_content)
1434
1423
 
1435
- def _load_model_from_remote_kv(self, model: nn.Module, client):
1424
+ def _load_model_from_remote_kv(
1425
+ self, model: nn.Module, model_config: ModelConfig, client
1426
+ ):
1436
1427
  for _, module in model.named_modules():
1437
1428
  quant_method = getattr(module, "quant_method", None)
1438
1429
  if quant_method is not None:
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
1460
1451
  if state_dict:
1461
1452
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1462
1453
 
1454
+ post_load_weights(model, model_config)
1455
+
1463
1456
  def _load_model_from_remote_fs(
1464
1457
  self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1465
1458
  ) -> nn.Module:
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
1501
1494
  with set_default_torch_dtype(model_config.dtype):
1502
1495
  with torch.device(device_config.device):
1503
1496
  model = _initialize_model(model_config, self.load_config)
1504
- for _, module in model.named_modules():
1505
- quant_method = getattr(module, "quant_method", None)
1506
- if quant_method is not None:
1507
- quant_method.process_weights_after_loading(module)
1508
1497
 
1509
- with create_remote_connector(model_weights, device_config.device) as client:
1498
+ with create_remote_connector(
1499
+ model_weights, device=device_config.device
1500
+ ) as client:
1510
1501
  connector_type = get_connector_type(client)
1511
1502
  if connector_type == ConnectorType.KV:
1512
- self._load_model_from_remote_kv(model, client)
1503
+ self._load_model_from_remote_kv(model, model_config, client)
1513
1504
  elif connector_type == ConnectorType.FS:
1514
1505
  self._load_model_from_remote_fs(
1515
1506
  model, client, model_config, device_config
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
105
105
 
106
106
  def get_architecture_class_name(model_config: ModelConfig) -> str:
107
107
  return get_model_architecture(model_config)[1]
108
+
109
+
110
+ def post_load_weights(model: nn.Module, model_config: ModelConfig):
111
+ # Model weight loading consists of two stages:
112
+ # 1. Initial weight loading.
113
+ # 2. Post-processing of weights, including assigning specific member variables.
114
+ # For `dummy_init`, only the second stage is required.
115
+ if hasattr(model, "post_load_weights"):
116
+ if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
117
+ model.post_load_weights(is_nextn=True)
118
+ else:
119
+ model.post_load_weights()
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
114
114
  is_flashinfer_available,
115
115
  is_hip,
116
116
  is_non_idle_and_non_empty,
117
+ is_npu,
117
118
  is_sm100_supported,
118
119
  log_info_on_rank0,
119
120
  make_layers,
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
122
123
 
123
124
  _is_hip = is_hip()
124
125
  _is_cuda = is_cuda()
126
+ _is_npu = is_npu()
125
127
  _is_fp8_fnuz = is_fp8_fnuz()
126
128
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
127
129
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1181
1183
  k[..., : self.qk_nope_head_dim] = k_nope
1182
1184
  k[..., self.qk_nope_head_dim :] = k_pe
1183
1185
 
1184
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1185
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
+ if not _is_npu:
1187
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1188
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
1189
 
1187
- # Save latent cache
1188
- forward_batch.token_to_kv_pool.set_kv_buffer(
1189
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1190
- )
1190
+ # Save latent cache
1191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1192
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1193
+ )
1194
+ else:
1195
+ # To reduce a time-costing split operation
1196
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1197
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1198
+ )
1191
1199
 
1192
1200
  return q, k, v, forward_batch
1193
1201
 
@@ -2406,18 +2414,26 @@ class DeepseekV2ForCausalLM(nn.Module):
2406
2414
  )
2407
2415
 
2408
2416
  num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2417
+
2409
2418
  for layer_id in range(num_hidden_layers):
2410
2419
  if is_nextn:
2411
2420
  layer = self.model.decoder
2412
2421
  else:
2413
2422
  layer = self.model.layers[layer_id]
2414
2423
 
2415
- for module in [
2416
- layer.self_attn.fused_qkv_a_proj_with_mqa,
2417
- layer.self_attn.q_b_proj,
2424
+ module_list = [
2418
2425
  layer.self_attn.kv_b_proj,
2419
2426
  layer.self_attn.o_proj,
2420
- ]:
2427
+ ]
2428
+
2429
+ if self.config.q_lora_rank is not None:
2430
+ module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
2431
+ module_list.append(layer.self_attn.q_b_proj)
2432
+ else:
2433
+ module_list.append(layer.self_attn.kv_a_proj_with_mqa)
2434
+ module_list.append(layer.self_attn.q_proj)
2435
+
2436
+ for module in module_list:
2421
2437
  requant_weight_ue8m0_inplace(
2422
2438
  module.weight, module.weight_scale_inv, weight_block_size
2423
2439
  )
@@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module):
1029
1029
  )
1030
1030
 
1031
1031
  params_dict = dict(self.named_parameters())
1032
- params_checker = {k: False for k, v in params_dict.items()}
1033
-
1034
- for other_loaded_param_name in other_loaded_param_names:
1035
- params_checker[other_loaded_param_name] = True
1036
1032
 
1037
1033
  for name, loaded_weight in weights:
1038
1034
  loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
@@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module):
1069
1065
  param = params_dict[name]
1070
1066
  weight_loader = param.weight_loader
1071
1067
  weight_loader(param, loaded_weight, shard_id)
1072
- params_checker[name] = True
1073
1068
  break
1074
1069
  else:
1075
1070
  for mapping in expert_params_mapping:
@@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module):
1092
1087
  name,
1093
1088
  shard_id=shard_id,
1094
1089
  )
1095
- params_checker[name] = True
1096
1090
  break
1097
1091
  else:
1098
1092
  if name.endswith(".bias") and name not in params_dict:
@@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module):
1111
1105
  param, "weight_loader", default_weight_loader
1112
1106
  )
1113
1107
  weight_loader(param, loaded_weight)
1114
- params_checker[name] = True
1115
1108
  else:
1116
1109
  logger.warning(f"Parameter {name} not found in params_dict")
1117
1110
 
1118
- not_loaded_params = [k for k, v in params_checker.items() if not v]
1119
- if tp_rank == 0:
1120
- if len(not_loaded_params) > 0:
1121
- raise Exception(f"Not all parameters loaded: {not_loaded_params}")
1122
- else:
1123
- logging.info("All parameters loaded successfully.")
1124
-
1125
1111
  def get_embed_and_head(self):
1126
1112
  return self.model.embed_tokens.weight, self.lm_head.weight
1127
1113
 
@@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
185
185
  )
186
186
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
187
187
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
188
+ self.load_lm_head_from_target = False
188
189
  if self.config.tie_word_embeddings:
189
190
  self.lm_head = self.model.embed_tokens
190
191
  else:
192
+ if config.draft_vocab_size is None:
193
+ self.load_lm_head_from_target = True
194
+ config.draft_vocab_size = config.vocab_size
191
195
  self.lm_head = ParallelLMHead(
192
196
  config.draft_vocab_size,
193
197
  config.hidden_size,