sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
from collections import deque
|
23
24
|
from typing import TYPE_CHECKING, List, Optional
|
24
25
|
|
25
26
|
import torch
|
@@ -175,17 +176,25 @@ class SchedulerDisaggregationPrefillMixin:
|
|
175
176
|
"""
|
176
177
|
|
177
178
|
@torch.no_grad()
|
178
|
-
def event_loop_normal_disagg_prefill(self):
|
179
|
+
def event_loop_normal_disagg_prefill(self: Scheduler):
|
179
180
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
180
181
|
|
181
182
|
while True:
|
182
183
|
recv_reqs = self.recv_requests()
|
183
184
|
self.process_input_requests(recv_reqs)
|
184
185
|
self.waiting_queue.extend(
|
185
|
-
self.
|
186
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
186
187
|
)
|
187
188
|
self.process_prefill_chunk()
|
188
189
|
batch = self.get_new_batch_prefill()
|
190
|
+
|
191
|
+
# Handle DP attention
|
192
|
+
if (
|
193
|
+
self.server_args.enable_dp_attention
|
194
|
+
or self.server_args.enable_sp_layernorm
|
195
|
+
):
|
196
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
197
|
+
|
189
198
|
self.cur_batch = batch
|
190
199
|
|
191
200
|
if batch:
|
@@ -204,6 +213,48 @@ class SchedulerDisaggregationPrefillMixin:
|
|
204
213
|
# Otherwise, it hangs under high concurrency
|
205
214
|
self.running_batch.batch_is_full = False
|
206
215
|
|
216
|
+
@torch.no_grad()
|
217
|
+
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
218
|
+
self.result_queue = deque()
|
219
|
+
|
220
|
+
while True:
|
221
|
+
recv_reqs = self.recv_requests()
|
222
|
+
self.process_input_requests(recv_reqs)
|
223
|
+
self.waiting_queue.extend(
|
224
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
225
|
+
)
|
226
|
+
self.process_prefill_chunk()
|
227
|
+
batch = self.get_new_batch_prefill()
|
228
|
+
|
229
|
+
# Handle DP attention
|
230
|
+
if (
|
231
|
+
self.server_args.enable_dp_attention
|
232
|
+
or self.server_args.enable_sp_layernorm
|
233
|
+
):
|
234
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
235
|
+
|
236
|
+
self.cur_batch = batch
|
237
|
+
|
238
|
+
if batch:
|
239
|
+
result = self.run_batch(batch)
|
240
|
+
self.result_queue.append((batch.copy(), result))
|
241
|
+
|
242
|
+
if self.last_batch:
|
243
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
244
|
+
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
245
|
+
|
246
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
247
|
+
self.process_disagg_prefill_inflight_queue()
|
248
|
+
|
249
|
+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
250
|
+
self.check_memory()
|
251
|
+
self.new_token_ratio = self.init_new_token_ratio
|
252
|
+
|
253
|
+
self.last_batch = batch
|
254
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
255
|
+
# Otherwise, it hangs under high concurrency
|
256
|
+
self.running_batch.batch_is_full = False
|
257
|
+
|
207
258
|
def process_batch_result_disagg_prefill(
|
208
259
|
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
209
260
|
) -> None:
|
@@ -212,7 +263,26 @@ class SchedulerDisaggregationPrefillMixin:
|
|
212
263
|
Adapted from process_batch_result_prefill
|
213
264
|
"""
|
214
265
|
|
215
|
-
|
266
|
+
(
|
267
|
+
logits_output,
|
268
|
+
next_token_ids,
|
269
|
+
extend_input_len_per_req,
|
270
|
+
extend_logprob_start_len_per_req,
|
271
|
+
bid,
|
272
|
+
) = (
|
273
|
+
result.logits_output,
|
274
|
+
result.next_token_ids,
|
275
|
+
result.extend_input_len_per_req,
|
276
|
+
result.extend_logprob_start_len_per_req,
|
277
|
+
result.bid,
|
278
|
+
)
|
279
|
+
|
280
|
+
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
281
|
+
if self.enable_overlap:
|
282
|
+
# wait
|
283
|
+
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
284
|
+
else:
|
285
|
+
next_token_ids = result.next_token_ids.tolist()
|
216
286
|
|
217
287
|
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
218
288
|
req: Req
|
@@ -226,12 +296,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
226
296
|
# being chunked reqs' prefill is not finished
|
227
297
|
req.is_chunked -= 1
|
228
298
|
|
229
|
-
|
230
|
-
|
231
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
232
|
-
# We need to remove this for overlap schedule.
|
233
|
-
self.current_stream.synchronize()
|
234
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
299
|
+
if self.enable_overlap:
|
300
|
+
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
235
301
|
|
236
302
|
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
237
303
|
"""
|
@@ -260,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
260
326
|
raise Exception("Transferring failed")
|
261
327
|
|
262
328
|
for req in done_reqs:
|
263
|
-
self.
|
329
|
+
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
264
330
|
req.metadata_buffer_index
|
265
331
|
)
|
266
332
|
|
@@ -276,34 +342,67 @@ class SchedulerDisaggregationPrefillMixin:
|
|
276
342
|
# only finished requests to running_batch.
|
277
343
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
278
344
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
279
|
-
self.
|
345
|
+
if self.enable_overlap:
|
346
|
+
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
347
|
+
self.chunked_req.tmp_end_idx = min(
|
348
|
+
len(self.chunked_req.fill_ids),
|
349
|
+
len(self.chunked_req.origin_input_ids),
|
350
|
+
)
|
351
|
+
else:
|
352
|
+
self.send_kv_chunk(self.chunked_req)
|
280
353
|
# chunked request keeps its rid but will get a new req_pool_idx
|
281
354
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
282
355
|
self.running_batch.batch_is_full = False
|
283
356
|
|
284
357
|
def send_kv_chunk(
|
285
|
-
self: Scheduler,
|
358
|
+
self: Scheduler,
|
359
|
+
req: Req,
|
360
|
+
token_id: Optional[int] = None,
|
361
|
+
end_idx: Optional[int] = None,
|
286
362
|
) -> None:
|
287
363
|
"""
|
288
364
|
Send a prefilled chunk to the decode server
|
289
365
|
"""
|
366
|
+
page_size = self.token_to_kv_pool_allocator.page_size
|
290
367
|
start_idx = req.start_send_idx
|
291
|
-
end_idx
|
368
|
+
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
369
|
+
# the resolved length is not the same as fill_ids's length
|
370
|
+
end_idx = (
|
371
|
+
end_idx
|
372
|
+
if end_idx is not None
|
373
|
+
else min(len(req.fill_ids), len(req.origin_input_ids))
|
374
|
+
)
|
375
|
+
last_chunk = token_id is not None
|
376
|
+
|
377
|
+
if (not last_chunk) and (
|
378
|
+
end_idx % page_size != 0
|
379
|
+
): # todo: remove the second condition
|
380
|
+
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
381
|
+
end_idx = end_idx - end_idx % page_size
|
292
382
|
|
293
383
|
# Update next start_send_idx
|
294
384
|
req.start_send_idx = end_idx
|
295
385
|
|
296
386
|
kv_indices = (
|
297
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx
|
387
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
298
388
|
.cpu()
|
299
389
|
.numpy()
|
300
390
|
)
|
301
|
-
if
|
302
|
-
self.
|
391
|
+
if last_chunk is True:
|
392
|
+
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
303
393
|
req.metadata_buffer_index, token_id
|
304
394
|
)
|
305
|
-
|
306
|
-
|
307
|
-
|
395
|
+
page_indices = kv_to_page_indices(kv_indices, page_size)
|
396
|
+
|
397
|
+
page_start_idx = start_idx // page_size
|
398
|
+
page_end_idx = page_start_idx + len(page_indices)
|
399
|
+
|
400
|
+
if len(page_indices) == 0:
|
401
|
+
logger.info(
|
402
|
+
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
403
|
+
)
|
404
|
+
return
|
405
|
+
|
406
|
+
req.disagg_kv_sender.send(
|
407
|
+
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
308
408
|
)
|
309
|
-
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
@@ -47,6 +47,7 @@ class ReqToMetadataIdxAllocator:
|
|
47
47
|
|
48
48
|
class TransferBackend(Enum):
|
49
49
|
MOONCAKE = "mooncake"
|
50
|
+
NIXL = "nixl"
|
50
51
|
FAKE = "fake"
|
51
52
|
|
52
53
|
|
@@ -73,6 +74,21 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
73
74
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
74
75
|
}
|
75
76
|
return class_mapping.get(class_type)
|
77
|
+
if transfer_backend == TransferBackend.NIXL:
|
78
|
+
from sglang.srt.disaggregation.nixl import (
|
79
|
+
NixlKVBootstrapServer,
|
80
|
+
NixlKVManager,
|
81
|
+
NixlKVReceiver,
|
82
|
+
NixlKVSender,
|
83
|
+
)
|
84
|
+
|
85
|
+
class_mapping = {
|
86
|
+
KVClassType.MANAGER: NixlKVManager,
|
87
|
+
KVClassType.SENDER: NixlKVSender,
|
88
|
+
KVClassType.RECEIVER: NixlKVReceiver,
|
89
|
+
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
|
+
}
|
91
|
+
return class_mapping.get(class_type)
|
76
92
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
77
93
|
|
78
94
|
|
@@ -82,6 +98,7 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|
82
98
|
# The return vector is kv_indices[::page_size] // page_size
|
83
99
|
if page_size == 1: # shortcut
|
84
100
|
return kv_indices
|
101
|
+
|
85
102
|
return kv_indices[::page_size] // page_size
|
86
103
|
|
87
104
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -279,6 +279,10 @@ class Engine(EngineBase):
|
|
279
279
|
self.shutdown()
|
280
280
|
return False
|
281
281
|
|
282
|
+
def flush_cache(self):
|
283
|
+
loop = asyncio.get_event_loop()
|
284
|
+
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
285
|
+
|
282
286
|
def start_profile(self):
|
283
287
|
loop = asyncio.get_event_loop()
|
284
288
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
@@ -25,11 +25,8 @@ import multiprocessing as multiprocessing
|
|
25
25
|
import os
|
26
26
|
import threading
|
27
27
|
import time
|
28
|
-
from ast import Mult
|
29
28
|
from http import HTTPStatus
|
30
|
-
from typing import AsyncIterator, Callable, Dict, Optional
|
31
|
-
|
32
|
-
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
29
|
+
from typing import AsyncIterator, Callable, Dict, Optional
|
33
30
|
|
34
31
|
# Fix a bug of Python threading
|
35
32
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -84,10 +81,10 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
84
81
|
from sglang.srt.reasoning_parser import ReasoningParser
|
85
82
|
from sglang.srt.server_args import ServerArgs
|
86
83
|
from sglang.srt.utils import (
|
87
|
-
MultiprocessingSerializer,
|
88
84
|
add_api_key_middleware,
|
89
85
|
add_prometheus_middleware,
|
90
86
|
delete_directory,
|
87
|
+
get_bool_env_var,
|
91
88
|
kill_process_tree,
|
92
89
|
set_uvicorn_logging_configs,
|
93
90
|
)
|
@@ -130,7 +127,10 @@ async def lifespan(fast_api_app: FastAPI):
|
|
130
127
|
|
131
128
|
|
132
129
|
# Fast API
|
133
|
-
app = FastAPI(
|
130
|
+
app = FastAPI(
|
131
|
+
lifespan=lifespan,
|
132
|
+
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
133
|
+
)
|
134
134
|
app.add_middleware(
|
135
135
|
CORSMiddleware,
|
136
136
|
allow_origins=["*"],
|
@@ -281,7 +281,9 @@ async def generate_from_file_request(file: UploadFile, request: Request):
|
|
281
281
|
)
|
282
282
|
|
283
283
|
try:
|
284
|
-
ret = await _global_state.generate_request(
|
284
|
+
ret = await _global_state.tokenizer_manager.generate_request(
|
285
|
+
obj, request
|
286
|
+
).__anext__()
|
285
287
|
return ret
|
286
288
|
except ValueError as e:
|
287
289
|
logger.error(f"Error: {e}")
|
@@ -315,11 +317,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
315
317
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
316
318
|
async def flush_cache():
|
317
319
|
"""Flush the radix cache."""
|
318
|
-
_global_state.tokenizer_manager.flush_cache()
|
320
|
+
ret = await _global_state.tokenizer_manager.flush_cache()
|
319
321
|
return Response(
|
320
322
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
321
323
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
322
|
-
status_code=200,
|
324
|
+
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
|
323
325
|
)
|
324
326
|
|
325
327
|
|
@@ -25,6 +25,7 @@ TOOLS_TAG_LIST = [
|
|
25
25
|
"<tool_call>",
|
26
26
|
"<|python_tag|>",
|
27
27
|
"[TOOL_CALLS]",
|
28
|
+
"<|tool▁calls▁begin|>",
|
28
29
|
]
|
29
30
|
|
30
31
|
|
@@ -477,6 +478,136 @@ class Llama32Detector(BaseFormatDetector):
|
|
477
478
|
)
|
478
479
|
|
479
480
|
|
481
|
+
class DeepSeekV3Detector(BaseFormatDetector):
|
482
|
+
"""
|
483
|
+
Detector for DeepSeek models.
|
484
|
+
Assumes function call format:
|
485
|
+
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
486
|
+
"""
|
487
|
+
|
488
|
+
def __init__(self):
|
489
|
+
super().__init__()
|
490
|
+
self.bot_token = "<|tool▁calls▁begin|>"
|
491
|
+
self.eot_token = "<|tool▁calls▁end|>"
|
492
|
+
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
493
|
+
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
494
|
+
self._last_arguments = ""
|
495
|
+
|
496
|
+
def has_tool_call(self, text: str) -> bool:
|
497
|
+
"""Check if the text contains a deepseek format tool call."""
|
498
|
+
return self.bot_token in text
|
499
|
+
|
500
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
501
|
+
"""
|
502
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
503
|
+
|
504
|
+
:param text: The complete text to parse.
|
505
|
+
:param tools: List of available tools.
|
506
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
507
|
+
"""
|
508
|
+
idx = text.find(self.bot_token)
|
509
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
510
|
+
if self.bot_token not in text:
|
511
|
+
return StreamingParseResult(normal_text=normal_text, calls=[])
|
512
|
+
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
513
|
+
calls = []
|
514
|
+
try:
|
515
|
+
for match_result in match_result_list:
|
516
|
+
# Get function name
|
517
|
+
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
518
|
+
func_name = func_detail.group(2)
|
519
|
+
func_args = func_detail.group(3)
|
520
|
+
func_args = json.loads(func_args)
|
521
|
+
# construct match_result for parse_base_json
|
522
|
+
match_result = {"name": func_name, "parameters": func_args}
|
523
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
524
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
525
|
+
except Exception as e:
|
526
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
527
|
+
# return the normal text if parsing fails
|
528
|
+
return StreamingParseResult(normal_text=text)
|
529
|
+
|
530
|
+
def structure_info(self) -> _GetInfoFunc:
|
531
|
+
return lambda name: StructureInfo(
|
532
|
+
begin=">" + name + "\n```json\n",
|
533
|
+
end="\n```<",
|
534
|
+
trigger=">" + name + "\n```json\n",
|
535
|
+
)
|
536
|
+
|
537
|
+
def parse_streaming_increment(
|
538
|
+
self, new_text: str, tools: List[Tool]
|
539
|
+
) -> StreamingParseResult:
|
540
|
+
"""
|
541
|
+
Streaming incremental parsing tool calls for DeepSeekV3 format.
|
542
|
+
"""
|
543
|
+
self._buffer += new_text
|
544
|
+
current_text = self._buffer
|
545
|
+
|
546
|
+
if self.bot_token not in current_text:
|
547
|
+
self._buffer = ""
|
548
|
+
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
|
549
|
+
if e_token in new_text:
|
550
|
+
new_text = new_text.replace(e_token, "")
|
551
|
+
return StreamingParseResult(normal_text=new_text)
|
552
|
+
|
553
|
+
if not hasattr(self, "_tool_indices"):
|
554
|
+
self._tool_indices = {
|
555
|
+
tool.function.name: i
|
556
|
+
for i, tool in enumerate(tools)
|
557
|
+
if tool.function and tool.function.name
|
558
|
+
}
|
559
|
+
|
560
|
+
calls: list[ToolCallItem] = []
|
561
|
+
try:
|
562
|
+
partial_match = re.search(
|
563
|
+
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
564
|
+
string=current_text,
|
565
|
+
flags=re.DOTALL,
|
566
|
+
)
|
567
|
+
if partial_match:
|
568
|
+
func_name = partial_match.group(2).strip()
|
569
|
+
func_args_raw = partial_match.group(3).strip()
|
570
|
+
|
571
|
+
if not self.current_tool_name_sent:
|
572
|
+
calls.append(
|
573
|
+
ToolCallItem(
|
574
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
575
|
+
name=func_name,
|
576
|
+
parameters="",
|
577
|
+
)
|
578
|
+
)
|
579
|
+
self.current_tool_name_sent = True
|
580
|
+
else:
|
581
|
+
argument_diff = (
|
582
|
+
func_args_raw[len(self._last_arguments) :]
|
583
|
+
if func_args_raw.startswith(self._last_arguments)
|
584
|
+
else func_args_raw
|
585
|
+
)
|
586
|
+
|
587
|
+
if argument_diff:
|
588
|
+
calls.append(
|
589
|
+
ToolCallItem(
|
590
|
+
tool_index=self._tool_indices.get(func_name, 0),
|
591
|
+
name=None,
|
592
|
+
parameters=argument_diff,
|
593
|
+
)
|
594
|
+
)
|
595
|
+
self._last_arguments += argument_diff
|
596
|
+
|
597
|
+
if _is_complete_json(func_args_raw):
|
598
|
+
result = StreamingParseResult(normal_text="", calls=calls)
|
599
|
+
self._buffer = ""
|
600
|
+
self._last_arguments = ""
|
601
|
+
self.current_tool_name_sent = False
|
602
|
+
return result
|
603
|
+
|
604
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
605
|
+
|
606
|
+
except Exception as e:
|
607
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
608
|
+
return StreamingParseResult(normal_text=current_text)
|
609
|
+
|
610
|
+
|
480
611
|
class MultiFormatParser:
|
481
612
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
482
613
|
"""
|
@@ -543,6 +674,7 @@ class FunctionCallParser:
|
|
543
674
|
"llama3": Llama32Detector,
|
544
675
|
"qwen25": Qwen25Detector,
|
545
676
|
"mistral": MistralDetector,
|
677
|
+
"deepseekv3": DeepSeekV3Detector,
|
546
678
|
}
|
547
679
|
|
548
680
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
sglang/srt/layers/activation.py
CHANGED
@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
)
|
30
30
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
|
-
from sglang.srt.utils import
|
31
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
32
32
|
|
33
|
-
_is_cuda =
|
33
|
+
_is_cuda = is_cuda()
|
34
34
|
|
35
35
|
if _is_cuda:
|
36
36
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
|
|
62
62
|
layer: RadixAttention,
|
63
63
|
forward_batch: ForwardBatch,
|
64
64
|
save_kv_cache: bool = True,
|
65
|
+
**kwargs,
|
65
66
|
):
|
66
67
|
"""Run forward on an attention layer."""
|
67
68
|
if forward_batch.forward_mode.is_decode():
|
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
|
|
72
73
|
layer,
|
73
74
|
forward_batch,
|
74
75
|
save_kv_cache=save_kv_cache,
|
76
|
+
**kwargs,
|
75
77
|
)
|
76
78
|
else:
|
77
79
|
return self.forward_extend(
|
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
|
|
81
83
|
layer,
|
82
84
|
forward_batch,
|
83
85
|
save_kv_cache=save_kv_cache,
|
86
|
+
**kwargs,
|
84
87
|
)
|
85
88
|
|
86
89
|
def forward_decode(
|