sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +77 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/scheduler.py +60 -20
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +41 -10
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -0
- sglang/srt/models/llama.py +11 -0
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +60 -2
- sglang/srt/openai_api/protocol.py +48 -0
- sglang/srt/server.py +26 -3
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +32 -30
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -10,41 +10,116 @@
|
|
10
10
|
# limitations under the License.
|
11
11
|
# ==============================================================================
|
12
12
|
|
13
|
+
import logging
|
13
14
|
import uuid
|
15
|
+
from typing import Dict, Optional
|
14
16
|
|
15
17
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
16
|
-
from sglang.srt.managers.schedule_batch import
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
class SessionReqNode:
|
22
|
+
def __init__(self, req, parent=None, childs=None):
|
23
|
+
self.req = req
|
24
|
+
self.parent = parent
|
25
|
+
if parent is not None:
|
26
|
+
parent.childs.append(self)
|
27
|
+
self.childs = [] if not childs else childs
|
28
|
+
|
29
|
+
def clear_childs(self, req_dict):
|
30
|
+
for req_node in self.childs:
|
31
|
+
req_node.clear(req_dict)
|
32
|
+
self.childs = []
|
33
|
+
|
34
|
+
def clear(self, req_dict):
|
35
|
+
for req_node in self.childs:
|
36
|
+
req_node.clear(req_dict)
|
37
|
+
|
38
|
+
if self.req.finished_reason == None:
|
39
|
+
self.req.to_abort = True
|
40
|
+
del req_dict[self.req.rid]
|
41
|
+
|
42
|
+
def abort(self):
|
43
|
+
if self.req.finished_reason == None:
|
44
|
+
self.req.to_abort = True
|
45
|
+
|
46
|
+
def __str__(self):
|
47
|
+
return self._str_helper(self.req.rid)
|
48
|
+
|
49
|
+
def _str_helper(self, prefix=""):
|
50
|
+
if len(self.childs) == 0:
|
51
|
+
return prefix + "\n"
|
52
|
+
else:
|
53
|
+
origin_prefix = prefix
|
54
|
+
prefix += " -- " + self.childs[0].req.rid
|
55
|
+
ret = self.childs[0]._str_helper(prefix)
|
56
|
+
for child in self.childs[1:]:
|
57
|
+
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
58
|
+
ret += child._str_helper(prefix)
|
59
|
+
return ret
|
17
60
|
|
18
61
|
|
19
62
|
class Session:
|
20
|
-
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
63
|
+
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
|
21
64
|
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
22
65
|
self.capacity_of_str_len = capacity_of_str_len
|
23
|
-
self.
|
66
|
+
self.req_nodes: Dict[str, SessionReqNode] = {}
|
24
67
|
|
25
68
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
69
|
+
assert req.session_params is not None
|
70
|
+
session_params = req.session_params
|
71
|
+
|
72
|
+
last_req_node = None
|
73
|
+
last_req = None
|
74
|
+
abort = False
|
75
|
+
if session_params.replace:
|
76
|
+
if session_params.rid is None:
|
77
|
+
for _, req_node in self.req_nodes.items():
|
78
|
+
req_node.clear(self.req_nodes)
|
79
|
+
else:
|
80
|
+
if session_params.rid not in self.req_nodes:
|
81
|
+
abort = True
|
82
|
+
else:
|
83
|
+
last_req_node = self.req_nodes[session_params.rid]
|
84
|
+
last_req_node.abort()
|
85
|
+
last_req = last_req_node.req
|
86
|
+
last_req_node.clear_childs(self.req_nodes)
|
31
87
|
else:
|
32
|
-
|
33
|
-
|
88
|
+
if session_params.rid is not None:
|
89
|
+
if session_params.rid not in self.req_nodes:
|
90
|
+
abort = True
|
91
|
+
else:
|
92
|
+
last_req_node = self.req_nodes[session_params.rid]
|
93
|
+
last_req = last_req_node.req
|
94
|
+
if not last_req.finished():
|
95
|
+
logging.warning(
|
96
|
+
"The request in a session is appending to a request that hasn't finished."
|
97
|
+
)
|
98
|
+
abort = True
|
99
|
+
|
100
|
+
if last_req is not None:
|
101
|
+
# trim bos token if it is an append
|
102
|
+
if req.input_ids[0] == tokenizer.bos_token_id:
|
103
|
+
req.input_ids = req.input_ids[1:]
|
104
|
+
|
34
105
|
input_ids = (
|
35
|
-
|
36
|
-
+
|
37
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
38
|
-
]
|
39
|
-
+ req.input_ids
|
106
|
+
last_req.origin_input_ids
|
107
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
40
108
|
)
|
109
|
+
if session_params.offset and session_params.offset != 0:
|
110
|
+
input_ids = input_ids[: session_params.offset] + req.input_ids
|
111
|
+
else:
|
112
|
+
input_ids += req.input_ids
|
41
113
|
input_ids_unpadded = (
|
42
|
-
|
43
|
-
+
|
44
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
45
|
-
]
|
46
|
-
+ req.input_ids
|
114
|
+
last_req.origin_input_ids_unpadded
|
115
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
47
116
|
)
|
117
|
+
if session_params.offset and session_params.offset != 0:
|
118
|
+
input_ids_unpadded = (
|
119
|
+
input_ids_unpadded[: session_params.offset] + req.input_ids
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
input_ids_unpadded += req.input_ids
|
48
123
|
else:
|
49
124
|
input_ids = req.input_ids
|
50
125
|
input_ids_unpadded = req.input_ids
|
@@ -57,13 +132,13 @@ class Session:
|
|
57
132
|
lora_path=req.lora_path,
|
58
133
|
session_id=self.session_id,
|
59
134
|
)
|
60
|
-
if
|
61
|
-
new_req.image_inputs =
|
135
|
+
if last_req is not None:
|
136
|
+
new_req.image_inputs = last_req.image_inputs
|
62
137
|
new_req.tokenizer = tokenizer
|
63
|
-
if
|
64
|
-
new_req.
|
65
|
-
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
66
|
-
)
|
138
|
+
if abort:
|
139
|
+
new_req.to_abort = True
|
67
140
|
else:
|
68
|
-
|
141
|
+
new_req_node = SessionReqNode(new_req, last_req_node)
|
142
|
+
self.req_nodes[req.rid] = new_req_node
|
143
|
+
|
69
144
|
return new_req
|
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
|
|
53
53
|
OpenSessionReqInput,
|
54
54
|
OpenSessionReqOutput,
|
55
55
|
ProfileReq,
|
56
|
+
SessionParams,
|
56
57
|
TokenizedEmbeddingReqInput,
|
57
58
|
TokenizedGenerateReqInput,
|
58
59
|
UpdateWeightFromDiskReqInput,
|
59
60
|
UpdateWeightFromDiskReqOutput,
|
60
61
|
UpdateWeightsFromDistributedReqInput,
|
61
62
|
UpdateWeightsFromDistributedReqOutput,
|
63
|
+
UpdateWeightsFromTensorReqInput,
|
64
|
+
UpdateWeightsFromTensorReqOutput,
|
62
65
|
)
|
63
66
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
64
67
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -179,6 +182,9 @@ class TokenizerManager:
|
|
179
182
|
self.update_weights_from_distributed_communicator = _Communicator(
|
180
183
|
self.send_to_scheduler, server_args.dp_size
|
181
184
|
)
|
185
|
+
self.update_weights_from_tensor_communicator = _Communicator(
|
186
|
+
self.send_to_scheduler, server_args.dp_size
|
187
|
+
)
|
182
188
|
self.get_weights_by_name_communicator = _Communicator(
|
183
189
|
self.send_to_scheduler, server_args.dp_size
|
184
190
|
)
|
@@ -259,8 +265,9 @@ class TokenizerManager:
|
|
259
265
|
return_logprob = obj.return_logprob
|
260
266
|
logprob_start_len = obj.logprob_start_len
|
261
267
|
top_logprobs_num = obj.top_logprobs_num
|
262
|
-
|
263
|
-
|
268
|
+
session_params = (
|
269
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
270
|
+
)
|
264
271
|
|
265
272
|
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
266
273
|
raise ValueError(
|
@@ -287,8 +294,7 @@ class TokenizerManager:
|
|
287
294
|
obj.stream,
|
288
295
|
lora_path=obj.lora_path,
|
289
296
|
input_embeds=input_embeds,
|
290
|
-
|
291
|
-
session_rid=session_rid,
|
297
|
+
session_params=session_params,
|
292
298
|
)
|
293
299
|
elif isinstance(obj, EmbeddingReqInput):
|
294
300
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -515,6 +521,22 @@ class TokenizerManager:
|
|
515
521
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
516
522
|
return result.success, result.message
|
517
523
|
|
524
|
+
async def update_weights_from_tensor(
|
525
|
+
self,
|
526
|
+
obj: UpdateWeightsFromTensorReqInput,
|
527
|
+
request: Optional[fastapi.Request] = None,
|
528
|
+
) -> Tuple[bool, str]:
|
529
|
+
self.auto_create_handle_loop()
|
530
|
+
assert (
|
531
|
+
self.server_args.dp_size == 1
|
532
|
+
), "dp_size must be for update weights from distributed"
|
533
|
+
|
534
|
+
# This means that weight sync
|
535
|
+
# cannot run while requests are in progress.
|
536
|
+
async with self.model_update_lock.writer_lock:
|
537
|
+
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
538
|
+
return result.success, result.message
|
539
|
+
|
518
540
|
async def get_weights_by_name(
|
519
541
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
520
542
|
):
|
@@ -531,12 +553,16 @@ class TokenizerManager:
|
|
531
553
|
):
|
532
554
|
self.auto_create_handle_loop()
|
533
555
|
|
534
|
-
session_id
|
535
|
-
|
556
|
+
if obj.session_id is None:
|
557
|
+
obj.session_id = uuid.uuid4().hex
|
558
|
+
elif obj.session_id in self.session_futures:
|
559
|
+
return None
|
560
|
+
|
536
561
|
self.send_to_scheduler.send_pyobj(obj)
|
537
|
-
|
538
|
-
|
539
|
-
|
562
|
+
|
563
|
+
self.session_futures[obj.session_id] = asyncio.Future()
|
564
|
+
session_id = await self.session_futures[obj.session_id]
|
565
|
+
del self.session_futures[obj.session_id]
|
540
566
|
return session_id
|
541
567
|
|
542
568
|
async def close_session(
|
@@ -688,7 +714,7 @@ class TokenizerManager:
|
|
688
714
|
)
|
689
715
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
690
716
|
self.session_futures[recv_obj.session_id].set_result(
|
691
|
-
recv_obj.session_id
|
717
|
+
recv_obj.session_id if recv_obj.success else None
|
692
718
|
)
|
693
719
|
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
694
720
|
if self.server_args.dp_size == 1:
|
@@ -708,6 +734,11 @@ class TokenizerManager:
|
|
708
734
|
self.server_args.dp_size == 1
|
709
735
|
), "dp_size must be 1 for update weights from distributed"
|
710
736
|
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
737
|
+
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
738
|
+
assert (
|
739
|
+
self.server_args.dp_size == 1
|
740
|
+
), "dp_size must be 1 for update weights from distributed"
|
741
|
+
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
711
742
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
712
743
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
713
744
|
else:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|
24
24
|
InitWeightsUpdateGroupReqInput,
|
25
25
|
UpdateWeightFromDiskReqInput,
|
26
26
|
UpdateWeightsFromDistributedReqInput,
|
27
|
+
UpdateWeightsFromTensorReqInput,
|
27
28
|
)
|
28
29
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -188,6 +189,12 @@ class TpModelWorker:
|
|
188
189
|
)
|
189
190
|
return success, message
|
190
191
|
|
192
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
193
|
+
success, message = self.model_runner.update_weights_from_tensor(
|
194
|
+
recv_req.name, recv_req.tensor
|
195
|
+
)
|
196
|
+
return success, message
|
197
|
+
|
191
198
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
199
|
parameter = self.model_runner.get_weights_by_name(
|
193
200
|
recv_req.name, recv_req.truncate_size
|
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
28
|
InitWeightsUpdateGroupReqInput,
|
29
29
|
UpdateWeightFromDiskReqInput,
|
30
30
|
UpdateWeightsFromDistributedReqInput,
|
31
|
+
UpdateWeightsFromTensorReqInput,
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
33
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
|
|
225
226
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
226
227
|
return success, message
|
227
228
|
|
229
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
230
|
+
success, message = self.worker.update_weights_from_tensor(recv_req)
|
231
|
+
return success, message
|
232
|
+
|
228
233
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
229
234
|
return self.worker.get_weights_by_name(recv_req)
|
230
235
|
|
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
|
45
45
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
46
46
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
47
47
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
48
|
+
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
48
49
|
|
49
50
|
|
50
51
|
class ForwardMode(IntEnum):
|
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
|
|
59
60
|
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
60
61
|
IDLE = auto()
|
61
62
|
|
63
|
+
# Used in speculative decoding: verify a batch in the target model.
|
64
|
+
TARGET_VERIFY = auto()
|
65
|
+
# Used in speculative decoding: extend a batch in the draft model.
|
66
|
+
DRAFT_EXTEND = auto()
|
67
|
+
|
62
68
|
# A dummy first batch to start the pipeline for overlap scheduler.
|
63
69
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
64
70
|
DUMMY_FIRST = auto()
|
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
|
|
67
73
|
return self == ForwardMode.PREFILL
|
68
74
|
|
69
75
|
def is_extend(self):
|
70
|
-
return
|
76
|
+
return (
|
77
|
+
self == ForwardMode.EXTEND
|
78
|
+
or self == ForwardMode.MIXED
|
79
|
+
or self == ForwardMode.DRAFT_EXTEND
|
80
|
+
or self == self.TARGET_VERIFY
|
81
|
+
)
|
71
82
|
|
72
83
|
def is_decode(self):
|
73
84
|
return self == ForwardMode.DECODE
|
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
|
|
78
89
|
def is_idle(self):
|
79
90
|
return self == ForwardMode.IDLE
|
80
91
|
|
92
|
+
def is_target_verify(self):
|
93
|
+
return self == ForwardMode.TARGET_VERIFY
|
94
|
+
|
95
|
+
def is_draft_extend(self):
|
96
|
+
return self == ForwardMode.DRAFT_EXTEND
|
97
|
+
|
98
|
+
def is_cuda_graph(self):
|
99
|
+
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
100
|
+
|
81
101
|
def is_dummy_first(self):
|
82
102
|
return self == ForwardMode.DUMMY_FIRST
|
83
103
|
|
@@ -141,14 +161,18 @@ class ForwardBatch:
|
|
141
161
|
token_to_kv_pool: BaseTokenToKVPool = None
|
142
162
|
attn_backend: AttentionBackend = None
|
143
163
|
|
144
|
-
#
|
145
|
-
|
164
|
+
# Speculative decoding
|
165
|
+
spec_info: SpecInfo = None
|
166
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
146
167
|
|
147
168
|
# For DP attention
|
148
169
|
global_num_tokens: Optional[List[int]] = None
|
149
170
|
gathered_buffer: Optional[torch.Tensor] = None
|
150
171
|
can_run_dp_cuda_graph: bool = False
|
151
172
|
|
173
|
+
# For Qwen2-VL
|
174
|
+
mrope_positions: torch.Tensor = None
|
175
|
+
|
152
176
|
def compute_mrope_positions(
|
153
177
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
154
178
|
):
|
@@ -351,3 +375,18 @@ def compute_position_torch(
|
|
351
375
|
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
352
376
|
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
353
377
|
return positions.to(torch.int64), extend_start_loc
|
378
|
+
|
379
|
+
|
380
|
+
class CaptureHiddenMode(IntEnum):
|
381
|
+
NULL = auto()
|
382
|
+
FULL = auto()
|
383
|
+
LAST = auto()
|
384
|
+
|
385
|
+
def need_capture(self):
|
386
|
+
return self != CaptureHiddenMode.NULL
|
387
|
+
|
388
|
+
def is_full(self):
|
389
|
+
return self == CaptureHiddenMode.FULL
|
390
|
+
|
391
|
+
def is_last(self):
|
392
|
+
return self == CaptureHiddenMode.LAST
|
@@ -429,6 +429,10 @@ class ModelRunner:
|
|
429
429
|
logger.error(error_msg)
|
430
430
|
return False, error_msg
|
431
431
|
|
432
|
+
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
433
|
+
self.model.load_weights([(name, tensor)])
|
434
|
+
return True, "Success" # TODO error handling
|
435
|
+
|
432
436
|
def get_weights_by_name(
|
433
437
|
self, name: str, truncate_size: int = 100
|
434
438
|
) -> Optional[torch.Tensor]:
|
sglang/srt/models/llama.py
CHANGED
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
|
|
516
516
|
)
|
517
517
|
return None
|
518
518
|
|
519
|
+
def get_embed_and_head(self):
|
520
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
521
|
+
|
522
|
+
def set_embed_and_head(self, embed, head):
|
523
|
+
del self.model.embed_tokens.weight
|
524
|
+
del self.lm_head.weight
|
525
|
+
self.model.embed_tokens.weight = embed
|
526
|
+
self.lm_head.weight = head
|
527
|
+
torch.cuda.empty_cache()
|
528
|
+
torch.cuda.synchronize()
|
529
|
+
|
519
530
|
|
520
531
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
521
532
|
pass
|
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
18
|
+
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import LlamaConfig
|
25
|
+
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
|
+
ParallelLMHead,
|
30
|
+
VocabParallelEmbedding,
|
31
|
+
)
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
|
+
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
34
|
+
|
35
|
+
|
36
|
+
class LlamaDecoderLayer(LlamaDecoderLayer):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
config: LlamaConfig,
|
40
|
+
layer_id: int = 0,
|
41
|
+
quant_config: Optional[QuantizationConfig] = None,
|
42
|
+
prefix: str = "",
|
43
|
+
) -> None:
|
44
|
+
super().__init__(config, layer_id, quant_config, prefix)
|
45
|
+
|
46
|
+
# Skip the input_layernorm
|
47
|
+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
48
|
+
if layer_id == 0:
|
49
|
+
del self.input_layernorm
|
50
|
+
setattr(self, "input_layernorm", lambda x: x)
|
51
|
+
|
52
|
+
|
53
|
+
class LlamaModel(nn.Module):
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
config: LlamaConfig,
|
57
|
+
quant_config: Optional[QuantizationConfig] = None,
|
58
|
+
) -> None:
|
59
|
+
super().__init__()
|
60
|
+
self.config = config
|
61
|
+
self.vocab_size = config.vocab_size
|
62
|
+
self.embed_tokens = VocabParallelEmbedding(
|
63
|
+
config.vocab_size,
|
64
|
+
config.hidden_size,
|
65
|
+
)
|
66
|
+
self.layers = nn.ModuleList(
|
67
|
+
[
|
68
|
+
LlamaDecoderLayer(
|
69
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
70
|
+
)
|
71
|
+
for i in range(config.num_hidden_layers)
|
72
|
+
]
|
73
|
+
)
|
74
|
+
self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
|
75
|
+
|
76
|
+
def forward(
|
77
|
+
self,
|
78
|
+
input_ids: torch.Tensor,
|
79
|
+
positions: torch.Tensor,
|
80
|
+
forward_batch: ForwardBatch,
|
81
|
+
input_embeds: torch.Tensor = None,
|
82
|
+
) -> torch.Tensor:
|
83
|
+
if input_embeds is None:
|
84
|
+
hidden_states = self.embed_tokens(input_ids)
|
85
|
+
else:
|
86
|
+
hidden_states = input_embeds
|
87
|
+
|
88
|
+
hidden_states = self.fc(
|
89
|
+
torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
|
90
|
+
)
|
91
|
+
|
92
|
+
residual = None
|
93
|
+
for i in range(len(self.layers)):
|
94
|
+
layer = self.layers[i]
|
95
|
+
hidden_states, residual = layer(
|
96
|
+
positions,
|
97
|
+
hidden_states,
|
98
|
+
forward_batch,
|
99
|
+
residual,
|
100
|
+
)
|
101
|
+
return hidden_states + residual
|
102
|
+
|
103
|
+
|
104
|
+
class LlamaForCausalLMEagle(LlamaForCausalLM):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
config: LlamaConfig,
|
108
|
+
quant_config: Optional[QuantizationConfig] = None,
|
109
|
+
cache_config=None,
|
110
|
+
) -> None:
|
111
|
+
nn.Module.__init__(self)
|
112
|
+
self.config = config
|
113
|
+
self.quant_config = quant_config
|
114
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
115
|
+
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
116
|
+
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
117
|
+
if self.config.tie_word_embeddings:
|
118
|
+
self.lm_head = self.model.embed_tokens
|
119
|
+
else:
|
120
|
+
self.lm_head = ParallelLMHead(
|
121
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
122
|
+
)
|
123
|
+
self.logits_processor = LogitsProcessor(config)
|
124
|
+
|
125
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
126
|
+
for name, loaded_weight in weights:
|
127
|
+
if "lm_head" not in name:
|
128
|
+
name = "model." + name
|
129
|
+
super().load_weights([(name, loaded_weight)])
|
130
|
+
|
131
|
+
|
132
|
+
EntryClass = [LlamaForCausalLMEagle]
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
|
|
65
65
|
FileDeleteResponse,
|
66
66
|
FileRequest,
|
67
67
|
FileResponse,
|
68
|
+
FunctionResponse,
|
68
69
|
LogProbs,
|
70
|
+
ToolCall,
|
69
71
|
TopLogprob,
|
70
72
|
UsageInfo,
|
71
73
|
)
|
74
|
+
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
72
75
|
from sglang.utils import get_exception_traceback
|
73
76
|
|
74
77
|
logger = logging.getLogger(__name__)
|
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
|
|
879
882
|
# None skips any image processing in GenerateReqInput.
|
880
883
|
if not isinstance(request.messages, str):
|
881
884
|
# Apply chat template and its stop strings.
|
885
|
+
tools = None
|
886
|
+
if request.tools and request.tool_choice != "none":
|
887
|
+
request.skip_special_tokens = False
|
888
|
+
if request.stream:
|
889
|
+
logger.warning("Streaming is not supported with tools.")
|
890
|
+
request.stream = False
|
891
|
+
if not isinstance(request.tool_choice, str):
|
892
|
+
tools = [
|
893
|
+
item.function.model_dump()
|
894
|
+
for item in request.tools
|
895
|
+
if item.function.name == request.tool_choice.function.name
|
896
|
+
]
|
897
|
+
else:
|
898
|
+
tools = [item.function.model_dump() for item in request.tools]
|
899
|
+
|
882
900
|
if chat_template_name is None:
|
883
901
|
openai_compatible_messages = []
|
884
902
|
for message in request.messages:
|
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
|
|
902
920
|
openai_compatible_messages,
|
903
921
|
tokenize=True,
|
904
922
|
add_generation_prompt=True,
|
923
|
+
tools=tools,
|
905
924
|
)
|
906
925
|
if assistant_prefix:
|
907
926
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1041
1060
|
|
1042
1061
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
1043
1062
|
|
1063
|
+
tool_calls = None
|
1064
|
+
text = ret_item["text"]
|
1065
|
+
|
1066
|
+
if isinstance(request, list):
|
1067
|
+
tool_choice = request[idx].tool_choice
|
1068
|
+
tools = request[idx].tools
|
1069
|
+
else:
|
1070
|
+
tool_choice = request.tool_choice
|
1071
|
+
tools = request.tools
|
1072
|
+
|
1073
|
+
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
1074
|
+
if finish_reason == "stop":
|
1075
|
+
finish_reason = "tool_calls"
|
1076
|
+
try:
|
1077
|
+
text, call_info_list = parse_tool_response(text, tools) # noqa
|
1078
|
+
tool_calls = [
|
1079
|
+
ToolCall(
|
1080
|
+
id=str(call_info[0]),
|
1081
|
+
function=FunctionResponse(
|
1082
|
+
name=call_info[1], arguments=call_info[2]
|
1083
|
+
),
|
1084
|
+
)
|
1085
|
+
for call_info in call_info_list
|
1086
|
+
]
|
1087
|
+
except Exception as e:
|
1088
|
+
logger.error(f"Exception: {e}")
|
1089
|
+
return create_error_response(
|
1090
|
+
HTTPStatus.BAD_REQUEST,
|
1091
|
+
"Failed to parse fc related info to json format!",
|
1092
|
+
)
|
1093
|
+
|
1044
1094
|
if to_file:
|
1045
1095
|
# to make the choice data json serializable
|
1046
1096
|
choice_data = {
|
1047
1097
|
"index": 0,
|
1048
|
-
"message": {
|
1098
|
+
"message": {
|
1099
|
+
"role": "assistant",
|
1100
|
+
"content": ret_item["text"] if tool_calls is None else None,
|
1101
|
+
"tool_calls": tool_calls,
|
1102
|
+
},
|
1049
1103
|
"logprobs": choice_logprobs,
|
1050
1104
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
1051
1105
|
"matched_stop": (
|
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1057
1111
|
else:
|
1058
1112
|
choice_data = ChatCompletionResponseChoice(
|
1059
1113
|
index=idx,
|
1060
|
-
message=ChatMessage(
|
1114
|
+
message=ChatMessage(
|
1115
|
+
role="assistant",
|
1116
|
+
content=ret_item["text"] if tool_calls is None else None,
|
1117
|
+
tool_calls=tool_calls,
|
1118
|
+
),
|
1061
1119
|
logprobs=choice_logprobs,
|
1062
1120
|
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1063
1121
|
matched_stop=(
|