sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Any, Dict, List, Optional, Union
25
+ from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -30,6 +30,7 @@ import zmq
30
30
  import zmq.asyncio
31
31
  from fastapi import BackgroundTasks
32
32
 
33
+ from sglang.srt.aio_rwlock import RWLock
33
34
  from sglang.srt.configs.model_config import ModelConfig
34
35
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
  from sglang.srt.managers.image_processor import (
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
62
63
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
63
64
  from sglang.srt.sampling.sampling_params import SamplingParams
64
65
  from sglang.srt.server_args import PortArgs, ServerArgs
65
- from sglang.srt.utils import get_zmq_socket, kill_process_tree
66
+ from sglang.srt.utils import (
67
+ dataclass_to_string_truncated,
68
+ get_zmq_socket,
69
+ kill_process_tree,
70
+ )
66
71
 
67
72
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
68
73
 
@@ -82,6 +87,9 @@ class ReqState:
82
87
  created_time: float
83
88
  first_token_time: Optional[float] = None
84
89
 
90
+ # For streaming output
91
+ last_output_offset: int = 0
92
+
85
93
 
86
94
  class TokenizerManager:
87
95
  """TokenizerManager is a process that tokenizes the text."""
@@ -120,6 +128,7 @@ class TokenizerManager:
120
128
 
121
129
  self.is_generation = self.model_config.is_generation
122
130
  self.context_len = self.model_config.context_len
131
+ self.image_token_id = self.model_config.image_token_id
123
132
 
124
133
  # Create image processor placeholder
125
134
  self.image_processor = get_dummy_image_processor()
@@ -152,15 +161,27 @@ class TokenizerManager:
152
161
  self.to_create_loop = True
153
162
  self.rid_to_state: Dict[str, ReqState] = {}
154
163
 
155
- # For update model weights
156
- self.model_update_lock = asyncio.Lock()
157
- self.model_update_result = None
164
+ # The event to notify the weight sync is finished.
165
+ self.model_update_lock = RWLock()
166
+ self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
167
+ None
168
+ )
169
+ self.asyncio_tasks = set()
158
170
 
159
171
  # For session info
160
172
  self.session_futures = {} # session_id -> asyncio event
161
173
 
162
174
  # Others
163
175
  self.gracefully_exit = False
176
+ self.init_weights_update_group_communicator = _Communicator(
177
+ self.send_to_scheduler, server_args.dp_size
178
+ )
179
+ self.update_weights_from_distributed_communicator = _Communicator(
180
+ self.send_to_scheduler, server_args.dp_size
181
+ )
182
+ self.get_weights_by_name_communicator = _Communicator(
183
+ self.send_to_scheduler, server_args.dp_size
184
+ )
164
185
 
165
186
  # Metrics
166
187
  if self.enable_metrics:
@@ -178,11 +199,7 @@ class TokenizerManager:
178
199
  ):
179
200
  created_time = time.time()
180
201
 
181
- if self.to_create_loop:
182
- self.create_handle_loop()
183
-
184
- while self.model_update_lock.locked():
185
- await asyncio.sleep(0.001)
202
+ self.auto_create_handle_loop()
186
203
 
187
204
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
188
205
  raise ValueError(
@@ -191,17 +208,24 @@ class TokenizerManager:
191
208
  )
192
209
 
193
210
  obj.normalize_batch_and_arguments()
194
- is_single = obj.is_single
195
- if is_single:
196
- tokenized_obj = await self._tokenize_one_request(obj)
197
- self.send_to_scheduler.send_pyobj(tokenized_obj)
198
- async for response in self._wait_one_response(obj, request, created_time):
199
- yield response
200
- else:
201
- async for response in self._handle_batch_request(
202
- obj, request, created_time
203
- ):
204
- yield response
211
+
212
+ if self.server_args.log_requests:
213
+ logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
214
+
215
+ async with self.model_update_lock.reader_lock:
216
+ is_single = obj.is_single
217
+ if is_single:
218
+ tokenized_obj = await self._tokenize_one_request(obj)
219
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
220
+ async for response in self._wait_one_response(
221
+ obj, request, created_time
222
+ ):
223
+ yield response
224
+ else:
225
+ async for response in self._handle_batch_request(
226
+ obj, request, created_time
227
+ ):
228
+ yield response
205
229
 
206
230
  async def _tokenize_one_request(
207
231
  self,
@@ -215,7 +239,7 @@ class TokenizerManager:
215
239
  if not self.server_args.disable_radix_cache:
216
240
  raise ValueError(
217
241
  "input_embeds is provided while disable_radix_cache is False. "
218
- "Please add `--disable-radix-cach` when you launch the server "
242
+ "Please add `--disable-radix-cache` when you launch the server "
219
243
  "if you want to use input_embeds as inputs."
220
244
  )
221
245
  input_embeds = obj.input_embeds
@@ -301,8 +325,8 @@ class TokenizerManager:
301
325
  state.out_list = []
302
326
  if state.finished:
303
327
  if self.server_args.log_requests:
304
- # Log requests
305
- logger.info(f"in={obj}, out={out}")
328
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
329
+ logger.info(msg)
306
330
  del self.rid_to_state[obj.rid]
307
331
  yield out
308
332
  break
@@ -423,112 +447,89 @@ class TokenizerManager:
423
447
  self,
424
448
  obj: UpdateWeightFromDiskReqInput,
425
449
  request: Optional[fastapi.Request] = None,
426
- ):
427
- if self.to_create_loop:
428
- self.create_handle_loop()
450
+ ) -> Tuple[bool, str]:
451
+ self.auto_create_handle_loop()
429
452
 
430
453
  # default the load format to the server_args
431
454
  if obj.load_format is None:
432
455
  obj.load_format = self.server_args.load_format
456
+ logger.info("Start update_weights. Load format=%s", obj.load_format)
433
457
 
434
- if not self.model_update_lock.locked():
435
-
436
- async with self.model_update_lock:
437
- # wait for the previous generation requests to finish
438
- for i in range(3):
439
- while len(self.rid_to_state) > 0:
440
- await asyncio.sleep(0.001)
441
- # FIXME: We add some sleep here to avoid some race conditions.
442
- # We can use a read-write lock as a better fix.
443
- await asyncio.sleep(0.01)
444
- self.send_to_scheduler.send_pyobj(obj)
445
- self.model_update_result = asyncio.Future()
446
-
447
- if self.server_args.dp_size == 1:
448
- result = await self.model_update_result
449
- if result.success:
450
- self.server_args.model_path = obj.model_path
451
- self.server_args.load_format = obj.load_format
452
- self.model_path = obj.model_path
453
- return result.success, result.message
454
- else: # self.server_args.dp_size > 1
455
- self.model_update_tmp = []
456
- result = await self.model_update_result
457
-
458
- all_success = all([r.success for r in result])
459
- if all_success is True:
460
- self.server_args.model_path = obj.model_path
461
- self.server_args.load_format = obj.load_format
462
- self.model_path = obj.model_path
463
- all_message = [r.message for r in result]
464
- all_message = " | ".join(all_message)
465
- return all_success, all_message
458
+ if True:
459
+ # Hold the lock if it is not async. This means that weight sync
460
+ # cannot run while requests are in progress.
461
+ async with self.model_update_lock.writer_lock:
462
+ return await self._wait_for_model_update_from_disk(obj)
466
463
 
467
- else:
468
- return False, "Another update is in progress. Please try again later."
464
+ async def _wait_for_model_update_from_disk(
465
+ self, obj: UpdateWeightFromDiskReqInput
466
+ ) -> Tuple[bool, str]:
467
+ self.send_to_scheduler.send_pyobj(obj)
468
+ self.model_update_result = asyncio.Future()
469
+ if self.server_args.dp_size == 1:
470
+ result = await self.model_update_result
471
+ if result.success:
472
+ self.served_model_name = obj.model_path
473
+ self.server_args.model_path = obj.model_path
474
+ self.server_args.load_format = obj.load_format
475
+ self.model_path = obj.model_path
476
+ return result.success, result.message
477
+ else: # self.server_args.dp_size > 1
478
+ self.model_update_tmp = []
479
+ result = await self.model_update_result
480
+
481
+ all_success = all([r.success for r in result])
482
+ if all_success is True:
483
+ self.server_args.model_path = obj.model_path
484
+ self.server_args.load_format = obj.load_format
485
+ self.model_path = obj.model_path
486
+ all_message = [r.message for r in result]
487
+ all_message = " | ".join(all_message)
488
+ return all_success, all_message
469
489
 
470
490
  async def init_weights_update_group(
471
491
  self,
472
492
  obj: InitWeightsUpdateGroupReqInput,
473
493
  request: Optional[fastapi.Request] = None,
474
- ) -> bool:
475
- if self.to_create_loop:
476
- self.create_handle_loop()
477
- self.send_to_scheduler.send_pyobj(obj)
478
-
479
- self.init_weights_update_group_result = asyncio.Future()
494
+ ) -> Tuple[bool, str]:
495
+ self.auto_create_handle_loop()
480
496
  assert (
481
497
  self.server_args.dp_size == 1
482
498
  ), "dp_size must be 1 for init parameter update group"
483
- result = await self.init_weights_update_group_result
499
+ result = (await self.init_weights_update_group_communicator(obj))[0]
484
500
  return result.success, result.message
485
501
 
486
502
  async def update_weights_from_distributed(
487
503
  self,
488
504
  obj: UpdateWeightsFromDistributedReqInput,
489
505
  request: Optional[fastapi.Request] = None,
490
- ):
491
- if self.to_create_loop:
492
- self.create_handle_loop()
506
+ ) -> Tuple[bool, str]:
507
+ self.auto_create_handle_loop()
508
+ assert (
509
+ self.server_args.dp_size == 1
510
+ ), "dp_size must be for update weights from distributed"
493
511
 
494
- if not self.model_update_lock.locked():
495
- async with self.model_update_lock:
496
- self.send_to_scheduler.send_pyobj(obj)
497
- self.parameter_update_result = asyncio.Future()
498
- assert (
499
- self.server_args.dp_size == 1
500
- ), "dp_size must be for update weights from distributed"
501
- result = await self.parameter_update_result
502
- return result.success, result.message
503
- else:
504
- logger.error("Another parameter update is in progress in tokenizer manager")
505
- return (
506
- False,
507
- "Another parameter update is in progress. Please try again later.",
508
- )
512
+ # This means that weight sync
513
+ # cannot run while requests are in progress.
514
+ async with self.model_update_lock.writer_lock:
515
+ result = (await self.update_weights_from_distributed_communicator(obj))[0]
516
+ return result.success, result.message
509
517
 
510
518
  async def get_weights_by_name(
511
519
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
512
520
  ):
513
- if self.to_create_loop:
514
- self.create_handle_loop()
515
-
516
- self.send_to_scheduler.send_pyobj(obj)
517
- self.get_weights_by_name_result = asyncio.Future()
521
+ self.auto_create_handle_loop()
522
+ results = await self.get_weights_by_name_communicator(obj)
523
+ all_parameters = [r.parameter for r in results]
518
524
  if self.server_args.dp_size == 1:
519
- result = await self.get_weights_by_name_result
520
- return result.parameter
525
+ return all_parameters[0]
521
526
  else:
522
- self.get_weights_by_name_tmp = []
523
- result = await self.get_weights_by_name_result
524
- all_parameters = [r.parameter for r in result]
525
527
  return all_parameters
526
528
 
527
529
  async def open_session(
528
530
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
529
531
  ):
530
- if self.to_create_loop:
531
- self.create_handle_loop()
532
+ self.auto_create_handle_loop()
532
533
 
533
534
  session_id = uuid.uuid4().hex
534
535
  obj.session_id = session_id
@@ -558,17 +559,17 @@ class TokenizerManager:
558
559
  background_tasks.add_task(abort_request)
559
560
  return background_tasks
560
561
 
561
- def create_handle_loop(self):
562
+ def auto_create_handle_loop(self):
562
563
  if not self.to_create_loop:
563
564
  return
564
565
 
565
566
  self.to_create_loop = False
566
567
  loop = asyncio.get_event_loop()
567
- loop.create_task(self.handle_loop())
568
+ self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
568
569
 
569
570
  signal_handler = SignalHandler(self)
570
571
  loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
571
- loop.create_task(self.sigterm_watchdog())
572
+ self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
572
573
 
573
574
  async def sigterm_watchdog(self):
574
575
  while not self.gracefully_exit:
@@ -701,21 +702,14 @@ class TokenizerManager:
701
702
  assert (
702
703
  self.server_args.dp_size == 1
703
704
  ), "dp_size must be 1 for init parameter update group"
704
- self.init_weights_update_group_result.set_result(recv_obj)
705
+ self.init_weights_update_group_communicator.handle_recv(recv_obj)
705
706
  elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
706
707
  assert (
707
708
  self.server_args.dp_size == 1
708
709
  ), "dp_size must be 1 for update weights from distributed"
709
- self.parameter_update_result.set_result(recv_obj)
710
+ self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
710
711
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
711
- if self.server_args.dp_size == 1:
712
- self.get_weights_by_name_result.set_result(recv_obj)
713
- else:
714
- self.get_weights_by_name_tmp.append(recv_obj)
715
- if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
716
- self.get_weights_by_name_result.set_result(
717
- self.get_weights_by_name_tmp
718
- )
712
+ self.get_weights_by_name_communicator.handle_recv(recv_obj)
719
713
  else:
720
714
  raise ValueError(f"Invalid object: {recv_obj=}")
721
715
 
@@ -799,3 +793,28 @@ class SignalHandler:
799
793
  f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
800
794
  )
801
795
  self.tokenizer_manager.gracefully_exit = True
796
+
797
+
798
+ T = TypeVar("T")
799
+
800
+
801
+ class _Communicator(Generic[T]):
802
+ def __init__(self, sender, fan_out: int):
803
+ self._sender = sender
804
+ self._fan_out = fan_out
805
+ self._result_future: Optional[asyncio.Future] = None
806
+ self._result_values: Optional[List[T]] = None
807
+
808
+ async def __call__(self, obj):
809
+ self._sender.send_pyobj(obj)
810
+ self._result_future = asyncio.Future()
811
+ self._result_values = []
812
+ await self._result_future
813
+ result_values = self._result_values
814
+ self._result_future = self._result_values = None
815
+ return result_values
816
+
817
+ def handle_recv(self, recv_obj: T):
818
+ self._result_values.append(recv_obj)
819
+ if len(self._result_values) == self._fan_out:
820
+ self._result_future.set_result(None)
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
184
184
  device: str,
185
185
  ):
186
186
  super().__init__(size, dtype, device)
187
+ self.head_num = head_num
188
+ self.head_dim = head_dim
189
+ self.layer_num = layer_num
190
+ self._create_buffers()
187
191
 
192
+ def _create_buffers(self):
188
193
  # [size, head_num, head_dim] for each layer
189
194
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
190
195
  self.k_buffer = [
191
196
  torch.empty(
192
- (size + 1, head_num, head_dim),
197
+ (self.size + 1, self.head_num, self.head_dim),
193
198
  dtype=self.store_dtype,
194
- device=device,
199
+ device=self.device,
195
200
  )
196
- for _ in range(layer_num)
201
+ for _ in range(self.layer_num)
197
202
  ]
198
203
  self.v_buffer = [
199
204
  torch.empty(
200
- (size + 1, head_num, head_dim),
205
+ (self.size + 1, self.head_num, self.head_dim),
201
206
  dtype=self.store_dtype,
202
- device=device,
207
+ device=self.device,
203
208
  )
204
- for _ in range(layer_num)
209
+ for _ in range(self.layer_num)
205
210
  ]
206
211
 
212
+ def _clear_buffers(self):
213
+ del self.k_buffer
214
+ del self.v_buffer
215
+
207
216
  def get_key_buffer(self, layer_id: int):
208
217
  if self.store_dtype != self.dtype:
209
218
  return self.k_buffer[layer_id].view(self.dtype)
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
245
254
 
246
255
 
247
256
  class MLATokenToKVPool(BaseTokenToKVPool):
248
-
249
257
  def __init__(
250
258
  self,
251
259
  size: int,
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
298
306
 
299
307
 
300
308
  class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
301
-
302
309
  def __init__(
303
310
  self,
304
311
  size: int,
@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
- from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
29
28
  from sglang.srt.layers.logits_processor import (
30
29
  LogitsMetadata,
31
30
  LogitsProcessor,
32
31
  LogitsProcessorOutput,
33
32
  )
33
+ from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
35
  from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
36
36
 
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
770
770
  quant_state_dict,
771
771
  )
772
772
 
773
+ def _is_8bit_weight_name(self, weight_name: str):
774
+ quantized_suffix = {".scb", ".weight_format"}
775
+ return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
776
+
777
+ def _is_4bit_weight_name(self, weight_name: str):
778
+ quantized_suffix = {
779
+ "absmax",
780
+ "quant_map",
781
+ "nested_absmax",
782
+ "nested_quant_map",
783
+ "bitsandbytes",
784
+ }
785
+ suffix = weight_name.split(".")[-1]
786
+ return any(q_suffix in suffix for q_suffix in quantized_suffix)
787
+
773
788
  def _quantized_8bit_generator(
774
789
  self, hf_weights_files, use_safetensors, quant_state_dict
775
790
  ) -> Generator:
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
779
794
  if not weight_name.lower().endswith(".scb"):
780
795
  continue
781
796
 
782
- weight_key = weight_name.lower().replace(".scb", ".qweight")
797
+ weight_key = weight_name.lower().replace(".scb", ".weight")
783
798
  quant_state_dict[weight_key] = weight_tensor
784
799
 
785
800
  for weight_name, weight_tensor in self._hf_weight_iter(
786
801
  hf_weights_files, use_safetensors
787
802
  ):
788
-
789
- if not weight_name.endswith((".weight", ".bias")):
803
+ if self._is_8bit_weight_name(weight_name):
790
804
  continue
791
805
 
792
- qweight_name = weight_name.replace(".weight", ".qweight")
793
-
794
- if qweight_name in quant_state_dict:
806
+ if weight_name in quant_state_dict:
795
807
  set_weight_attrs(weight_tensor, {"load_in_8bit": True})
796
- yield qweight_name, weight_tensor
808
+ yield weight_name, weight_tensor
797
809
  else:
798
810
  yield weight_name, weight_tensor
799
811
 
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
806
818
  weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
807
819
  temp_state_dict = {}
808
820
  for weight_name, weight_tensor in weight_iterator:
809
- if weight_name.endswith((".weight", ".bias")):
821
+ if not self._is_4bit_weight_name(weight_name):
810
822
  continue
811
823
  # bitsandbytes library requires
812
824
  # weight.quant_state.bitsandbytes__* in CPU
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
830
842
  hf_weights_files, use_safetensors
831
843
  ):
832
844
 
833
- if not weight_name.endswith((".weight", ".bias")):
845
+ if self._is_4bit_weight_name(weight_name):
834
846
  continue
835
847
 
836
848
  if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
837
849
  f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
838
850
  ):
839
851
  quant_state = _parse_quant_state(weight_name, temp_state_dict)
840
- weight_name = weight_name.replace(".weight", ".qweight")
841
852
  quant_state_dict[weight_name] = quant_state
842
- yield weight_name.replace(".weight", ".qweight"), weight_tensor
853
+ yield weight_name, weight_tensor
843
854
  else:
844
855
  yield weight_name, weight_tensor
845
856
 
sglang/srt/models/dbrx.py CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
28
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
29
 
30
- from sglang.srt.layers.fused_moe_triton import fused_moe
31
30
  from sglang.srt.layers.linear import (
32
31
  QKVParallelLinear,
33
32
  ReplicatedLinear,
34
33
  RowParallelLinear,
35
34
  )
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
30
 
31
31
  from sglang.srt.layers.activation import SiluAndMul
32
- from sglang.srt.layers.fused_moe_triton import fused_moe
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
38
37
  RowParallelLinear,
39
38
  )
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (