sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/bench_serving.py +11 -3
  3. sglang/lang/backend/openai.py +10 -0
  4. sglang/srt/configs/model_config.py +11 -2
  5. sglang/srt/constrained/xgrammar_backend.py +6 -0
  6. sglang/srt/layers/attention/__init__.py +0 -1
  7. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  9. sglang/srt/layers/logits_processor.py +30 -2
  10. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
  11. sglang/srt/layers/moe/topk.py +14 -0
  12. sglang/srt/layers/quantization/fp8.py +42 -2
  13. sglang/srt/layers/quantization/fp8_kernel.py +91 -18
  14. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  15. sglang/srt/managers/io_struct.py +29 -8
  16. sglang/srt/managers/schedule_batch.py +22 -15
  17. sglang/srt/managers/schedule_policy.py +1 -1
  18. sglang/srt/managers/scheduler.py +71 -34
  19. sglang/srt/managers/session_controller.py +102 -27
  20. sglang/srt/managers/tokenizer_manager.py +95 -55
  21. sglang/srt/managers/tp_worker.py +7 -0
  22. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  23. sglang/srt/model_executor/forward_batch_info.py +42 -3
  24. sglang/srt/model_executor/model_runner.py +4 -6
  25. sglang/srt/model_loader/loader.py +22 -11
  26. sglang/srt/models/gemma2.py +19 -0
  27. sglang/srt/models/llama.py +13 -2
  28. sglang/srt/models/llama_eagle.py +132 -0
  29. sglang/srt/openai_api/adapter.py +79 -2
  30. sglang/srt/openai_api/protocol.py +50 -0
  31. sglang/srt/sampling/sampling_params.py +9 -2
  32. sglang/srt/server.py +45 -39
  33. sglang/srt/server_args.py +17 -30
  34. sglang/srt/speculative/spec_info.py +19 -0
  35. sglang/srt/utils.py +62 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
  38. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
  39. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
25
+ from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
53
53
  OpenSessionReqInput,
54
54
  OpenSessionReqOutput,
55
55
  ProfileReq,
56
+ SessionParams,
56
57
  TokenizedEmbeddingReqInput,
57
58
  TokenizedGenerateReqInput,
58
59
  UpdateWeightFromDiskReqInput,
59
60
  UpdateWeightFromDiskReqOutput,
60
61
  UpdateWeightsFromDistributedReqInput,
61
62
  UpdateWeightsFromDistributedReqOutput,
63
+ UpdateWeightsFromTensorReqInput,
64
+ UpdateWeightsFromTensorReqOutput,
62
65
  )
63
66
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
64
67
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -173,6 +176,18 @@ class TokenizerManager:
173
176
 
174
177
  # Others
175
178
  self.gracefully_exit = False
179
+ self.init_weights_update_group_communicator = _Communicator(
180
+ self.send_to_scheduler, server_args.dp_size
181
+ )
182
+ self.update_weights_from_distributed_communicator = _Communicator(
183
+ self.send_to_scheduler, server_args.dp_size
184
+ )
185
+ self.update_weights_from_tensor_communicator = _Communicator(
186
+ self.send_to_scheduler, server_args.dp_size
187
+ )
188
+ self.get_weights_by_name_communicator = _Communicator(
189
+ self.send_to_scheduler, server_args.dp_size
190
+ )
176
191
 
177
192
  # Metrics
178
193
  if self.enable_metrics:
@@ -190,8 +205,7 @@ class TokenizerManager:
190
205
  ):
191
206
  created_time = time.time()
192
207
 
193
- if self.to_create_loop:
194
- self.create_handle_loop()
208
+ self.auto_create_handle_loop()
195
209
 
196
210
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
197
211
  raise ValueError(
@@ -251,8 +265,9 @@ class TokenizerManager:
251
265
  return_logprob = obj.return_logprob
252
266
  logprob_start_len = obj.logprob_start_len
253
267
  top_logprobs_num = obj.top_logprobs_num
254
- session_id = obj.session[0] if obj.session else None
255
- session_rid = obj.session[1] if obj.session else None
268
+ session_params = (
269
+ SessionParams(**obj.session_params) if obj.session_params else None
270
+ )
256
271
 
257
272
  if obj.input_ids is not None and len(input_ids) >= self.context_len:
258
273
  raise ValueError(
@@ -279,8 +294,7 @@ class TokenizerManager:
279
294
  obj.stream,
280
295
  lora_path=obj.lora_path,
281
296
  input_embeds=input_embeds,
282
- session_id=session_id,
283
- session_rid=session_rid,
297
+ session_params=session_params,
284
298
  )
285
299
  elif isinstance(obj, EmbeddingReqInput):
286
300
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -440,8 +454,7 @@ class TokenizerManager:
440
454
  obj: UpdateWeightFromDiskReqInput,
441
455
  request: Optional[fastapi.Request] = None,
442
456
  ) -> Tuple[bool, str]:
443
- if self.to_create_loop:
444
- self.create_handle_loop()
457
+ self.auto_create_handle_loop()
445
458
 
446
459
  # default the load format to the server_args
447
460
  if obj.load_format is None:
@@ -456,7 +469,7 @@ class TokenizerManager:
456
469
 
457
470
  async def _wait_for_model_update_from_disk(
458
471
  self, obj: UpdateWeightFromDiskReqInput
459
- ) -> Tuple[bool, str, int]:
472
+ ) -> Tuple[bool, str]:
460
473
  self.send_to_scheduler.send_pyobj(obj)
461
474
  self.model_update_result = asyncio.Future()
462
475
  if self.server_args.dp_size == 1:
@@ -485,15 +498,11 @@ class TokenizerManager:
485
498
  obj: InitWeightsUpdateGroupReqInput,
486
499
  request: Optional[fastapi.Request] = None,
487
500
  ) -> Tuple[bool, str]:
488
- if self.to_create_loop:
489
- self.create_handle_loop()
490
- self.send_to_scheduler.send_pyobj(obj)
491
-
492
- self.init_weights_update_group_result = asyncio.Future()
501
+ self.auto_create_handle_loop()
493
502
  assert (
494
503
  self.server_args.dp_size == 1
495
504
  ), "dp_size must be 1 for init parameter update group"
496
- result = await self.init_weights_update_group_result
505
+ result = (await self.init_weights_update_group_communicator(obj))[0]
497
506
  return result.success, result.message
498
507
 
499
508
  async def update_weights_from_distributed(
@@ -501,51 +510,59 @@ class TokenizerManager:
501
510
  obj: UpdateWeightsFromDistributedReqInput,
502
511
  request: Optional[fastapi.Request] = None,
503
512
  ) -> Tuple[bool, str]:
504
- if self.to_create_loop:
505
- self.create_handle_loop()
513
+ self.auto_create_handle_loop()
514
+ assert (
515
+ self.server_args.dp_size == 1
516
+ ), "dp_size must be for update weights from distributed"
517
+
518
+ # This means that weight sync
519
+ # cannot run while requests are in progress.
520
+ async with self.model_update_lock.writer_lock:
521
+ result = (await self.update_weights_from_distributed_communicator(obj))[0]
522
+ return result.success, result.message
523
+
524
+ async def update_weights_from_tensor(
525
+ self,
526
+ obj: UpdateWeightsFromTensorReqInput,
527
+ request: Optional[fastapi.Request] = None,
528
+ ) -> Tuple[bool, str]:
529
+ self.auto_create_handle_loop()
530
+ assert (
531
+ self.server_args.dp_size == 1
532
+ ), "dp_size must be for update weights from distributed"
506
533
 
507
534
  # This means that weight sync
508
535
  # cannot run while requests are in progress.
509
536
  async with self.model_update_lock.writer_lock:
510
- self.send_to_scheduler.send_pyobj(obj)
511
- self.parameter_update_result: Awaitable[
512
- UpdateWeightsFromDistributedReqOutput
513
- ] = asyncio.Future()
514
- assert (
515
- self.server_args.dp_size == 1
516
- ), "dp_size must be for update weights from distributed"
517
- result = await self.parameter_update_result
537
+ result = (await self.update_weights_from_tensor_communicator(obj))[0]
518
538
  return result.success, result.message
519
539
 
520
540
  async def get_weights_by_name(
521
541
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
522
542
  ):
523
- if self.to_create_loop:
524
- self.create_handle_loop()
525
-
526
- self.send_to_scheduler.send_pyobj(obj)
527
- self.get_weights_by_name_result = asyncio.Future()
543
+ self.auto_create_handle_loop()
544
+ results = await self.get_weights_by_name_communicator(obj)
545
+ all_parameters = [r.parameter for r in results]
528
546
  if self.server_args.dp_size == 1:
529
- result = await self.get_weights_by_name_result
530
- return result.parameter
547
+ return all_parameters[0]
531
548
  else:
532
- self.get_weights_by_name_tmp = []
533
- result = await self.get_weights_by_name_result
534
- all_parameters = [r.parameter for r in result]
535
549
  return all_parameters
536
550
 
537
551
  async def open_session(
538
552
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
539
553
  ):
540
- if self.to_create_loop:
541
- self.create_handle_loop()
554
+ self.auto_create_handle_loop()
555
+
556
+ if obj.session_id is None:
557
+ obj.session_id = uuid.uuid4().hex
558
+ elif obj.session_id in self.session_futures:
559
+ return None
542
560
 
543
- session_id = uuid.uuid4().hex
544
- obj.session_id = session_id
545
561
  self.send_to_scheduler.send_pyobj(obj)
546
- self.session_futures[session_id] = asyncio.Future()
547
- session_id = await self.session_futures[session_id]
548
- del self.session_futures[session_id]
562
+
563
+ self.session_futures[obj.session_id] = asyncio.Future()
564
+ session_id = await self.session_futures[obj.session_id]
565
+ del self.session_futures[obj.session_id]
549
566
  return session_id
550
567
 
551
568
  async def close_session(
@@ -568,7 +585,7 @@ class TokenizerManager:
568
585
  background_tasks.add_task(abort_request)
569
586
  return background_tasks
570
587
 
571
- def create_handle_loop(self):
588
+ def auto_create_handle_loop(self):
572
589
  if not self.to_create_loop:
573
590
  return
574
591
 
@@ -697,7 +714,7 @@ class TokenizerManager:
697
714
  )
698
715
  elif isinstance(recv_obj, OpenSessionReqOutput):
699
716
  self.session_futures[recv_obj.session_id].set_result(
700
- recv_obj.session_id
717
+ recv_obj.session_id if recv_obj.success else None
701
718
  )
702
719
  elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
703
720
  if self.server_args.dp_size == 1:
@@ -711,21 +728,19 @@ class TokenizerManager:
711
728
  assert (
712
729
  self.server_args.dp_size == 1
713
730
  ), "dp_size must be 1 for init parameter update group"
714
- self.init_weights_update_group_result.set_result(recv_obj)
731
+ self.init_weights_update_group_communicator.handle_recv(recv_obj)
715
732
  elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
716
733
  assert (
717
734
  self.server_args.dp_size == 1
718
735
  ), "dp_size must be 1 for update weights from distributed"
719
- self.parameter_update_result.set_result(recv_obj)
736
+ self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
737
+ elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
738
+ assert (
739
+ self.server_args.dp_size == 1
740
+ ), "dp_size must be 1 for update weights from distributed"
741
+ self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
720
742
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
721
- if self.server_args.dp_size == 1:
722
- self.get_weights_by_name_result.set_result(recv_obj)
723
- else:
724
- self.get_weights_by_name_tmp.append(recv_obj)
725
- if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
726
- self.get_weights_by_name_result.set_result(
727
- self.get_weights_by_name_tmp
728
- )
743
+ self.get_weights_by_name_communicator.handle_recv(recv_obj)
729
744
  else:
730
745
  raise ValueError(f"Invalid object: {recv_obj=}")
731
746
 
@@ -809,3 +824,28 @@ class SignalHandler:
809
824
  f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
810
825
  )
811
826
  self.tokenizer_manager.gracefully_exit = True
827
+
828
+
829
+ T = TypeVar("T")
830
+
831
+
832
+ class _Communicator(Generic[T]):
833
+ def __init__(self, sender, fan_out: int):
834
+ self._sender = sender
835
+ self._fan_out = fan_out
836
+ self._result_future: Optional[asyncio.Future] = None
837
+ self._result_values: Optional[List[T]] = None
838
+
839
+ async def __call__(self, obj):
840
+ self._sender.send_pyobj(obj)
841
+ self._result_future = asyncio.Future()
842
+ self._result_values = []
843
+ await self._result_future
844
+ result_values = self._result_values
845
+ self._result_future = self._result_values = None
846
+ return result_values
847
+
848
+ def handle_recv(self, recv_obj: T):
849
+ self._result_values.append(recv_obj)
850
+ if len(self._result_values) == self._fan_out:
851
+ self._result_future.set_result(None)
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
24
24
  InitWeightsUpdateGroupReqInput,
25
25
  UpdateWeightFromDiskReqInput,
26
26
  UpdateWeightsFromDistributedReqInput,
27
+ UpdateWeightsFromTensorReqInput,
27
28
  )
28
29
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
29
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -188,6 +189,12 @@ class TpModelWorker:
188
189
  )
189
190
  return success, message
190
191
 
192
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
193
+ success, message = self.model_runner.update_weights_from_tensor(
194
+ recv_req.name, recv_req.tensor
195
+ )
196
+ return success, message
197
+
191
198
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
199
  parameter = self.model_runner.get_weights_by_name(
193
200
  recv_req.name, recv_req.truncate_size
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
28
28
  InitWeightsUpdateGroupReqInput,
29
29
  UpdateWeightFromDiskReqInput,
30
30
  UpdateWeightsFromDistributedReqInput,
31
+ UpdateWeightsFromTensorReqInput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
33
34
  from sglang.srt.managers.tp_worker import TpModelWorker
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
225
226
  success, message = self.worker.update_weights_from_distributed(recv_req)
226
227
  return success, message
227
228
 
229
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
230
+ success, message = self.worker.update_weights_from_tensor(recv_req)
231
+ return success, message
232
+
228
233
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
229
234
  return self.worker.get_weights_by_name(recv_req)
230
235
 
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
45
45
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
46
  from sglang.srt.model_executor.model_runner import ModelRunner
47
47
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
48
+ from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
48
49
 
49
50
 
50
51
  class ForwardMode(IntEnum):
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
59
60
  # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
60
61
  IDLE = auto()
61
62
 
63
+ # Used in speculative decoding: verify a batch in the target model.
64
+ TARGET_VERIFY = auto()
65
+ # Used in speculative decoding: extend a batch in the draft model.
66
+ DRAFT_EXTEND = auto()
67
+
62
68
  # A dummy first batch to start the pipeline for overlap scheduler.
63
69
  # It is now used for triggering the sampling_info_done event for the first prefill batch.
64
70
  DUMMY_FIRST = auto()
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
67
73
  return self == ForwardMode.PREFILL
68
74
 
69
75
  def is_extend(self):
70
- return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
76
+ return (
77
+ self == ForwardMode.EXTEND
78
+ or self == ForwardMode.MIXED
79
+ or self == ForwardMode.DRAFT_EXTEND
80
+ or self == self.TARGET_VERIFY
81
+ )
71
82
 
72
83
  def is_decode(self):
73
84
  return self == ForwardMode.DECODE
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
78
89
  def is_idle(self):
79
90
  return self == ForwardMode.IDLE
80
91
 
92
+ def is_target_verify(self):
93
+ return self == ForwardMode.TARGET_VERIFY
94
+
95
+ def is_draft_extend(self):
96
+ return self == ForwardMode.DRAFT_EXTEND
97
+
98
+ def is_cuda_graph(self):
99
+ return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
100
+
81
101
  def is_dummy_first(self):
82
102
  return self == ForwardMode.DUMMY_FIRST
83
103
 
@@ -141,14 +161,18 @@ class ForwardBatch:
141
161
  token_to_kv_pool: BaseTokenToKVPool = None
142
162
  attn_backend: AttentionBackend = None
143
163
 
144
- # For Qwen2-VL
145
- mrope_positions: torch.Tensor = None
164
+ # Speculative decoding
165
+ spec_info: SpecInfo = None
166
+ spec_algorithm: SpeculativeAlgorithm = None
146
167
 
147
168
  # For DP attention
148
169
  global_num_tokens: Optional[List[int]] = None
149
170
  gathered_buffer: Optional[torch.Tensor] = None
150
171
  can_run_dp_cuda_graph: bool = False
151
172
 
173
+ # For Qwen2-VL
174
+ mrope_positions: torch.Tensor = None
175
+
152
176
  def compute_mrope_positions(
153
177
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
154
178
  ):
@@ -351,3 +375,18 @@ def compute_position_torch(
351
375
  extend_start_loc = torch.zeros_like(extend_seq_lens)
352
376
  extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
353
377
  return positions.to(torch.int64), extend_start_loc
378
+
379
+
380
+ class CaptureHiddenMode(IntEnum):
381
+ NULL = auto()
382
+ FULL = auto()
383
+ LAST = auto()
384
+
385
+ def need_capture(self):
386
+ return self != CaptureHiddenMode.NULL
387
+
388
+ def is_full(self):
389
+ return self == CaptureHiddenMode.FULL
390
+
391
+ def is_last(self):
392
+ return self == CaptureHiddenMode.LAST
@@ -95,12 +95,6 @@ class ModelRunner:
95
95
  ):
96
96
  logger.info("MLA optimization is turned on. Use triton backend.")
97
97
  self.server_args.attention_backend = "triton"
98
- # FIXME(HandH1998)
99
- if (
100
- "DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
101
- and not self.server_args.disable_cuda_graph
102
- ):
103
- self.server_args.disable_cuda_graph = True
104
98
 
105
99
  if self.server_args.enable_double_sparsity:
106
100
  logger.info(
@@ -435,6 +429,10 @@ class ModelRunner:
435
429
  logger.error(error_msg)
436
430
  return False, error_msg
437
431
 
432
+ def update_weights_from_tensor(self, name, tensor: torch.Tensor):
433
+ self.model.load_weights([(name, tensor)])
434
+ return True, "Success" # TODO error handling
435
+
438
436
  def get_weights_by_name(
439
437
  self, name: str, truncate_size: int = 100
440
438
  ) -> Optional[torch.Tensor]:
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
770
770
  quant_state_dict,
771
771
  )
772
772
 
773
+ def _is_8bit_weight_name(self, weight_name: str):
774
+ quantized_suffix = {".scb", ".weight_format"}
775
+ return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
776
+
777
+ def _is_4bit_weight_name(self, weight_name: str):
778
+ quantized_suffix = {
779
+ "absmax",
780
+ "quant_map",
781
+ "nested_absmax",
782
+ "nested_quant_map",
783
+ "bitsandbytes",
784
+ }
785
+ suffix = weight_name.split(".")[-1]
786
+ return any(q_suffix in suffix for q_suffix in quantized_suffix)
787
+
773
788
  def _quantized_8bit_generator(
774
789
  self, hf_weights_files, use_safetensors, quant_state_dict
775
790
  ) -> Generator:
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
779
794
  if not weight_name.lower().endswith(".scb"):
780
795
  continue
781
796
 
782
- weight_key = weight_name.lower().replace(".scb", ".qweight")
797
+ weight_key = weight_name.lower().replace(".scb", ".weight")
783
798
  quant_state_dict[weight_key] = weight_tensor
784
799
 
785
800
  for weight_name, weight_tensor in self._hf_weight_iter(
786
801
  hf_weights_files, use_safetensors
787
802
  ):
788
-
789
- if not weight_name.endswith((".weight", ".bias")):
803
+ if self._is_8bit_weight_name(weight_name):
790
804
  continue
791
805
 
792
- qweight_name = weight_name.replace(".weight", ".qweight")
793
-
794
- if qweight_name in quant_state_dict:
806
+ if weight_name in quant_state_dict:
795
807
  set_weight_attrs(weight_tensor, {"load_in_8bit": True})
796
- yield qweight_name, weight_tensor
808
+ yield weight_name, weight_tensor
797
809
  else:
798
810
  yield weight_name, weight_tensor
799
811
 
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
806
818
  weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
807
819
  temp_state_dict = {}
808
820
  for weight_name, weight_tensor in weight_iterator:
809
- if weight_name.endswith((".weight", ".bias")):
821
+ if not self._is_4bit_weight_name(weight_name):
810
822
  continue
811
823
  # bitsandbytes library requires
812
824
  # weight.quant_state.bitsandbytes__* in CPU
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
830
842
  hf_weights_files, use_safetensors
831
843
  ):
832
844
 
833
- if not weight_name.endswith((".weight", ".bias")):
845
+ if self._is_4bit_weight_name(weight_name):
834
846
  continue
835
847
 
836
848
  if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
837
849
  f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
838
850
  ):
839
851
  quant_state = _parse_quant_state(weight_name, temp_state_dict)
840
- weight_name = weight_name.replace(".weight", ".qweight")
841
852
  quant_state_dict[weight_name] = quant_state
842
- yield weight_name.replace(".weight", ".qweight"), weight_tensor
853
+ yield weight_name, weight_tensor
843
854
  else:
844
855
  yield weight_name, weight_tensor
845
856
 
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
307
307
 
308
308
 
309
309
  class Gemma2ForCausalLM(nn.Module):
310
+ # BitandBytes specific attributes
311
+ default_bitsandbytes_target_modules = [
312
+ ".gate_proj.",
313
+ ".down_proj.",
314
+ ".up_proj.",
315
+ ".q_proj.",
316
+ ".k_proj.",
317
+ ".v_proj.",
318
+ ".o_proj.",
319
+ ]
320
+ bitsandbytes_stacked_params_mapping = {
321
+ # shard_name, weight_name, index
322
+ "q_proj": ("qkv_proj", 0),
323
+ "k_proj": ("qkv_proj", 1),
324
+ "v_proj": ("qkv_proj", 2),
325
+ "gate_proj": ("gate_up_proj", 0),
326
+ "up_proj": ("gate_up_proj", 1),
327
+ }
328
+
310
329
  packed_modules_mapping = {
311
330
  "qkv_proj": [
312
331
  "q_proj",
@@ -325,8 +325,8 @@ class LlamaForCausalLM(nn.Module):
325
325
  self.config = config
326
326
  self.quant_config = quant_config
327
327
  self.model = LlamaModel(config, quant_config=quant_config)
328
- # Llama 3.2 1B Insturct set tie_word_embeddings to True
329
- # Llama 3.1 8B Insturct set tie_word_embeddings to False
328
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
329
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
330
330
  if self.config.tie_word_embeddings:
331
331
  self.lm_head = self.model.embed_tokens
332
332
  else:
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
516
516
  )
517
517
  return None
518
518
 
519
+ def get_embed_and_head(self):
520
+ return self.model.embed_tokens.weight, self.lm_head.weight
521
+
522
+ def set_embed_and_head(self, embed, head):
523
+ del self.model.embed_tokens.weight
524
+ del self.lm_head.weight
525
+ self.model.embed_tokens.weight = embed
526
+ self.lm_head.weight = head
527
+ torch.cuda.empty_cache()
528
+ torch.cuda.synchronize()
529
+
519
530
 
520
531
  class Phi3ForCausalLM(LlamaForCausalLM):
521
532
  pass