sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/srt/configs/model_config.py +11 -2
  3. sglang/srt/layers/attention/__init__.py +0 -1
  4. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  5. sglang/srt/layers/logits_processor.py +30 -2
  6. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
  7. sglang/srt/layers/quantization/fp8.py +42 -2
  8. sglang/srt/layers/quantization/fp8_kernel.py +77 -18
  9. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  10. sglang/srt/managers/io_struct.py +29 -8
  11. sglang/srt/managers/schedule_batch.py +22 -15
  12. sglang/srt/managers/scheduler.py +60 -20
  13. sglang/srt/managers/session_controller.py +102 -27
  14. sglang/srt/managers/tokenizer_manager.py +41 -10
  15. sglang/srt/managers/tp_worker.py +7 -0
  16. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  17. sglang/srt/model_executor/forward_batch_info.py +42 -3
  18. sglang/srt/model_executor/model_runner.py +4 -0
  19. sglang/srt/models/llama.py +11 -0
  20. sglang/srt/models/llama_eagle.py +132 -0
  21. sglang/srt/openai_api/adapter.py +60 -2
  22. sglang/srt/openai_api/protocol.py +48 -0
  23. sglang/srt/server.py +26 -3
  24. sglang/srt/server_args.py +17 -30
  25. sglang/srt/speculative/spec_info.py +19 -0
  26. sglang/srt/utils.py +62 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +3 -3
  29. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +32 -30
  30. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  31. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  32. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -10,41 +10,116 @@
10
10
  # limitations under the License.
11
11
  # ==============================================================================
12
12
 
13
+ import logging
13
14
  import uuid
15
+ from typing import Dict, Optional
14
16
 
15
17
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
16
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
18
+ from sglang.srt.managers.schedule_batch import Req
19
+
20
+
21
+ class SessionReqNode:
22
+ def __init__(self, req, parent=None, childs=None):
23
+ self.req = req
24
+ self.parent = parent
25
+ if parent is not None:
26
+ parent.childs.append(self)
27
+ self.childs = [] if not childs else childs
28
+
29
+ def clear_childs(self, req_dict):
30
+ for req_node in self.childs:
31
+ req_node.clear(req_dict)
32
+ self.childs = []
33
+
34
+ def clear(self, req_dict):
35
+ for req_node in self.childs:
36
+ req_node.clear(req_dict)
37
+
38
+ if self.req.finished_reason == None:
39
+ self.req.to_abort = True
40
+ del req_dict[self.req.rid]
41
+
42
+ def abort(self):
43
+ if self.req.finished_reason == None:
44
+ self.req.to_abort = True
45
+
46
+ def __str__(self):
47
+ return self._str_helper(self.req.rid)
48
+
49
+ def _str_helper(self, prefix=""):
50
+ if len(self.childs) == 0:
51
+ return prefix + "\n"
52
+ else:
53
+ origin_prefix = prefix
54
+ prefix += " -- " + self.childs[0].req.rid
55
+ ret = self.childs[0]._str_helper(prefix)
56
+ for child in self.childs[1:]:
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
+ ret += child._str_helper(prefix)
59
+ return ret
17
60
 
18
61
 
19
62
  class Session:
20
- def __init__(self, capacity_of_str_len: int, session_id: str = None):
63
+ def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
21
64
  self.session_id = session_id if session_id is not None else uuid.uuid4().hex
22
65
  self.capacity_of_str_len = capacity_of_str_len
23
- self.reqs: List[Req] = []
66
+ self.req_nodes: Dict[str, SessionReqNode] = {}
24
67
 
25
68
  def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
26
- if req.session_rid is not None:
27
- while len(self.reqs) > 0:
28
- if self.reqs[-1].rid == req.session_rid:
29
- break
30
- self.reqs = self.reqs[:-1]
69
+ assert req.session_params is not None
70
+ session_params = req.session_params
71
+
72
+ last_req_node = None
73
+ last_req = None
74
+ abort = False
75
+ if session_params.replace:
76
+ if session_params.rid is None:
77
+ for _, req_node in self.req_nodes.items():
78
+ req_node.clear(self.req_nodes)
79
+ else:
80
+ if session_params.rid not in self.req_nodes:
81
+ abort = True
82
+ else:
83
+ last_req_node = self.req_nodes[session_params.rid]
84
+ last_req_node.abort()
85
+ last_req = last_req_node.req
86
+ last_req_node.clear_childs(self.req_nodes)
31
87
  else:
32
- self.reqs = []
33
- if len(self.reqs) > 0:
88
+ if session_params.rid is not None:
89
+ if session_params.rid not in self.req_nodes:
90
+ abort = True
91
+ else:
92
+ last_req_node = self.req_nodes[session_params.rid]
93
+ last_req = last_req_node.req
94
+ if not last_req.finished():
95
+ logging.warning(
96
+ "The request in a session is appending to a request that hasn't finished."
97
+ )
98
+ abort = True
99
+
100
+ if last_req is not None:
101
+ # trim bos token if it is an append
102
+ if req.input_ids[0] == tokenizer.bos_token_id:
103
+ req.input_ids = req.input_ids[1:]
104
+
34
105
  input_ids = (
35
- self.reqs[-1].origin_input_ids
36
- + self.reqs[-1].output_ids[
37
- : self.reqs[-1].sampling_params.max_new_tokens
38
- ]
39
- + req.input_ids
106
+ last_req.origin_input_ids
107
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
40
108
  )
109
+ if session_params.offset and session_params.offset != 0:
110
+ input_ids = input_ids[: session_params.offset] + req.input_ids
111
+ else:
112
+ input_ids += req.input_ids
41
113
  input_ids_unpadded = (
42
- self.reqs[-1].origin_input_ids_unpadded
43
- + self.reqs[-1].output_ids[
44
- : self.reqs[-1].sampling_params.max_new_tokens
45
- ]
46
- + req.input_ids
114
+ last_req.origin_input_ids_unpadded
115
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
47
116
  )
117
+ if session_params.offset and session_params.offset != 0:
118
+ input_ids_unpadded = (
119
+ input_ids_unpadded[: session_params.offset] + req.input_ids
120
+ )
121
+ else:
122
+ input_ids_unpadded += req.input_ids
48
123
  else:
49
124
  input_ids = req.input_ids
50
125
  input_ids_unpadded = req.input_ids
@@ -57,13 +132,13 @@ class Session:
57
132
  lora_path=req.lora_path,
58
133
  session_id=self.session_id,
59
134
  )
60
- if len(self.reqs) > 0:
61
- new_req.image_inputs = self.reqs[-1].image_inputs
135
+ if last_req is not None:
136
+ new_req.image_inputs = last_req.image_inputs
62
137
  new_req.tokenizer = tokenizer
63
- if req.session_rid is not None and len(self.reqs) == 0:
64
- new_req.finished_reason = FINISH_ABORT(
65
- f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
66
- )
138
+ if abort:
139
+ new_req.to_abort = True
67
140
  else:
68
- self.reqs.append(new_req)
141
+ new_req_node = SessionReqNode(new_req, last_req_node)
142
+ self.req_nodes[req.rid] = new_req_node
143
+
69
144
  return new_req
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
53
53
  OpenSessionReqInput,
54
54
  OpenSessionReqOutput,
55
55
  ProfileReq,
56
+ SessionParams,
56
57
  TokenizedEmbeddingReqInput,
57
58
  TokenizedGenerateReqInput,
58
59
  UpdateWeightFromDiskReqInput,
59
60
  UpdateWeightFromDiskReqOutput,
60
61
  UpdateWeightsFromDistributedReqInput,
61
62
  UpdateWeightsFromDistributedReqOutput,
63
+ UpdateWeightsFromTensorReqInput,
64
+ UpdateWeightsFromTensorReqOutput,
62
65
  )
63
66
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
64
67
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -179,6 +182,9 @@ class TokenizerManager:
179
182
  self.update_weights_from_distributed_communicator = _Communicator(
180
183
  self.send_to_scheduler, server_args.dp_size
181
184
  )
185
+ self.update_weights_from_tensor_communicator = _Communicator(
186
+ self.send_to_scheduler, server_args.dp_size
187
+ )
182
188
  self.get_weights_by_name_communicator = _Communicator(
183
189
  self.send_to_scheduler, server_args.dp_size
184
190
  )
@@ -259,8 +265,9 @@ class TokenizerManager:
259
265
  return_logprob = obj.return_logprob
260
266
  logprob_start_len = obj.logprob_start_len
261
267
  top_logprobs_num = obj.top_logprobs_num
262
- session_id = obj.session[0] if obj.session else None
263
- session_rid = obj.session[1] if obj.session else None
268
+ session_params = (
269
+ SessionParams(**obj.session_params) if obj.session_params else None
270
+ )
264
271
 
265
272
  if obj.input_ids is not None and len(input_ids) >= self.context_len:
266
273
  raise ValueError(
@@ -287,8 +294,7 @@ class TokenizerManager:
287
294
  obj.stream,
288
295
  lora_path=obj.lora_path,
289
296
  input_embeds=input_embeds,
290
- session_id=session_id,
291
- session_rid=session_rid,
297
+ session_params=session_params,
292
298
  )
293
299
  elif isinstance(obj, EmbeddingReqInput):
294
300
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -515,6 +521,22 @@ class TokenizerManager:
515
521
  result = (await self.update_weights_from_distributed_communicator(obj))[0]
516
522
  return result.success, result.message
517
523
 
524
+ async def update_weights_from_tensor(
525
+ self,
526
+ obj: UpdateWeightsFromTensorReqInput,
527
+ request: Optional[fastapi.Request] = None,
528
+ ) -> Tuple[bool, str]:
529
+ self.auto_create_handle_loop()
530
+ assert (
531
+ self.server_args.dp_size == 1
532
+ ), "dp_size must be for update weights from distributed"
533
+
534
+ # This means that weight sync
535
+ # cannot run while requests are in progress.
536
+ async with self.model_update_lock.writer_lock:
537
+ result = (await self.update_weights_from_tensor_communicator(obj))[0]
538
+ return result.success, result.message
539
+
518
540
  async def get_weights_by_name(
519
541
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
520
542
  ):
@@ -531,12 +553,16 @@ class TokenizerManager:
531
553
  ):
532
554
  self.auto_create_handle_loop()
533
555
 
534
- session_id = uuid.uuid4().hex
535
- obj.session_id = session_id
556
+ if obj.session_id is None:
557
+ obj.session_id = uuid.uuid4().hex
558
+ elif obj.session_id in self.session_futures:
559
+ return None
560
+
536
561
  self.send_to_scheduler.send_pyobj(obj)
537
- self.session_futures[session_id] = asyncio.Future()
538
- session_id = await self.session_futures[session_id]
539
- del self.session_futures[session_id]
562
+
563
+ self.session_futures[obj.session_id] = asyncio.Future()
564
+ session_id = await self.session_futures[obj.session_id]
565
+ del self.session_futures[obj.session_id]
540
566
  return session_id
541
567
 
542
568
  async def close_session(
@@ -688,7 +714,7 @@ class TokenizerManager:
688
714
  )
689
715
  elif isinstance(recv_obj, OpenSessionReqOutput):
690
716
  self.session_futures[recv_obj.session_id].set_result(
691
- recv_obj.session_id
717
+ recv_obj.session_id if recv_obj.success else None
692
718
  )
693
719
  elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
694
720
  if self.server_args.dp_size == 1:
@@ -708,6 +734,11 @@ class TokenizerManager:
708
734
  self.server_args.dp_size == 1
709
735
  ), "dp_size must be 1 for update weights from distributed"
710
736
  self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
737
+ elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
738
+ assert (
739
+ self.server_args.dp_size == 1
740
+ ), "dp_size must be 1 for update weights from distributed"
741
+ self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
711
742
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
712
743
  self.get_weights_by_name_communicator.handle_recv(recv_obj)
713
744
  else:
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
24
24
  InitWeightsUpdateGroupReqInput,
25
25
  UpdateWeightFromDiskReqInput,
26
26
  UpdateWeightsFromDistributedReqInput,
27
+ UpdateWeightsFromTensorReqInput,
27
28
  )
28
29
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
29
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -188,6 +189,12 @@ class TpModelWorker:
188
189
  )
189
190
  return success, message
190
191
 
192
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
193
+ success, message = self.model_runner.update_weights_from_tensor(
194
+ recv_req.name, recv_req.tensor
195
+ )
196
+ return success, message
197
+
191
198
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
199
  parameter = self.model_runner.get_weights_by_name(
193
200
  recv_req.name, recv_req.truncate_size
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
28
28
  InitWeightsUpdateGroupReqInput,
29
29
  UpdateWeightFromDiskReqInput,
30
30
  UpdateWeightsFromDistributedReqInput,
31
+ UpdateWeightsFromTensorReqInput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
33
34
  from sglang.srt.managers.tp_worker import TpModelWorker
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
225
226
  success, message = self.worker.update_weights_from_distributed(recv_req)
226
227
  return success, message
227
228
 
229
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
230
+ success, message = self.worker.update_weights_from_tensor(recv_req)
231
+ return success, message
232
+
228
233
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
229
234
  return self.worker.get_weights_by_name(recv_req)
230
235
 
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
45
45
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
46
  from sglang.srt.model_executor.model_runner import ModelRunner
47
47
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
48
+ from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
48
49
 
49
50
 
50
51
  class ForwardMode(IntEnum):
@@ -59,6 +60,11 @@ class ForwardMode(IntEnum):
59
60
  # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
60
61
  IDLE = auto()
61
62
 
63
+ # Used in speculative decoding: verify a batch in the target model.
64
+ TARGET_VERIFY = auto()
65
+ # Used in speculative decoding: extend a batch in the draft model.
66
+ DRAFT_EXTEND = auto()
67
+
62
68
  # A dummy first batch to start the pipeline for overlap scheduler.
63
69
  # It is now used for triggering the sampling_info_done event for the first prefill batch.
64
70
  DUMMY_FIRST = auto()
@@ -67,7 +73,12 @@ class ForwardMode(IntEnum):
67
73
  return self == ForwardMode.PREFILL
68
74
 
69
75
  def is_extend(self):
70
- return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
76
+ return (
77
+ self == ForwardMode.EXTEND
78
+ or self == ForwardMode.MIXED
79
+ or self == ForwardMode.DRAFT_EXTEND
80
+ or self == self.TARGET_VERIFY
81
+ )
71
82
 
72
83
  def is_decode(self):
73
84
  return self == ForwardMode.DECODE
@@ -78,6 +89,15 @@ class ForwardMode(IntEnum):
78
89
  def is_idle(self):
79
90
  return self == ForwardMode.IDLE
80
91
 
92
+ def is_target_verify(self):
93
+ return self == ForwardMode.TARGET_VERIFY
94
+
95
+ def is_draft_extend(self):
96
+ return self == ForwardMode.DRAFT_EXTEND
97
+
98
+ def is_cuda_graph(self):
99
+ return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
100
+
81
101
  def is_dummy_first(self):
82
102
  return self == ForwardMode.DUMMY_FIRST
83
103
 
@@ -141,14 +161,18 @@ class ForwardBatch:
141
161
  token_to_kv_pool: BaseTokenToKVPool = None
142
162
  attn_backend: AttentionBackend = None
143
163
 
144
- # For Qwen2-VL
145
- mrope_positions: torch.Tensor = None
164
+ # Speculative decoding
165
+ spec_info: SpecInfo = None
166
+ spec_algorithm: SpeculativeAlgorithm = None
146
167
 
147
168
  # For DP attention
148
169
  global_num_tokens: Optional[List[int]] = None
149
170
  gathered_buffer: Optional[torch.Tensor] = None
150
171
  can_run_dp_cuda_graph: bool = False
151
172
 
173
+ # For Qwen2-VL
174
+ mrope_positions: torch.Tensor = None
175
+
152
176
  def compute_mrope_positions(
153
177
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
154
178
  ):
@@ -351,3 +375,18 @@ def compute_position_torch(
351
375
  extend_start_loc = torch.zeros_like(extend_seq_lens)
352
376
  extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
353
377
  return positions.to(torch.int64), extend_start_loc
378
+
379
+
380
+ class CaptureHiddenMode(IntEnum):
381
+ NULL = auto()
382
+ FULL = auto()
383
+ LAST = auto()
384
+
385
+ def need_capture(self):
386
+ return self != CaptureHiddenMode.NULL
387
+
388
+ def is_full(self):
389
+ return self == CaptureHiddenMode.FULL
390
+
391
+ def is_last(self):
392
+ return self == CaptureHiddenMode.LAST
@@ -429,6 +429,10 @@ class ModelRunner:
429
429
  logger.error(error_msg)
430
430
  return False, error_msg
431
431
 
432
+ def update_weights_from_tensor(self, name, tensor: torch.Tensor):
433
+ self.model.load_weights([(name, tensor)])
434
+ return True, "Success" # TODO error handling
435
+
432
436
  def get_weights_by_name(
433
437
  self, name: str, truncate_size: int = 100
434
438
  ) -> Optional[torch.Tensor]:
@@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module):
516
516
  )
517
517
  return None
518
518
 
519
+ def get_embed_and_head(self):
520
+ return self.model.embed_tokens.weight, self.lm_head.weight
521
+
522
+ def set_embed_and_head(self, embed, head):
523
+ del self.model.embed_tokens.weight
524
+ del self.lm_head.weight
525
+ self.model.embed_tokens.weight = embed
526
+ self.lm_head.weight = head
527
+ torch.cuda.empty_cache()
528
+ torch.cuda.synchronize()
529
+
519
530
 
520
531
  class Phi3ForCausalLM(LlamaForCausalLM):
521
532
  pass
@@ -0,0 +1,132 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import LlamaConfig
25
+
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
+ from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
34
+
35
+
36
+ class LlamaDecoderLayer(LlamaDecoderLayer):
37
+ def __init__(
38
+ self,
39
+ config: LlamaConfig,
40
+ layer_id: int = 0,
41
+ quant_config: Optional[QuantizationConfig] = None,
42
+ prefix: str = "",
43
+ ) -> None:
44
+ super().__init__(config, layer_id, quant_config, prefix)
45
+
46
+ # Skip the input_layernorm
47
+ # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
48
+ if layer_id == 0:
49
+ del self.input_layernorm
50
+ setattr(self, "input_layernorm", lambda x: x)
51
+
52
+
53
+ class LlamaModel(nn.Module):
54
+ def __init__(
55
+ self,
56
+ config: LlamaConfig,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.config = config
61
+ self.vocab_size = config.vocab_size
62
+ self.embed_tokens = VocabParallelEmbedding(
63
+ config.vocab_size,
64
+ config.hidden_size,
65
+ )
66
+ self.layers = nn.ModuleList(
67
+ [
68
+ LlamaDecoderLayer(
69
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
70
+ )
71
+ for i in range(config.num_hidden_layers)
72
+ ]
73
+ )
74
+ self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
75
+
76
+ def forward(
77
+ self,
78
+ input_ids: torch.Tensor,
79
+ positions: torch.Tensor,
80
+ forward_batch: ForwardBatch,
81
+ input_embeds: torch.Tensor = None,
82
+ ) -> torch.Tensor:
83
+ if input_embeds is None:
84
+ hidden_states = self.embed_tokens(input_ids)
85
+ else:
86
+ hidden_states = input_embeds
87
+
88
+ hidden_states = self.fc(
89
+ torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
90
+ )
91
+
92
+ residual = None
93
+ for i in range(len(self.layers)):
94
+ layer = self.layers[i]
95
+ hidden_states, residual = layer(
96
+ positions,
97
+ hidden_states,
98
+ forward_batch,
99
+ residual,
100
+ )
101
+ return hidden_states + residual
102
+
103
+
104
+ class LlamaForCausalLMEagle(LlamaForCausalLM):
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ quant_config: Optional[QuantizationConfig] = None,
109
+ cache_config=None,
110
+ ) -> None:
111
+ nn.Module.__init__(self)
112
+ self.config = config
113
+ self.quant_config = quant_config
114
+ self.model = LlamaModel(config, quant_config=quant_config)
115
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
116
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
117
+ if self.config.tie_word_embeddings:
118
+ self.lm_head = self.model.embed_tokens
119
+ else:
120
+ self.lm_head = ParallelLMHead(
121
+ config.vocab_size, config.hidden_size, quant_config=quant_config
122
+ )
123
+ self.logits_processor = LogitsProcessor(config)
124
+
125
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
126
+ for name, loaded_weight in weights:
127
+ if "lm_head" not in name:
128
+ name = "model." + name
129
+ super().load_weights([(name, loaded_weight)])
130
+
131
+
132
+ EntryClass = [LlamaForCausalLMEagle]
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
65
65
  FileDeleteResponse,
66
66
  FileRequest,
67
67
  FileResponse,
68
+ FunctionResponse,
68
69
  LogProbs,
70
+ ToolCall,
69
71
  TopLogprob,
70
72
  UsageInfo,
71
73
  )
74
+ from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
72
75
  from sglang.utils import get_exception_traceback
73
76
 
74
77
  logger = logging.getLogger(__name__)
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
879
882
  # None skips any image processing in GenerateReqInput.
880
883
  if not isinstance(request.messages, str):
881
884
  # Apply chat template and its stop strings.
885
+ tools = None
886
+ if request.tools and request.tool_choice != "none":
887
+ request.skip_special_tokens = False
888
+ if request.stream:
889
+ logger.warning("Streaming is not supported with tools.")
890
+ request.stream = False
891
+ if not isinstance(request.tool_choice, str):
892
+ tools = [
893
+ item.function.model_dump()
894
+ for item in request.tools
895
+ if item.function.name == request.tool_choice.function.name
896
+ ]
897
+ else:
898
+ tools = [item.function.model_dump() for item in request.tools]
899
+
882
900
  if chat_template_name is None:
883
901
  openai_compatible_messages = []
884
902
  for message in request.messages:
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
902
920
  openai_compatible_messages,
903
921
  tokenize=True,
904
922
  add_generation_prompt=True,
923
+ tools=tools,
905
924
  )
906
925
  if assistant_prefix:
907
926
  prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1041
1060
 
1042
1061
  finish_reason = ret_item["meta_info"]["finish_reason"]
1043
1062
 
1063
+ tool_calls = None
1064
+ text = ret_item["text"]
1065
+
1066
+ if isinstance(request, list):
1067
+ tool_choice = request[idx].tool_choice
1068
+ tools = request[idx].tools
1069
+ else:
1070
+ tool_choice = request.tool_choice
1071
+ tools = request.tools
1072
+
1073
+ if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
1074
+ if finish_reason == "stop":
1075
+ finish_reason = "tool_calls"
1076
+ try:
1077
+ text, call_info_list = parse_tool_response(text, tools) # noqa
1078
+ tool_calls = [
1079
+ ToolCall(
1080
+ id=str(call_info[0]),
1081
+ function=FunctionResponse(
1082
+ name=call_info[1], arguments=call_info[2]
1083
+ ),
1084
+ )
1085
+ for call_info in call_info_list
1086
+ ]
1087
+ except Exception as e:
1088
+ logger.error(f"Exception: {e}")
1089
+ return create_error_response(
1090
+ HTTPStatus.BAD_REQUEST,
1091
+ "Failed to parse fc related info to json format!",
1092
+ )
1093
+
1044
1094
  if to_file:
1045
1095
  # to make the choice data json serializable
1046
1096
  choice_data = {
1047
1097
  "index": 0,
1048
- "message": {"role": "assistant", "content": ret_item["text"]},
1098
+ "message": {
1099
+ "role": "assistant",
1100
+ "content": ret_item["text"] if tool_calls is None else None,
1101
+ "tool_calls": tool_calls,
1102
+ },
1049
1103
  "logprobs": choice_logprobs,
1050
1104
  "finish_reason": (finish_reason["type"] if finish_reason else ""),
1051
1105
  "matched_stop": (
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1057
1111
  else:
1058
1112
  choice_data = ChatCompletionResponseChoice(
1059
1113
  index=idx,
1060
- message=ChatMessage(role="assistant", content=ret_item["text"]),
1114
+ message=ChatMessage(
1115
+ role="assistant",
1116
+ content=ret_item["text"] if tool_calls is None else None,
1117
+ tool_calls=tool_calls,
1118
+ ),
1061
1119
  logprobs=choice_logprobs,
1062
1120
  finish_reason=(finish_reason["type"] if finish_reason else ""),
1063
1121
  matched_stop=(