sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
from collections import deque
|
23
24
|
from typing import TYPE_CHECKING, List, Optional
|
24
25
|
|
25
26
|
import torch
|
@@ -31,6 +32,8 @@ from sglang.srt.disaggregation.utils import (
|
|
31
32
|
ReqToMetadataIdxAllocator,
|
32
33
|
TransferBackend,
|
33
34
|
get_kv_class,
|
35
|
+
kv_to_page_indices,
|
36
|
+
kv_to_page_num,
|
34
37
|
poll_and_all_reduce,
|
35
38
|
)
|
36
39
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
@@ -103,7 +106,7 @@ class PrefillBootstrapQueue:
|
|
103
106
|
kv_args.aux_item_lens = [
|
104
107
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
105
108
|
]
|
106
|
-
kv_args.ib_device =
|
109
|
+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
107
110
|
kv_args.gpu_id = self.scheduler.gpu_id
|
108
111
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
109
112
|
kv_manager = kv_manager_class(
|
@@ -154,7 +157,8 @@ class PrefillBootstrapQueue:
|
|
154
157
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
155
158
|
)
|
156
159
|
assert req.metadata_buffer_index is not None
|
157
|
-
|
160
|
+
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
161
|
+
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
158
162
|
|
159
163
|
bootstrapped_reqs.append(req)
|
160
164
|
indices_to_remove.add(i)
|
@@ -171,6 +175,70 @@ class SchedulerDisaggregationPrefillMixin:
|
|
171
175
|
Mixin for Scheduler to handle disaggregation prefill
|
172
176
|
"""
|
173
177
|
|
178
|
+
@torch.no_grad()
|
179
|
+
def event_loop_normal_disagg_prefill(self):
|
180
|
+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
181
|
+
|
182
|
+
while True:
|
183
|
+
recv_reqs = self.recv_requests()
|
184
|
+
self.process_input_requests(recv_reqs)
|
185
|
+
self.waiting_queue.extend(
|
186
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
187
|
+
)
|
188
|
+
self.process_prefill_chunk()
|
189
|
+
batch = self.get_new_batch_prefill()
|
190
|
+
self.cur_batch = batch
|
191
|
+
|
192
|
+
if batch:
|
193
|
+
result = self.run_batch(batch)
|
194
|
+
self.process_batch_result_disagg_prefill(batch, result)
|
195
|
+
|
196
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
197
|
+
self.process_disagg_prefill_inflight_queue()
|
198
|
+
|
199
|
+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
200
|
+
self.check_memory()
|
201
|
+
self.new_token_ratio = self.init_new_token_ratio
|
202
|
+
|
203
|
+
self.last_batch = batch
|
204
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
205
|
+
# Otherwise, it hangs under high concurrency
|
206
|
+
self.running_batch.batch_is_full = False
|
207
|
+
|
208
|
+
@torch.no_grad()
|
209
|
+
def event_loop_overlap_disagg_prefill(self):
|
210
|
+
self.result_queue = deque()
|
211
|
+
|
212
|
+
while True:
|
213
|
+
recv_reqs = self.recv_requests()
|
214
|
+
self.process_input_requests(recv_reqs)
|
215
|
+
self.waiting_queue.extend(
|
216
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
217
|
+
)
|
218
|
+
self.process_prefill_chunk()
|
219
|
+
batch = self.get_new_batch_prefill()
|
220
|
+
self.cur_batch = batch
|
221
|
+
|
222
|
+
if batch:
|
223
|
+
result = self.run_batch(batch)
|
224
|
+
self.result_queue.append((batch.copy(), result))
|
225
|
+
|
226
|
+
if self.last_batch:
|
227
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
228
|
+
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
229
|
+
|
230
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
231
|
+
self.process_disagg_prefill_inflight_queue()
|
232
|
+
|
233
|
+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
234
|
+
self.check_memory()
|
235
|
+
self.new_token_ratio = self.init_new_token_ratio
|
236
|
+
|
237
|
+
self.last_batch = batch
|
238
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
239
|
+
# Otherwise, it hangs under high concurrency
|
240
|
+
self.running_batch.batch_is_full = False
|
241
|
+
|
174
242
|
def process_batch_result_disagg_prefill(
|
175
243
|
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
176
244
|
) -> None:
|
@@ -179,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
|
|
179
247
|
Adapted from process_batch_result_prefill
|
180
248
|
"""
|
181
249
|
|
182
|
-
|
250
|
+
(
|
251
|
+
logits_output,
|
252
|
+
next_token_ids,
|
253
|
+
extend_input_len_per_req,
|
254
|
+
extend_logprob_start_len_per_req,
|
255
|
+
bid,
|
256
|
+
) = (
|
257
|
+
result.logits_output,
|
258
|
+
result.next_token_ids,
|
259
|
+
result.extend_input_len_per_req,
|
260
|
+
result.extend_logprob_start_len_per_req,
|
261
|
+
result.bid,
|
262
|
+
)
|
263
|
+
|
264
|
+
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
265
|
+
if self.enable_overlap:
|
266
|
+
# wait
|
267
|
+
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
268
|
+
else:
|
269
|
+
next_token_ids = result.next_token_ids.tolist()
|
183
270
|
|
184
271
|
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
185
272
|
req: Req
|
@@ -193,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
193
280
|
# being chunked reqs' prefill is not finished
|
194
281
|
req.is_chunked -= 1
|
195
282
|
|
196
|
-
|
197
|
-
|
198
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
199
|
-
# We need to remove this for overlap schedule.
|
200
|
-
self.current_stream.synchronize()
|
201
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
283
|
+
if self.enable_overlap:
|
284
|
+
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
202
285
|
|
203
286
|
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
204
287
|
"""
|
@@ -210,7 +293,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
210
293
|
|
211
294
|
polls = poll_and_all_reduce(
|
212
295
|
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
213
|
-
self.
|
296
|
+
self.attn_tp_cpu_group,
|
214
297
|
)
|
215
298
|
|
216
299
|
undone_reqs: List[Req] = []
|
@@ -243,31 +326,68 @@ class SchedulerDisaggregationPrefillMixin:
|
|
243
326
|
# only finished requests to running_batch.
|
244
327
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
245
328
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
246
|
-
|
329
|
+
if (
|
330
|
+
self.enable_overlap
|
331
|
+
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
332
|
+
self.chunked_req.tmp_end_idx = min(
|
333
|
+
len(self.chunked_req.fill_ids),
|
334
|
+
len(self.chunked_req.origin_input_ids),
|
335
|
+
)
|
336
|
+
else:
|
337
|
+
self.send_kv_chunk(self.chunked_req)
|
247
338
|
# chunked request keeps its rid but will get a new req_pool_idx
|
248
339
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
249
340
|
self.running_batch.batch_is_full = False
|
250
341
|
|
251
342
|
def send_kv_chunk(
|
252
|
-
self: Scheduler,
|
343
|
+
self: Scheduler,
|
344
|
+
req: Req,
|
345
|
+
token_id: Optional[int] = None,
|
346
|
+
end_idx: Optional[int] = None,
|
253
347
|
) -> None:
|
254
348
|
"""
|
255
349
|
Send a prefilled chunk to the decode server
|
256
350
|
"""
|
351
|
+
page_size = self.token_to_kv_pool_allocator.page_size
|
257
352
|
start_idx = req.start_send_idx
|
258
|
-
end_idx
|
353
|
+
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
354
|
+
# the resolved length is not the same as fill_ids's length
|
355
|
+
end_idx = (
|
356
|
+
end_idx
|
357
|
+
if end_idx is not None
|
358
|
+
else min(len(req.fill_ids), len(req.origin_input_ids))
|
359
|
+
)
|
360
|
+
last_chunk = token_id is not None
|
361
|
+
|
362
|
+
if (not last_chunk) and (
|
363
|
+
end_idx % page_size != 0
|
364
|
+
): # todo: remove the second condition
|
365
|
+
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
366
|
+
end_idx = end_idx - end_idx % page_size
|
259
367
|
|
260
368
|
# Update next start_send_idx
|
261
369
|
req.start_send_idx = end_idx
|
262
370
|
|
263
371
|
kv_indices = (
|
264
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx
|
372
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
265
373
|
.cpu()
|
266
374
|
.numpy()
|
267
375
|
)
|
268
|
-
if
|
376
|
+
if last_chunk is True:
|
269
377
|
self.disagg_prefill_pending_queue.store_prefill_results(
|
270
378
|
req.metadata_buffer_index, token_id
|
271
379
|
)
|
272
|
-
|
273
|
-
|
380
|
+
page_indices = kv_to_page_indices(kv_indices, page_size)
|
381
|
+
|
382
|
+
page_start_idx = start_idx // page_size
|
383
|
+
page_end_idx = page_start_idx + len(page_indices)
|
384
|
+
|
385
|
+
if len(page_indices) == 0:
|
386
|
+
logger.info(
|
387
|
+
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
388
|
+
)
|
389
|
+
return
|
390
|
+
|
391
|
+
req.disagg_kv_sender.send(
|
392
|
+
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
393
|
+
)
|
@@ -4,6 +4,7 @@ from collections import deque
|
|
4
4
|
from enum import Enum
|
5
5
|
from typing import List
|
6
6
|
|
7
|
+
import numpy as np
|
7
8
|
import torch
|
8
9
|
import torch.distributed as dist
|
9
10
|
|
@@ -46,6 +47,7 @@ class ReqToMetadataIdxAllocator:
|
|
46
47
|
|
47
48
|
class TransferBackend(Enum):
|
48
49
|
MOONCAKE = "mooncake"
|
50
|
+
NIXL = "nixl"
|
49
51
|
FAKE = "fake"
|
50
52
|
|
51
53
|
|
@@ -72,4 +74,34 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
72
74
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
73
75
|
}
|
74
76
|
return class_mapping.get(class_type)
|
77
|
+
if transfer_backend == TransferBackend.NIXL:
|
78
|
+
from sglang.srt.disaggregation.nixl import (
|
79
|
+
NixlKVBootstrapServer,
|
80
|
+
NixlKVManager,
|
81
|
+
NixlKVReceiver,
|
82
|
+
NixlKVSender,
|
83
|
+
)
|
84
|
+
|
85
|
+
class_mapping = {
|
86
|
+
KVClassType.MANAGER: NixlKVManager,
|
87
|
+
KVClassType.SENDER: NixlKVSender,
|
88
|
+
KVClassType.RECEIVER: NixlKVReceiver,
|
89
|
+
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
|
+
}
|
91
|
+
return class_mapping.get(class_type)
|
75
92
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
93
|
+
|
94
|
+
|
95
|
+
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
96
|
+
# 1. The page is guaruanteed to be full except the last page.
|
97
|
+
# 2. page index = kv_index // page_size
|
98
|
+
# The return vector is kv_indices[::page_size] // page_size
|
99
|
+
if page_size == 1: # shortcut
|
100
|
+
return kv_indices
|
101
|
+
|
102
|
+
return kv_indices[::page_size] // page_size
|
103
|
+
|
104
|
+
|
105
|
+
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
106
|
+
# ceil(num_kv_indices / page_size)
|
107
|
+
return (num_kv_indices + page_size - 1) // page_size
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -279,6 +279,10 @@ class Engine(EngineBase):
|
|
279
279
|
self.shutdown()
|
280
280
|
return False
|
281
281
|
|
282
|
+
def flush_cache(self):
|
283
|
+
loop = asyncio.get_event_loop()
|
284
|
+
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
285
|
+
|
282
286
|
def start_profile(self):
|
283
287
|
loop = asyncio.get_event_loop()
|
284
288
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
@@ -25,11 +25,8 @@ import multiprocessing as multiprocessing
|
|
25
25
|
import os
|
26
26
|
import threading
|
27
27
|
import time
|
28
|
-
from ast import Mult
|
29
28
|
from http import HTTPStatus
|
30
|
-
from typing import AsyncIterator, Callable, Dict, Optional
|
31
|
-
|
32
|
-
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
29
|
+
from typing import AsyncIterator, Callable, Dict, Optional
|
33
30
|
|
34
31
|
# Fix a bug of Python threading
|
35
32
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -84,7 +81,6 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
84
81
|
from sglang.srt.reasoning_parser import ReasoningParser
|
85
82
|
from sglang.srt.server_args import ServerArgs
|
86
83
|
from sglang.srt.utils import (
|
87
|
-
MultiprocessingSerializer,
|
88
84
|
add_api_key_middleware,
|
89
85
|
add_prometheus_middleware,
|
90
86
|
delete_directory,
|
@@ -315,11 +311,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
315
311
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
316
312
|
async def flush_cache():
|
317
313
|
"""Flush the radix cache."""
|
318
|
-
_global_state.tokenizer_manager.flush_cache()
|
314
|
+
ret = await _global_state.tokenizer_manager.flush_cache()
|
319
315
|
return Response(
|
320
316
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
321
317
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
322
|
-
status_code=200,
|
318
|
+
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
|
323
319
|
)
|
324
320
|
|
325
321
|
|
@@ -12,18 +12,17 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
import os
|
15
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
15
|
+
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
import torch.distributed as dist
|
19
19
|
from PIL.Image import Image
|
20
20
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
21
21
|
|
22
|
+
from sglang.srt.entrypoints.engine import Engine
|
22
23
|
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
23
24
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
24
25
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
25
|
-
from sglang.srt.server import Engine
|
26
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
27
26
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
28
27
|
|
29
28
|
|
@@ -125,7 +124,7 @@ class VerlEngine:
|
|
125
124
|
|
126
125
|
def update_weights_from_tensor(
|
127
126
|
self,
|
128
|
-
named_tensors:
|
127
|
+
named_tensors: Iterable[Tuple[str, torch.Tensor]],
|
129
128
|
load_format: Optional[str] = None,
|
130
129
|
):
|
131
130
|
# Most naive implementation, can optimize a lot if it is bottleneck
|
@@ -154,9 +153,12 @@ class VerlEngine:
|
|
154
153
|
)
|
155
154
|
],
|
156
155
|
load_format=load_format,
|
157
|
-
flush_cache=
|
156
|
+
flush_cache=False,
|
158
157
|
)
|
159
158
|
|
159
|
+
if self._tp_rank == 0:
|
160
|
+
self._engine.tokenizer_manager.flush_cache()
|
161
|
+
|
160
162
|
def release_memory_occupation(self):
|
161
163
|
if self._tp_rank == 0:
|
162
164
|
self._engine.release_memory_occupation()
|
@@ -25,6 +25,7 @@ TOOLS_TAG_LIST = [
|
|
25
25
|
"<tool_call>",
|
26
26
|
"<|python_tag|>",
|
27
27
|
"[TOOL_CALLS]",
|
28
|
+
"<|tool▁calls▁begin|>",
|
28
29
|
]
|
29
30
|
|
30
31
|
|
@@ -477,6 +478,64 @@ class Llama32Detector(BaseFormatDetector):
|
|
477
478
|
)
|
478
479
|
|
479
480
|
|
481
|
+
class DeepSeekV3Detector(BaseFormatDetector):
|
482
|
+
"""
|
483
|
+
Detector for DeepSeek models.
|
484
|
+
Assumes function call format:
|
485
|
+
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
486
|
+
"""
|
487
|
+
|
488
|
+
def __init__(self):
|
489
|
+
super().__init__()
|
490
|
+
self.bot_token = "<|tool▁calls▁begin|>"
|
491
|
+
self.eot_token = "<|tool▁calls▁end|>"
|
492
|
+
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
493
|
+
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
494
|
+
|
495
|
+
def has_tool_call(self, text: str) -> bool:
|
496
|
+
"""Check if the text contains a deepseek format tool call."""
|
497
|
+
return self.bot_token in text
|
498
|
+
|
499
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
500
|
+
"""
|
501
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
502
|
+
|
503
|
+
:param text: The complete text to parse.
|
504
|
+
:param tools: List of available tools.
|
505
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
506
|
+
"""
|
507
|
+
idx = text.find(self.bot_token)
|
508
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
509
|
+
if self.bot_token not in text:
|
510
|
+
return StreamingParseResult(normal_text=normal_text, calls=[])
|
511
|
+
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
512
|
+
calls = []
|
513
|
+
try:
|
514
|
+
for match_result in match_result_list:
|
515
|
+
# Get function name
|
516
|
+
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
517
|
+
func_name = func_detail.group(2)
|
518
|
+
func_args = func_detail.group(3)
|
519
|
+
func_args = json.loads(func_args)
|
520
|
+
# construct match_result for parse_base_json
|
521
|
+
match_result = {"name": func_name, "parameters": func_args}
|
522
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
523
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
524
|
+
except Exception as e:
|
525
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
526
|
+
# return the normal text if parsing fails
|
527
|
+
return StreamingParseResult(normal_text=text)
|
528
|
+
|
529
|
+
def structure_info(self) -> _GetInfoFunc:
|
530
|
+
return lambda name: StructureInfo(
|
531
|
+
begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
|
532
|
+
+ name
|
533
|
+
+ "\n```json\n",
|
534
|
+
end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
535
|
+
trigger="<|tool▁calls▁begin|>",
|
536
|
+
)
|
537
|
+
|
538
|
+
|
480
539
|
class MultiFormatParser:
|
481
540
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
482
541
|
"""
|
@@ -543,6 +602,7 @@ class FunctionCallParser:
|
|
543
602
|
"llama3": Llama32Detector,
|
544
603
|
"qwen25": Qwen25Detector,
|
545
604
|
"mistral": MistralDetector,
|
605
|
+
"deepseekv3": DeepSeekV3Detector,
|
546
606
|
}
|
547
607
|
|
548
608
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
sglang/srt/layers/activation.py
CHANGED
@@ -21,13 +21,6 @@ import torch
|
|
21
21
|
import torch.nn as nn
|
22
22
|
import torch.nn.functional as F
|
23
23
|
|
24
|
-
from sglang.srt.utils import is_cuda_available
|
25
|
-
|
26
|
-
_is_cuda = is_cuda_available()
|
27
|
-
|
28
|
-
if _is_cuda:
|
29
|
-
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
30
|
-
|
31
24
|
from sglang.srt.custom_op import CustomOp
|
32
25
|
from sglang.srt.distributed import (
|
33
26
|
divide,
|
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
|
|
35
28
|
get_tensor_model_parallel_world_size,
|
36
29
|
)
|
37
30
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
-
from sglang.srt.utils import set_weight_attrs
|
31
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
32
|
+
|
33
|
+
_is_cuda = is_cuda()
|
34
|
+
|
35
|
+
if _is_cuda:
|
36
|
+
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
39
37
|
|
40
38
|
logger = logging.getLogger(__name__)
|
41
39
|
|