sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Any, Dict, List, Optional, Union
|
25
|
+
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -30,6 +30,7 @@ import zmq
|
|
30
30
|
import zmq.asyncio
|
31
31
|
from fastapi import BackgroundTasks
|
32
32
|
|
33
|
+
from sglang.srt.aio_rwlock import RWLock
|
33
34
|
from sglang.srt.configs.model_config import ModelConfig
|
34
35
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
36
|
from sglang.srt.managers.image_processor import (
|
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
|
|
62
63
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
63
64
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
64
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
65
|
-
from sglang.srt.utils import
|
66
|
+
from sglang.srt.utils import (
|
67
|
+
dataclass_to_string_truncated,
|
68
|
+
get_zmq_socket,
|
69
|
+
kill_process_tree,
|
70
|
+
)
|
66
71
|
|
67
72
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
68
73
|
|
@@ -82,6 +87,9 @@ class ReqState:
|
|
82
87
|
created_time: float
|
83
88
|
first_token_time: Optional[float] = None
|
84
89
|
|
90
|
+
# For streaming output
|
91
|
+
last_output_offset: int = 0
|
92
|
+
|
85
93
|
|
86
94
|
class TokenizerManager:
|
87
95
|
"""TokenizerManager is a process that tokenizes the text."""
|
@@ -120,6 +128,7 @@ class TokenizerManager:
|
|
120
128
|
|
121
129
|
self.is_generation = self.model_config.is_generation
|
122
130
|
self.context_len = self.model_config.context_len
|
131
|
+
self.image_token_id = self.model_config.image_token_id
|
123
132
|
|
124
133
|
# Create image processor placeholder
|
125
134
|
self.image_processor = get_dummy_image_processor()
|
@@ -152,15 +161,27 @@ class TokenizerManager:
|
|
152
161
|
self.to_create_loop = True
|
153
162
|
self.rid_to_state: Dict[str, ReqState] = {}
|
154
163
|
|
155
|
-
#
|
156
|
-
self.model_update_lock =
|
157
|
-
self.model_update_result =
|
164
|
+
# The event to notify the weight sync is finished.
|
165
|
+
self.model_update_lock = RWLock()
|
166
|
+
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
167
|
+
None
|
168
|
+
)
|
169
|
+
self.asyncio_tasks = set()
|
158
170
|
|
159
171
|
# For session info
|
160
172
|
self.session_futures = {} # session_id -> asyncio event
|
161
173
|
|
162
174
|
# Others
|
163
175
|
self.gracefully_exit = False
|
176
|
+
self.init_weights_update_group_communicator = _Communicator(
|
177
|
+
self.send_to_scheduler, server_args.dp_size
|
178
|
+
)
|
179
|
+
self.update_weights_from_distributed_communicator = _Communicator(
|
180
|
+
self.send_to_scheduler, server_args.dp_size
|
181
|
+
)
|
182
|
+
self.get_weights_by_name_communicator = _Communicator(
|
183
|
+
self.send_to_scheduler, server_args.dp_size
|
184
|
+
)
|
164
185
|
|
165
186
|
# Metrics
|
166
187
|
if self.enable_metrics:
|
@@ -178,11 +199,7 @@ class TokenizerManager:
|
|
178
199
|
):
|
179
200
|
created_time = time.time()
|
180
201
|
|
181
|
-
|
182
|
-
self.create_handle_loop()
|
183
|
-
|
184
|
-
while self.model_update_lock.locked():
|
185
|
-
await asyncio.sleep(0.001)
|
202
|
+
self.auto_create_handle_loop()
|
186
203
|
|
187
204
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
188
205
|
raise ValueError(
|
@@ -191,17 +208,24 @@ class TokenizerManager:
|
|
191
208
|
)
|
192
209
|
|
193
210
|
obj.normalize_batch_and_arguments()
|
194
|
-
|
195
|
-
if
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
211
|
+
|
212
|
+
if self.server_args.log_requests:
|
213
|
+
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
214
|
+
|
215
|
+
async with self.model_update_lock.reader_lock:
|
216
|
+
is_single = obj.is_single
|
217
|
+
if is_single:
|
218
|
+
tokenized_obj = await self._tokenize_one_request(obj)
|
219
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
220
|
+
async for response in self._wait_one_response(
|
221
|
+
obj, request, created_time
|
222
|
+
):
|
223
|
+
yield response
|
224
|
+
else:
|
225
|
+
async for response in self._handle_batch_request(
|
226
|
+
obj, request, created_time
|
227
|
+
):
|
228
|
+
yield response
|
205
229
|
|
206
230
|
async def _tokenize_one_request(
|
207
231
|
self,
|
@@ -215,7 +239,7 @@ class TokenizerManager:
|
|
215
239
|
if not self.server_args.disable_radix_cache:
|
216
240
|
raise ValueError(
|
217
241
|
"input_embeds is provided while disable_radix_cache is False. "
|
218
|
-
"Please add `--disable-radix-
|
242
|
+
"Please add `--disable-radix-cache` when you launch the server "
|
219
243
|
"if you want to use input_embeds as inputs."
|
220
244
|
)
|
221
245
|
input_embeds = obj.input_embeds
|
@@ -301,8 +325,8 @@ class TokenizerManager:
|
|
301
325
|
state.out_list = []
|
302
326
|
if state.finished:
|
303
327
|
if self.server_args.log_requests:
|
304
|
-
|
305
|
-
logger.info(
|
328
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
329
|
+
logger.info(msg)
|
306
330
|
del self.rid_to_state[obj.rid]
|
307
331
|
yield out
|
308
332
|
break
|
@@ -423,112 +447,89 @@ class TokenizerManager:
|
|
423
447
|
self,
|
424
448
|
obj: UpdateWeightFromDiskReqInput,
|
425
449
|
request: Optional[fastapi.Request] = None,
|
426
|
-
):
|
427
|
-
|
428
|
-
self.create_handle_loop()
|
450
|
+
) -> Tuple[bool, str]:
|
451
|
+
self.auto_create_handle_loop()
|
429
452
|
|
430
453
|
# default the load format to the server_args
|
431
454
|
if obj.load_format is None:
|
432
455
|
obj.load_format = self.server_args.load_format
|
456
|
+
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
433
457
|
|
434
|
-
if
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
while len(self.rid_to_state) > 0:
|
440
|
-
await asyncio.sleep(0.001)
|
441
|
-
# FIXME: We add some sleep here to avoid some race conditions.
|
442
|
-
# We can use a read-write lock as a better fix.
|
443
|
-
await asyncio.sleep(0.01)
|
444
|
-
self.send_to_scheduler.send_pyobj(obj)
|
445
|
-
self.model_update_result = asyncio.Future()
|
446
|
-
|
447
|
-
if self.server_args.dp_size == 1:
|
448
|
-
result = await self.model_update_result
|
449
|
-
if result.success:
|
450
|
-
self.server_args.model_path = obj.model_path
|
451
|
-
self.server_args.load_format = obj.load_format
|
452
|
-
self.model_path = obj.model_path
|
453
|
-
return result.success, result.message
|
454
|
-
else: # self.server_args.dp_size > 1
|
455
|
-
self.model_update_tmp = []
|
456
|
-
result = await self.model_update_result
|
457
|
-
|
458
|
-
all_success = all([r.success for r in result])
|
459
|
-
if all_success is True:
|
460
|
-
self.server_args.model_path = obj.model_path
|
461
|
-
self.server_args.load_format = obj.load_format
|
462
|
-
self.model_path = obj.model_path
|
463
|
-
all_message = [r.message for r in result]
|
464
|
-
all_message = " | ".join(all_message)
|
465
|
-
return all_success, all_message
|
458
|
+
if True:
|
459
|
+
# Hold the lock if it is not async. This means that weight sync
|
460
|
+
# cannot run while requests are in progress.
|
461
|
+
async with self.model_update_lock.writer_lock:
|
462
|
+
return await self._wait_for_model_update_from_disk(obj)
|
466
463
|
|
467
|
-
|
468
|
-
|
464
|
+
async def _wait_for_model_update_from_disk(
|
465
|
+
self, obj: UpdateWeightFromDiskReqInput
|
466
|
+
) -> Tuple[bool, str]:
|
467
|
+
self.send_to_scheduler.send_pyobj(obj)
|
468
|
+
self.model_update_result = asyncio.Future()
|
469
|
+
if self.server_args.dp_size == 1:
|
470
|
+
result = await self.model_update_result
|
471
|
+
if result.success:
|
472
|
+
self.served_model_name = obj.model_path
|
473
|
+
self.server_args.model_path = obj.model_path
|
474
|
+
self.server_args.load_format = obj.load_format
|
475
|
+
self.model_path = obj.model_path
|
476
|
+
return result.success, result.message
|
477
|
+
else: # self.server_args.dp_size > 1
|
478
|
+
self.model_update_tmp = []
|
479
|
+
result = await self.model_update_result
|
480
|
+
|
481
|
+
all_success = all([r.success for r in result])
|
482
|
+
if all_success is True:
|
483
|
+
self.server_args.model_path = obj.model_path
|
484
|
+
self.server_args.load_format = obj.load_format
|
485
|
+
self.model_path = obj.model_path
|
486
|
+
all_message = [r.message for r in result]
|
487
|
+
all_message = " | ".join(all_message)
|
488
|
+
return all_success, all_message
|
469
489
|
|
470
490
|
async def init_weights_update_group(
|
471
491
|
self,
|
472
492
|
obj: InitWeightsUpdateGroupReqInput,
|
473
493
|
request: Optional[fastapi.Request] = None,
|
474
|
-
) -> bool:
|
475
|
-
|
476
|
-
self.create_handle_loop()
|
477
|
-
self.send_to_scheduler.send_pyobj(obj)
|
478
|
-
|
479
|
-
self.init_weights_update_group_result = asyncio.Future()
|
494
|
+
) -> Tuple[bool, str]:
|
495
|
+
self.auto_create_handle_loop()
|
480
496
|
assert (
|
481
497
|
self.server_args.dp_size == 1
|
482
498
|
), "dp_size must be 1 for init parameter update group"
|
483
|
-
result = await self.
|
499
|
+
result = (await self.init_weights_update_group_communicator(obj))[0]
|
484
500
|
return result.success, result.message
|
485
501
|
|
486
502
|
async def update_weights_from_distributed(
|
487
503
|
self,
|
488
504
|
obj: UpdateWeightsFromDistributedReqInput,
|
489
505
|
request: Optional[fastapi.Request] = None,
|
490
|
-
):
|
491
|
-
|
492
|
-
|
506
|
+
) -> Tuple[bool, str]:
|
507
|
+
self.auto_create_handle_loop()
|
508
|
+
assert (
|
509
|
+
self.server_args.dp_size == 1
|
510
|
+
), "dp_size must be for update weights from distributed"
|
493
511
|
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
self.server_args.dp_size == 1
|
500
|
-
), "dp_size must be for update weights from distributed"
|
501
|
-
result = await self.parameter_update_result
|
502
|
-
return result.success, result.message
|
503
|
-
else:
|
504
|
-
logger.error("Another parameter update is in progress in tokenizer manager")
|
505
|
-
return (
|
506
|
-
False,
|
507
|
-
"Another parameter update is in progress. Please try again later.",
|
508
|
-
)
|
512
|
+
# This means that weight sync
|
513
|
+
# cannot run while requests are in progress.
|
514
|
+
async with self.model_update_lock.writer_lock:
|
515
|
+
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
516
|
+
return result.success, result.message
|
509
517
|
|
510
518
|
async def get_weights_by_name(
|
511
519
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
512
520
|
):
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
self.send_to_scheduler.send_pyobj(obj)
|
517
|
-
self.get_weights_by_name_result = asyncio.Future()
|
521
|
+
self.auto_create_handle_loop()
|
522
|
+
results = await self.get_weights_by_name_communicator(obj)
|
523
|
+
all_parameters = [r.parameter for r in results]
|
518
524
|
if self.server_args.dp_size == 1:
|
519
|
-
|
520
|
-
return result.parameter
|
525
|
+
return all_parameters[0]
|
521
526
|
else:
|
522
|
-
self.get_weights_by_name_tmp = []
|
523
|
-
result = await self.get_weights_by_name_result
|
524
|
-
all_parameters = [r.parameter for r in result]
|
525
527
|
return all_parameters
|
526
528
|
|
527
529
|
async def open_session(
|
528
530
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
529
531
|
):
|
530
|
-
|
531
|
-
self.create_handle_loop()
|
532
|
+
self.auto_create_handle_loop()
|
532
533
|
|
533
534
|
session_id = uuid.uuid4().hex
|
534
535
|
obj.session_id = session_id
|
@@ -558,17 +559,17 @@ class TokenizerManager:
|
|
558
559
|
background_tasks.add_task(abort_request)
|
559
560
|
return background_tasks
|
560
561
|
|
561
|
-
def
|
562
|
+
def auto_create_handle_loop(self):
|
562
563
|
if not self.to_create_loop:
|
563
564
|
return
|
564
565
|
|
565
566
|
self.to_create_loop = False
|
566
567
|
loop = asyncio.get_event_loop()
|
567
|
-
loop.create_task(self.handle_loop())
|
568
|
+
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
568
569
|
|
569
570
|
signal_handler = SignalHandler(self)
|
570
571
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
571
|
-
loop.create_task(self.sigterm_watchdog())
|
572
|
+
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
572
573
|
|
573
574
|
async def sigterm_watchdog(self):
|
574
575
|
while not self.gracefully_exit:
|
@@ -701,21 +702,14 @@ class TokenizerManager:
|
|
701
702
|
assert (
|
702
703
|
self.server_args.dp_size == 1
|
703
704
|
), "dp_size must be 1 for init parameter update group"
|
704
|
-
self.
|
705
|
+
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
705
706
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
706
707
|
assert (
|
707
708
|
self.server_args.dp_size == 1
|
708
709
|
), "dp_size must be 1 for update weights from distributed"
|
709
|
-
self.
|
710
|
+
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
710
711
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
711
|
-
|
712
|
-
self.get_weights_by_name_result.set_result(recv_obj)
|
713
|
-
else:
|
714
|
-
self.get_weights_by_name_tmp.append(recv_obj)
|
715
|
-
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
716
|
-
self.get_weights_by_name_result.set_result(
|
717
|
-
self.get_weights_by_name_tmp
|
718
|
-
)
|
712
|
+
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
719
713
|
else:
|
720
714
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
721
715
|
|
@@ -799,3 +793,28 @@ class SignalHandler:
|
|
799
793
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
800
794
|
)
|
801
795
|
self.tokenizer_manager.gracefully_exit = True
|
796
|
+
|
797
|
+
|
798
|
+
T = TypeVar("T")
|
799
|
+
|
800
|
+
|
801
|
+
class _Communicator(Generic[T]):
|
802
|
+
def __init__(self, sender, fan_out: int):
|
803
|
+
self._sender = sender
|
804
|
+
self._fan_out = fan_out
|
805
|
+
self._result_future: Optional[asyncio.Future] = None
|
806
|
+
self._result_values: Optional[List[T]] = None
|
807
|
+
|
808
|
+
async def __call__(self, obj):
|
809
|
+
self._sender.send_pyobj(obj)
|
810
|
+
self._result_future = asyncio.Future()
|
811
|
+
self._result_values = []
|
812
|
+
await self._result_future
|
813
|
+
result_values = self._result_values
|
814
|
+
self._result_future = self._result_values = None
|
815
|
+
return result_values
|
816
|
+
|
817
|
+
def handle_recv(self, recv_obj: T):
|
818
|
+
self._result_values.append(recv_obj)
|
819
|
+
if len(self._result_values) == self._fan_out:
|
820
|
+
self._result_future.set_result(None)
|
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
184
184
|
device: str,
|
185
185
|
):
|
186
186
|
super().__init__(size, dtype, device)
|
187
|
+
self.head_num = head_num
|
188
|
+
self.head_dim = head_dim
|
189
|
+
self.layer_num = layer_num
|
190
|
+
self._create_buffers()
|
187
191
|
|
192
|
+
def _create_buffers(self):
|
188
193
|
# [size, head_num, head_dim] for each layer
|
189
194
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
190
195
|
self.k_buffer = [
|
191
196
|
torch.empty(
|
192
|
-
(size + 1, head_num, head_dim),
|
197
|
+
(self.size + 1, self.head_num, self.head_dim),
|
193
198
|
dtype=self.store_dtype,
|
194
|
-
device=device,
|
199
|
+
device=self.device,
|
195
200
|
)
|
196
|
-
for _ in range(layer_num)
|
201
|
+
for _ in range(self.layer_num)
|
197
202
|
]
|
198
203
|
self.v_buffer = [
|
199
204
|
torch.empty(
|
200
|
-
(size + 1, head_num, head_dim),
|
205
|
+
(self.size + 1, self.head_num, self.head_dim),
|
201
206
|
dtype=self.store_dtype,
|
202
|
-
device=device,
|
207
|
+
device=self.device,
|
203
208
|
)
|
204
|
-
for _ in range(layer_num)
|
209
|
+
for _ in range(self.layer_num)
|
205
210
|
]
|
206
211
|
|
212
|
+
def _clear_buffers(self):
|
213
|
+
del self.k_buffer
|
214
|
+
del self.v_buffer
|
215
|
+
|
207
216
|
def get_key_buffer(self, layer_id: int):
|
208
217
|
if self.store_dtype != self.dtype:
|
209
218
|
return self.k_buffer[layer_id].view(self.dtype)
|
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
245
254
|
|
246
255
|
|
247
256
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
248
|
-
|
249
257
|
def __init__(
|
250
258
|
self,
|
251
259
|
size: int,
|
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
298
306
|
|
299
307
|
|
300
308
|
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
301
|
-
|
302
309
|
def __init__(
|
303
310
|
self,
|
304
311
|
size: int,
|
@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
|
-
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
29
28
|
from sglang.srt.layers.logits_processor import (
|
30
29
|
LogitsMetadata,
|
31
30
|
LogitsProcessor,
|
32
31
|
LogitsProcessorOutput,
|
33
32
|
)
|
33
|
+
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
35
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
36
36
|
|
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
770
770
|
quant_state_dict,
|
771
771
|
)
|
772
772
|
|
773
|
+
def _is_8bit_weight_name(self, weight_name: str):
|
774
|
+
quantized_suffix = {".scb", ".weight_format"}
|
775
|
+
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
776
|
+
|
777
|
+
def _is_4bit_weight_name(self, weight_name: str):
|
778
|
+
quantized_suffix = {
|
779
|
+
"absmax",
|
780
|
+
"quant_map",
|
781
|
+
"nested_absmax",
|
782
|
+
"nested_quant_map",
|
783
|
+
"bitsandbytes",
|
784
|
+
}
|
785
|
+
suffix = weight_name.split(".")[-1]
|
786
|
+
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
787
|
+
|
773
788
|
def _quantized_8bit_generator(
|
774
789
|
self, hf_weights_files, use_safetensors, quant_state_dict
|
775
790
|
) -> Generator:
|
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
779
794
|
if not weight_name.lower().endswith(".scb"):
|
780
795
|
continue
|
781
796
|
|
782
|
-
weight_key = weight_name.lower().replace(".scb", ".
|
797
|
+
weight_key = weight_name.lower().replace(".scb", ".weight")
|
783
798
|
quant_state_dict[weight_key] = weight_tensor
|
784
799
|
|
785
800
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
786
801
|
hf_weights_files, use_safetensors
|
787
802
|
):
|
788
|
-
|
789
|
-
if not weight_name.endswith((".weight", ".bias")):
|
803
|
+
if self._is_8bit_weight_name(weight_name):
|
790
804
|
continue
|
791
805
|
|
792
|
-
|
793
|
-
|
794
|
-
if qweight_name in quant_state_dict:
|
806
|
+
if weight_name in quant_state_dict:
|
795
807
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
796
|
-
yield
|
808
|
+
yield weight_name, weight_tensor
|
797
809
|
else:
|
798
810
|
yield weight_name, weight_tensor
|
799
811
|
|
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
806
818
|
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
807
819
|
temp_state_dict = {}
|
808
820
|
for weight_name, weight_tensor in weight_iterator:
|
809
|
-
if
|
821
|
+
if not self._is_4bit_weight_name(weight_name):
|
810
822
|
continue
|
811
823
|
# bitsandbytes library requires
|
812
824
|
# weight.quant_state.bitsandbytes__* in CPU
|
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
830
842
|
hf_weights_files, use_safetensors
|
831
843
|
):
|
832
844
|
|
833
|
-
if
|
845
|
+
if self._is_4bit_weight_name(weight_name):
|
834
846
|
continue
|
835
847
|
|
836
848
|
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
|
837
849
|
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
838
850
|
):
|
839
851
|
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
840
|
-
weight_name = weight_name.replace(".weight", ".qweight")
|
841
852
|
quant_state_dict[weight_name] = quant_state
|
842
|
-
yield weight_name
|
853
|
+
yield weight_name, weight_tensor
|
843
854
|
else:
|
844
855
|
yield weight_name, weight_tensor
|
845
856
|
|
sglang/srt/models/dbrx.py
CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
|
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
28
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
29
29
|
|
30
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
QKVParallelLinear,
|
33
32
|
ReplicatedLinear,
|
34
33
|
RowParallelLinear,
|
35
34
|
)
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/deepseek.py
CHANGED
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
30
|
|
31
31
|
from sglang.srt.layers.activation import SiluAndMul
|
32
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
MergedColumnParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
38
37
|
RowParallelLinear,
|
39
38
|
)
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|