sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/bench_serving.py +11 -3
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
- sglang/srt/layers/moe/topk.py +14 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +91 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +71 -34
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +95 -55
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -6
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/llama.py +13 -2
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +79 -2
- sglang/srt/openai_api/protocol.py +50 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +45 -39
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
25
|
+
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
|
|
53
53
|
OpenSessionReqInput,
|
54
54
|
OpenSessionReqOutput,
|
55
55
|
ProfileReq,
|
56
|
+
SessionParams,
|
56
57
|
TokenizedEmbeddingReqInput,
|
57
58
|
TokenizedGenerateReqInput,
|
58
59
|
UpdateWeightFromDiskReqInput,
|
59
60
|
UpdateWeightFromDiskReqOutput,
|
60
61
|
UpdateWeightsFromDistributedReqInput,
|
61
62
|
UpdateWeightsFromDistributedReqOutput,
|
63
|
+
UpdateWeightsFromTensorReqInput,
|
64
|
+
UpdateWeightsFromTensorReqOutput,
|
62
65
|
)
|
63
66
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
64
67
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -173,6 +176,18 @@ class TokenizerManager:
|
|
173
176
|
|
174
177
|
# Others
|
175
178
|
self.gracefully_exit = False
|
179
|
+
self.init_weights_update_group_communicator = _Communicator(
|
180
|
+
self.send_to_scheduler, server_args.dp_size
|
181
|
+
)
|
182
|
+
self.update_weights_from_distributed_communicator = _Communicator(
|
183
|
+
self.send_to_scheduler, server_args.dp_size
|
184
|
+
)
|
185
|
+
self.update_weights_from_tensor_communicator = _Communicator(
|
186
|
+
self.send_to_scheduler, server_args.dp_size
|
187
|
+
)
|
188
|
+
self.get_weights_by_name_communicator = _Communicator(
|
189
|
+
self.send_to_scheduler, server_args.dp_size
|
190
|
+
)
|
176
191
|
|
177
192
|
# Metrics
|
178
193
|
if self.enable_metrics:
|
@@ -190,8 +205,7 @@ class TokenizerManager:
|
|
190
205
|
):
|
191
206
|
created_time = time.time()
|
192
207
|
|
193
|
-
|
194
|
-
self.create_handle_loop()
|
208
|
+
self.auto_create_handle_loop()
|
195
209
|
|
196
210
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
197
211
|
raise ValueError(
|
@@ -251,8 +265,9 @@ class TokenizerManager:
|
|
251
265
|
return_logprob = obj.return_logprob
|
252
266
|
logprob_start_len = obj.logprob_start_len
|
253
267
|
top_logprobs_num = obj.top_logprobs_num
|
254
|
-
|
255
|
-
|
268
|
+
session_params = (
|
269
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
270
|
+
)
|
256
271
|
|
257
272
|
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
258
273
|
raise ValueError(
|
@@ -279,8 +294,7 @@ class TokenizerManager:
|
|
279
294
|
obj.stream,
|
280
295
|
lora_path=obj.lora_path,
|
281
296
|
input_embeds=input_embeds,
|
282
|
-
|
283
|
-
session_rid=session_rid,
|
297
|
+
session_params=session_params,
|
284
298
|
)
|
285
299
|
elif isinstance(obj, EmbeddingReqInput):
|
286
300
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -440,8 +454,7 @@ class TokenizerManager:
|
|
440
454
|
obj: UpdateWeightFromDiskReqInput,
|
441
455
|
request: Optional[fastapi.Request] = None,
|
442
456
|
) -> Tuple[bool, str]:
|
443
|
-
|
444
|
-
self.create_handle_loop()
|
457
|
+
self.auto_create_handle_loop()
|
445
458
|
|
446
459
|
# default the load format to the server_args
|
447
460
|
if obj.load_format is None:
|
@@ -456,7 +469,7 @@ class TokenizerManager:
|
|
456
469
|
|
457
470
|
async def _wait_for_model_update_from_disk(
|
458
471
|
self, obj: UpdateWeightFromDiskReqInput
|
459
|
-
) -> Tuple[bool, str
|
472
|
+
) -> Tuple[bool, str]:
|
460
473
|
self.send_to_scheduler.send_pyobj(obj)
|
461
474
|
self.model_update_result = asyncio.Future()
|
462
475
|
if self.server_args.dp_size == 1:
|
@@ -485,15 +498,11 @@ class TokenizerManager:
|
|
485
498
|
obj: InitWeightsUpdateGroupReqInput,
|
486
499
|
request: Optional[fastapi.Request] = None,
|
487
500
|
) -> Tuple[bool, str]:
|
488
|
-
|
489
|
-
self.create_handle_loop()
|
490
|
-
self.send_to_scheduler.send_pyobj(obj)
|
491
|
-
|
492
|
-
self.init_weights_update_group_result = asyncio.Future()
|
501
|
+
self.auto_create_handle_loop()
|
493
502
|
assert (
|
494
503
|
self.server_args.dp_size == 1
|
495
504
|
), "dp_size must be 1 for init parameter update group"
|
496
|
-
result = await self.
|
505
|
+
result = (await self.init_weights_update_group_communicator(obj))[0]
|
497
506
|
return result.success, result.message
|
498
507
|
|
499
508
|
async def update_weights_from_distributed(
|
@@ -501,51 +510,59 @@ class TokenizerManager:
|
|
501
510
|
obj: UpdateWeightsFromDistributedReqInput,
|
502
511
|
request: Optional[fastapi.Request] = None,
|
503
512
|
) -> Tuple[bool, str]:
|
504
|
-
|
505
|
-
|
513
|
+
self.auto_create_handle_loop()
|
514
|
+
assert (
|
515
|
+
self.server_args.dp_size == 1
|
516
|
+
), "dp_size must be for update weights from distributed"
|
517
|
+
|
518
|
+
# This means that weight sync
|
519
|
+
# cannot run while requests are in progress.
|
520
|
+
async with self.model_update_lock.writer_lock:
|
521
|
+
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
522
|
+
return result.success, result.message
|
523
|
+
|
524
|
+
async def update_weights_from_tensor(
|
525
|
+
self,
|
526
|
+
obj: UpdateWeightsFromTensorReqInput,
|
527
|
+
request: Optional[fastapi.Request] = None,
|
528
|
+
) -> Tuple[bool, str]:
|
529
|
+
self.auto_create_handle_loop()
|
530
|
+
assert (
|
531
|
+
self.server_args.dp_size == 1
|
532
|
+
), "dp_size must be for update weights from distributed"
|
506
533
|
|
507
534
|
# This means that weight sync
|
508
535
|
# cannot run while requests are in progress.
|
509
536
|
async with self.model_update_lock.writer_lock:
|
510
|
-
self.
|
511
|
-
self.parameter_update_result: Awaitable[
|
512
|
-
UpdateWeightsFromDistributedReqOutput
|
513
|
-
] = asyncio.Future()
|
514
|
-
assert (
|
515
|
-
self.server_args.dp_size == 1
|
516
|
-
), "dp_size must be for update weights from distributed"
|
517
|
-
result = await self.parameter_update_result
|
537
|
+
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
518
538
|
return result.success, result.message
|
519
539
|
|
520
540
|
async def get_weights_by_name(
|
521
541
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
522
542
|
):
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
self.send_to_scheduler.send_pyobj(obj)
|
527
|
-
self.get_weights_by_name_result = asyncio.Future()
|
543
|
+
self.auto_create_handle_loop()
|
544
|
+
results = await self.get_weights_by_name_communicator(obj)
|
545
|
+
all_parameters = [r.parameter for r in results]
|
528
546
|
if self.server_args.dp_size == 1:
|
529
|
-
|
530
|
-
return result.parameter
|
547
|
+
return all_parameters[0]
|
531
548
|
else:
|
532
|
-
self.get_weights_by_name_tmp = []
|
533
|
-
result = await self.get_weights_by_name_result
|
534
|
-
all_parameters = [r.parameter for r in result]
|
535
549
|
return all_parameters
|
536
550
|
|
537
551
|
async def open_session(
|
538
552
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
539
553
|
):
|
540
|
-
|
541
|
-
|
554
|
+
self.auto_create_handle_loop()
|
555
|
+
|
556
|
+
if obj.session_id is None:
|
557
|
+
obj.session_id = uuid.uuid4().hex
|
558
|
+
elif obj.session_id in self.session_futures:
|
559
|
+
return None
|
542
560
|
|
543
|
-
session_id = uuid.uuid4().hex
|
544
|
-
obj.session_id = session_id
|
545
561
|
self.send_to_scheduler.send_pyobj(obj)
|
546
|
-
|
547
|
-
|
548
|
-
|
562
|
+
|
563
|
+
self.session_futures[obj.session_id] = asyncio.Future()
|
564
|
+
session_id = await self.session_futures[obj.session_id]
|
565
|
+
del self.session_futures[obj.session_id]
|
549
566
|
return session_id
|
550
567
|
|
551
568
|
async def close_session(
|
@@ -568,7 +585,7 @@ class TokenizerManager:
|
|
568
585
|
background_tasks.add_task(abort_request)
|
569
586
|
return background_tasks
|
570
587
|
|
571
|
-
def
|
588
|
+
def auto_create_handle_loop(self):
|
572
589
|
if not self.to_create_loop:
|
573
590
|
return
|
574
591
|
|
@@ -697,7 +714,7 @@ class TokenizerManager:
|
|
697
714
|
)
|
698
715
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
699
716
|
self.session_futures[recv_obj.session_id].set_result(
|
700
|
-
recv_obj.session_id
|
717
|
+
recv_obj.session_id if recv_obj.success else None
|
701
718
|
)
|
702
719
|
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
703
720
|
if self.server_args.dp_size == 1:
|
@@ -711,21 +728,19 @@ class TokenizerManager:
|
|
711
728
|
assert (
|
712
729
|
self.server_args.dp_size == 1
|
713
730
|
), "dp_size must be 1 for init parameter update group"
|
714
|
-
self.
|
731
|
+
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
715
732
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
716
733
|
assert (
|
717
734
|
self.server_args.dp_size == 1
|
718
735
|
), "dp_size must be 1 for update weights from distributed"
|
719
|
-
self.
|
736
|
+
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
737
|
+
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
738
|
+
assert (
|
739
|
+
self.server_args.dp_size == 1
|
740
|
+
), "dp_size must be 1 for update weights from distributed"
|
741
|
+
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
720
742
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
721
|
-
|
722
|
-
self.get_weights_by_name_result.set_result(recv_obj)
|
723
|
-
else:
|
724
|
-
self.get_weights_by_name_tmp.append(recv_obj)
|
725
|
-
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
726
|
-
self.get_weights_by_name_result.set_result(
|
727
|
-
self.get_weights_by_name_tmp
|
728
|
-
)
|
743
|
+
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
729
744
|
else:
|
730
745
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
731
746
|
|
@@ -809,3 +824,28 @@ class SignalHandler:
|
|
809
824
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
810
825
|
)
|
811
826
|
self.tokenizer_manager.gracefully_exit = True
|
827
|
+
|
828
|
+
|
829
|
+
T = TypeVar("T")
|
830
|
+
|
831
|
+
|
832
|
+
class _Communicator(Generic[T]):
|
833
|
+
def __init__(self, sender, fan_out: int):
|
834
|
+
self._sender = sender
|
835
|
+
self._fan_out = fan_out
|
836
|
+
self._result_future: Optional[asyncio.Future] = None
|
837
|
+
self._result_values: Optional[List[T]] = None
|
838
|
+
|
839
|
+
async def __call__(self, obj):
|
840
|
+
self._sender.send_pyobj(obj)
|
841
|
+
self._result_future = asyncio.Future()
|
842
|
+
self._result_values = []
|
843
|
+
await self._result_future
|
844
|
+
result_values = self._result_values
|
845
|
+
self._result_future = self._result_values = None
|
846
|
+
return result_values
|
847
|
+
|
848
|
+
def handle_recv(self, recv_obj: T):
|
849
|
+
self._result_values.append(recv_obj)
|
850
|
+
if len(self._result_values) == self._fan_out:
|
851
|
+
self._result_future.set_result(None)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|
24
24
|
InitWeightsUpdateGroupReqInput,
|
25
25
|
UpdateWeightFromDiskReqInput,
|
26
26
|
UpdateWeightsFromDistributedReqInput,
|
27
|
+
UpdateWeightsFromTensorReqInput,
|
27
28
|
)
|
28
29
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -188,6 +189,12 @@ class TpModelWorker:
|
|
188
189
|
)
|
189
190
|
return success, message
|
190
191
|
|
192
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
193
|
+
success, message = self.model_runner.update_weights_from_tensor(
|
194
|
+
recv_req.name, recv_req.tensor
|
195
|
+
)
|
196
|
+
return success, message
|
197
|
+
|
191
198
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
199
|
parameter = self.model_runner.get_weights_by_name(
|
193
200
|
recv_req.name, recv_req.truncate_size
|
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
28
|
InitWeightsUpdateGroupReqInput,
|
29
29
|
UpdateWeightFromDiskReqInput,
|
30
30
|
UpdateWeightsFromDistributedReqInput,
|
31
|
+
UpdateWeightsFromTensorReqInput,
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
33
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
|
|
225
226
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
226
227
|
return success, message
|
227
228
|
|
229
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
230
|
+
success, message = self.worker.update_weights_from_tensor(recv_req)
|
231
|
+
return success, message
|
232
|
+
|
228
233
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
229
234
|
return self.worker.get_weights_by_name(recv_req)
|
230
235
|
|
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
|
45
45
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
46
46
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
47
47
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
48
|
+
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
48
49
|
|
49
50
|
|
50
51
|
class ForwardMode(IntEnum):
|
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
|
|
59
60
|
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
60
61
|
IDLE = auto()
|
61
62
|
|
63
|
+
# Used in speculative decoding: verify a batch in the target model.
|
64
|
+
TARGET_VERIFY = auto()
|
65
|
+
# Used in speculative decoding: extend a batch in the draft model.
|
66
|
+
DRAFT_EXTEND = auto()
|
67
|
+
|
62
68
|
# A dummy first batch to start the pipeline for overlap scheduler.
|
63
69
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
64
70
|
DUMMY_FIRST = auto()
|
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
|
|
67
73
|
return self == ForwardMode.PREFILL
|
68
74
|
|
69
75
|
def is_extend(self):
|
70
|
-
return
|
76
|
+
return (
|
77
|
+
self == ForwardMode.EXTEND
|
78
|
+
or self == ForwardMode.MIXED
|
79
|
+
or self == ForwardMode.DRAFT_EXTEND
|
80
|
+
or self == self.TARGET_VERIFY
|
81
|
+
)
|
71
82
|
|
72
83
|
def is_decode(self):
|
73
84
|
return self == ForwardMode.DECODE
|
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
|
|
78
89
|
def is_idle(self):
|
79
90
|
return self == ForwardMode.IDLE
|
80
91
|
|
92
|
+
def is_target_verify(self):
|
93
|
+
return self == ForwardMode.TARGET_VERIFY
|
94
|
+
|
95
|
+
def is_draft_extend(self):
|
96
|
+
return self == ForwardMode.DRAFT_EXTEND
|
97
|
+
|
98
|
+
def is_cuda_graph(self):
|
99
|
+
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
100
|
+
|
81
101
|
def is_dummy_first(self):
|
82
102
|
return self == ForwardMode.DUMMY_FIRST
|
83
103
|
|
@@ -141,14 +161,18 @@ class ForwardBatch:
|
|
141
161
|
token_to_kv_pool: BaseTokenToKVPool = None
|
142
162
|
attn_backend: AttentionBackend = None
|
143
163
|
|
144
|
-
#
|
145
|
-
|
164
|
+
# Speculative decoding
|
165
|
+
spec_info: SpecInfo = None
|
166
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
146
167
|
|
147
168
|
# For DP attention
|
148
169
|
global_num_tokens: Optional[List[int]] = None
|
149
170
|
gathered_buffer: Optional[torch.Tensor] = None
|
150
171
|
can_run_dp_cuda_graph: bool = False
|
151
172
|
|
173
|
+
# For Qwen2-VL
|
174
|
+
mrope_positions: torch.Tensor = None
|
175
|
+
|
152
176
|
def compute_mrope_positions(
|
153
177
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
154
178
|
):
|
@@ -351,3 +375,18 @@ def compute_position_torch(
|
|
351
375
|
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
352
376
|
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
353
377
|
return positions.to(torch.int64), extend_start_loc
|
378
|
+
|
379
|
+
|
380
|
+
class CaptureHiddenMode(IntEnum):
|
381
|
+
NULL = auto()
|
382
|
+
FULL = auto()
|
383
|
+
LAST = auto()
|
384
|
+
|
385
|
+
def need_capture(self):
|
386
|
+
return self != CaptureHiddenMode.NULL
|
387
|
+
|
388
|
+
def is_full(self):
|
389
|
+
return self == CaptureHiddenMode.FULL
|
390
|
+
|
391
|
+
def is_last(self):
|
392
|
+
return self == CaptureHiddenMode.LAST
|
@@ -95,12 +95,6 @@ class ModelRunner:
|
|
95
95
|
):
|
96
96
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
97
97
|
self.server_args.attention_backend = "triton"
|
98
|
-
# FIXME(HandH1998)
|
99
|
-
if (
|
100
|
-
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
|
101
|
-
and not self.server_args.disable_cuda_graph
|
102
|
-
):
|
103
|
-
self.server_args.disable_cuda_graph = True
|
104
98
|
|
105
99
|
if self.server_args.enable_double_sparsity:
|
106
100
|
logger.info(
|
@@ -435,6 +429,10 @@ class ModelRunner:
|
|
435
429
|
logger.error(error_msg)
|
436
430
|
return False, error_msg
|
437
431
|
|
432
|
+
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
433
|
+
self.model.load_weights([(name, tensor)])
|
434
|
+
return True, "Success" # TODO error handling
|
435
|
+
|
438
436
|
def get_weights_by_name(
|
439
437
|
self, name: str, truncate_size: int = 100
|
440
438
|
) -> Optional[torch.Tensor]:
|
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
770
770
|
quant_state_dict,
|
771
771
|
)
|
772
772
|
|
773
|
+
def _is_8bit_weight_name(self, weight_name: str):
|
774
|
+
quantized_suffix = {".scb", ".weight_format"}
|
775
|
+
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
776
|
+
|
777
|
+
def _is_4bit_weight_name(self, weight_name: str):
|
778
|
+
quantized_suffix = {
|
779
|
+
"absmax",
|
780
|
+
"quant_map",
|
781
|
+
"nested_absmax",
|
782
|
+
"nested_quant_map",
|
783
|
+
"bitsandbytes",
|
784
|
+
}
|
785
|
+
suffix = weight_name.split(".")[-1]
|
786
|
+
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
787
|
+
|
773
788
|
def _quantized_8bit_generator(
|
774
789
|
self, hf_weights_files, use_safetensors, quant_state_dict
|
775
790
|
) -> Generator:
|
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
779
794
|
if not weight_name.lower().endswith(".scb"):
|
780
795
|
continue
|
781
796
|
|
782
|
-
weight_key = weight_name.lower().replace(".scb", ".
|
797
|
+
weight_key = weight_name.lower().replace(".scb", ".weight")
|
783
798
|
quant_state_dict[weight_key] = weight_tensor
|
784
799
|
|
785
800
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
786
801
|
hf_weights_files, use_safetensors
|
787
802
|
):
|
788
|
-
|
789
|
-
if not weight_name.endswith((".weight", ".bias")):
|
803
|
+
if self._is_8bit_weight_name(weight_name):
|
790
804
|
continue
|
791
805
|
|
792
|
-
|
793
|
-
|
794
|
-
if qweight_name in quant_state_dict:
|
806
|
+
if weight_name in quant_state_dict:
|
795
807
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
796
|
-
yield
|
808
|
+
yield weight_name, weight_tensor
|
797
809
|
else:
|
798
810
|
yield weight_name, weight_tensor
|
799
811
|
|
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
806
818
|
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
807
819
|
temp_state_dict = {}
|
808
820
|
for weight_name, weight_tensor in weight_iterator:
|
809
|
-
if
|
821
|
+
if not self._is_4bit_weight_name(weight_name):
|
810
822
|
continue
|
811
823
|
# bitsandbytes library requires
|
812
824
|
# weight.quant_state.bitsandbytes__* in CPU
|
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
830
842
|
hf_weights_files, use_safetensors
|
831
843
|
):
|
832
844
|
|
833
|
-
if
|
845
|
+
if self._is_4bit_weight_name(weight_name):
|
834
846
|
continue
|
835
847
|
|
836
848
|
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
|
837
849
|
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
838
850
|
):
|
839
851
|
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
840
|
-
weight_name = weight_name.replace(".weight", ".qweight")
|
841
852
|
quant_state_dict[weight_name] = quant_state
|
842
|
-
yield weight_name
|
853
|
+
yield weight_name, weight_tensor
|
843
854
|
else:
|
844
855
|
yield weight_name, weight_tensor
|
845
856
|
|
sglang/srt/models/gemma2.py
CHANGED
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
|
|
307
307
|
|
308
308
|
|
309
309
|
class Gemma2ForCausalLM(nn.Module):
|
310
|
+
# BitandBytes specific attributes
|
311
|
+
default_bitsandbytes_target_modules = [
|
312
|
+
".gate_proj.",
|
313
|
+
".down_proj.",
|
314
|
+
".up_proj.",
|
315
|
+
".q_proj.",
|
316
|
+
".k_proj.",
|
317
|
+
".v_proj.",
|
318
|
+
".o_proj.",
|
319
|
+
]
|
320
|
+
bitsandbytes_stacked_params_mapping = {
|
321
|
+
# shard_name, weight_name, index
|
322
|
+
"q_proj": ("qkv_proj", 0),
|
323
|
+
"k_proj": ("qkv_proj", 1),
|
324
|
+
"v_proj": ("qkv_proj", 2),
|
325
|
+
"gate_proj": ("gate_up_proj", 0),
|
326
|
+
"up_proj": ("gate_up_proj", 1),
|
327
|
+
}
|
328
|
+
|
310
329
|
packed_modules_mapping = {
|
311
330
|
"qkv_proj": [
|
312
331
|
"q_proj",
|
sglang/srt/models/llama.py
CHANGED
@@ -325,8 +325,8 @@ class LlamaForCausalLM(nn.Module):
|
|
325
325
|
self.config = config
|
326
326
|
self.quant_config = quant_config
|
327
327
|
self.model = LlamaModel(config, quant_config=quant_config)
|
328
|
-
# Llama 3.2 1B
|
329
|
-
# Llama 3.1 8B
|
328
|
+
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
329
|
+
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
330
330
|
if self.config.tie_word_embeddings:
|
331
331
|
self.lm_head = self.model.embed_tokens
|
332
332
|
else:
|
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
|
|
516
516
|
)
|
517
517
|
return None
|
518
518
|
|
519
|
+
def get_embed_and_head(self):
|
520
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
521
|
+
|
522
|
+
def set_embed_and_head(self, embed, head):
|
523
|
+
del self.model.embed_tokens.weight
|
524
|
+
del self.lm_head.weight
|
525
|
+
self.model.embed_tokens.weight = embed
|
526
|
+
self.lm_head.weight = head
|
527
|
+
torch.cuda.empty_cache()
|
528
|
+
torch.cuda.synchronize()
|
529
|
+
|
519
530
|
|
520
531
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
521
532
|
pass
|