sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
from collections import deque
|
23
24
|
from typing import TYPE_CHECKING, List, Optional
|
24
25
|
|
25
26
|
import torch
|
@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
|
|
204
205
|
# Otherwise, it hangs under high concurrency
|
205
206
|
self.running_batch.batch_is_full = False
|
206
207
|
|
208
|
+
@torch.no_grad()
|
209
|
+
def event_loop_overlap_disagg_prefill(self):
|
210
|
+
self.result_queue = deque()
|
211
|
+
|
212
|
+
while True:
|
213
|
+
recv_reqs = self.recv_requests()
|
214
|
+
self.process_input_requests(recv_reqs)
|
215
|
+
self.waiting_queue.extend(
|
216
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
217
|
+
)
|
218
|
+
self.process_prefill_chunk()
|
219
|
+
batch = self.get_new_batch_prefill()
|
220
|
+
self.cur_batch = batch
|
221
|
+
|
222
|
+
if batch:
|
223
|
+
result = self.run_batch(batch)
|
224
|
+
self.result_queue.append((batch.copy(), result))
|
225
|
+
|
226
|
+
if self.last_batch:
|
227
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
228
|
+
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
229
|
+
|
230
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
231
|
+
self.process_disagg_prefill_inflight_queue()
|
232
|
+
|
233
|
+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
234
|
+
self.check_memory()
|
235
|
+
self.new_token_ratio = self.init_new_token_ratio
|
236
|
+
|
237
|
+
self.last_batch = batch
|
238
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
239
|
+
# Otherwise, it hangs under high concurrency
|
240
|
+
self.running_batch.batch_is_full = False
|
241
|
+
|
207
242
|
def process_batch_result_disagg_prefill(
|
208
243
|
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
209
244
|
) -> None:
|
@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
|
|
212
247
|
Adapted from process_batch_result_prefill
|
213
248
|
"""
|
214
249
|
|
215
|
-
|
250
|
+
(
|
251
|
+
logits_output,
|
252
|
+
next_token_ids,
|
253
|
+
extend_input_len_per_req,
|
254
|
+
extend_logprob_start_len_per_req,
|
255
|
+
bid,
|
256
|
+
) = (
|
257
|
+
result.logits_output,
|
258
|
+
result.next_token_ids,
|
259
|
+
result.extend_input_len_per_req,
|
260
|
+
result.extend_logprob_start_len_per_req,
|
261
|
+
result.bid,
|
262
|
+
)
|
263
|
+
|
264
|
+
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
265
|
+
if self.enable_overlap:
|
266
|
+
# wait
|
267
|
+
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
268
|
+
else:
|
269
|
+
next_token_ids = result.next_token_ids.tolist()
|
216
270
|
|
217
271
|
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
218
272
|
req: Req
|
@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
226
280
|
# being chunked reqs' prefill is not finished
|
227
281
|
req.is_chunked -= 1
|
228
282
|
|
229
|
-
|
230
|
-
|
231
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
232
|
-
# We need to remove this for overlap schedule.
|
233
|
-
self.current_stream.synchronize()
|
234
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
283
|
+
if self.enable_overlap:
|
284
|
+
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
235
285
|
|
236
286
|
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
237
287
|
"""
|
@@ -276,34 +326,68 @@ class SchedulerDisaggregationPrefillMixin:
|
|
276
326
|
# only finished requests to running_batch.
|
277
327
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
278
328
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
279
|
-
|
329
|
+
if (
|
330
|
+
self.enable_overlap
|
331
|
+
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
332
|
+
self.chunked_req.tmp_end_idx = min(
|
333
|
+
len(self.chunked_req.fill_ids),
|
334
|
+
len(self.chunked_req.origin_input_ids),
|
335
|
+
)
|
336
|
+
else:
|
337
|
+
self.send_kv_chunk(self.chunked_req)
|
280
338
|
# chunked request keeps its rid but will get a new req_pool_idx
|
281
339
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
282
340
|
self.running_batch.batch_is_full = False
|
283
341
|
|
284
342
|
def send_kv_chunk(
|
285
|
-
self: Scheduler,
|
343
|
+
self: Scheduler,
|
344
|
+
req: Req,
|
345
|
+
token_id: Optional[int] = None,
|
346
|
+
end_idx: Optional[int] = None,
|
286
347
|
) -> None:
|
287
348
|
"""
|
288
349
|
Send a prefilled chunk to the decode server
|
289
350
|
"""
|
351
|
+
page_size = self.token_to_kv_pool_allocator.page_size
|
290
352
|
start_idx = req.start_send_idx
|
291
|
-
end_idx
|
353
|
+
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
354
|
+
# the resolved length is not the same as fill_ids's length
|
355
|
+
end_idx = (
|
356
|
+
end_idx
|
357
|
+
if end_idx is not None
|
358
|
+
else min(len(req.fill_ids), len(req.origin_input_ids))
|
359
|
+
)
|
360
|
+
last_chunk = token_id is not None
|
361
|
+
|
362
|
+
if (not last_chunk) and (
|
363
|
+
end_idx % page_size != 0
|
364
|
+
): # todo: remove the second condition
|
365
|
+
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
366
|
+
end_idx = end_idx - end_idx % page_size
|
292
367
|
|
293
368
|
# Update next start_send_idx
|
294
369
|
req.start_send_idx = end_idx
|
295
370
|
|
296
371
|
kv_indices = (
|
297
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx
|
372
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
298
373
|
.cpu()
|
299
374
|
.numpy()
|
300
375
|
)
|
301
|
-
if
|
376
|
+
if last_chunk is True:
|
302
377
|
self.disagg_prefill_pending_queue.store_prefill_results(
|
303
378
|
req.metadata_buffer_index, token_id
|
304
379
|
)
|
305
|
-
|
306
|
-
|
307
|
-
|
380
|
+
page_indices = kv_to_page_indices(kv_indices, page_size)
|
381
|
+
|
382
|
+
page_start_idx = start_idx // page_size
|
383
|
+
page_end_idx = page_start_idx + len(page_indices)
|
384
|
+
|
385
|
+
if len(page_indices) == 0:
|
386
|
+
logger.info(
|
387
|
+
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
388
|
+
)
|
389
|
+
return
|
390
|
+
|
391
|
+
req.disagg_kv_sender.send(
|
392
|
+
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
308
393
|
)
|
309
|
-
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
@@ -47,6 +47,7 @@ class ReqToMetadataIdxAllocator:
|
|
47
47
|
|
48
48
|
class TransferBackend(Enum):
|
49
49
|
MOONCAKE = "mooncake"
|
50
|
+
NIXL = "nixl"
|
50
51
|
FAKE = "fake"
|
51
52
|
|
52
53
|
|
@@ -73,6 +74,21 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
73
74
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
74
75
|
}
|
75
76
|
return class_mapping.get(class_type)
|
77
|
+
if transfer_backend == TransferBackend.NIXL:
|
78
|
+
from sglang.srt.disaggregation.nixl import (
|
79
|
+
NixlKVBootstrapServer,
|
80
|
+
NixlKVManager,
|
81
|
+
NixlKVReceiver,
|
82
|
+
NixlKVSender,
|
83
|
+
)
|
84
|
+
|
85
|
+
class_mapping = {
|
86
|
+
KVClassType.MANAGER: NixlKVManager,
|
87
|
+
KVClassType.SENDER: NixlKVSender,
|
88
|
+
KVClassType.RECEIVER: NixlKVReceiver,
|
89
|
+
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
|
+
}
|
91
|
+
return class_mapping.get(class_type)
|
76
92
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
77
93
|
|
78
94
|
|
@@ -82,6 +98,7 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|
82
98
|
# The return vector is kv_indices[::page_size] // page_size
|
83
99
|
if page_size == 1: # shortcut
|
84
100
|
return kv_indices
|
101
|
+
|
85
102
|
return kv_indices[::page_size] // page_size
|
86
103
|
|
87
104
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -279,6 +279,10 @@ class Engine(EngineBase):
|
|
279
279
|
self.shutdown()
|
280
280
|
return False
|
281
281
|
|
282
|
+
def flush_cache(self):
|
283
|
+
loop = asyncio.get_event_loop()
|
284
|
+
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
285
|
+
|
282
286
|
def start_profile(self):
|
283
287
|
loop = asyncio.get_event_loop()
|
284
288
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
@@ -25,11 +25,8 @@ import multiprocessing as multiprocessing
|
|
25
25
|
import os
|
26
26
|
import threading
|
27
27
|
import time
|
28
|
-
from ast import Mult
|
29
28
|
from http import HTTPStatus
|
30
|
-
from typing import AsyncIterator, Callable, Dict, Optional
|
31
|
-
|
32
|
-
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
29
|
+
from typing import AsyncIterator, Callable, Dict, Optional
|
33
30
|
|
34
31
|
# Fix a bug of Python threading
|
35
32
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -84,7 +81,6 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
84
81
|
from sglang.srt.reasoning_parser import ReasoningParser
|
85
82
|
from sglang.srt.server_args import ServerArgs
|
86
83
|
from sglang.srt.utils import (
|
87
|
-
MultiprocessingSerializer,
|
88
84
|
add_api_key_middleware,
|
89
85
|
add_prometheus_middleware,
|
90
86
|
delete_directory,
|
@@ -315,11 +311,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
315
311
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
316
312
|
async def flush_cache():
|
317
313
|
"""Flush the radix cache."""
|
318
|
-
_global_state.tokenizer_manager.flush_cache()
|
314
|
+
ret = await _global_state.tokenizer_manager.flush_cache()
|
319
315
|
return Response(
|
320
316
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
321
317
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
322
|
-
status_code=200,
|
318
|
+
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
|
323
319
|
)
|
324
320
|
|
325
321
|
|
@@ -25,6 +25,7 @@ TOOLS_TAG_LIST = [
|
|
25
25
|
"<tool_call>",
|
26
26
|
"<|python_tag|>",
|
27
27
|
"[TOOL_CALLS]",
|
28
|
+
"<|tool▁calls▁begin|>",
|
28
29
|
]
|
29
30
|
|
30
31
|
|
@@ -477,6 +478,64 @@ class Llama32Detector(BaseFormatDetector):
|
|
477
478
|
)
|
478
479
|
|
479
480
|
|
481
|
+
class DeepSeekV3Detector(BaseFormatDetector):
|
482
|
+
"""
|
483
|
+
Detector for DeepSeek models.
|
484
|
+
Assumes function call format:
|
485
|
+
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
486
|
+
"""
|
487
|
+
|
488
|
+
def __init__(self):
|
489
|
+
super().__init__()
|
490
|
+
self.bot_token = "<|tool▁calls▁begin|>"
|
491
|
+
self.eot_token = "<|tool▁calls▁end|>"
|
492
|
+
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
|
493
|
+
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
|
494
|
+
|
495
|
+
def has_tool_call(self, text: str) -> bool:
|
496
|
+
"""Check if the text contains a deepseek format tool call."""
|
497
|
+
return self.bot_token in text
|
498
|
+
|
499
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
500
|
+
"""
|
501
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
502
|
+
|
503
|
+
:param text: The complete text to parse.
|
504
|
+
:param tools: List of available tools.
|
505
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
506
|
+
"""
|
507
|
+
idx = text.find(self.bot_token)
|
508
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
509
|
+
if self.bot_token not in text:
|
510
|
+
return StreamingParseResult(normal_text=normal_text, calls=[])
|
511
|
+
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
|
512
|
+
calls = []
|
513
|
+
try:
|
514
|
+
for match_result in match_result_list:
|
515
|
+
# Get function name
|
516
|
+
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
|
517
|
+
func_name = func_detail.group(2)
|
518
|
+
func_args = func_detail.group(3)
|
519
|
+
func_args = json.loads(func_args)
|
520
|
+
# construct match_result for parse_base_json
|
521
|
+
match_result = {"name": func_name, "parameters": func_args}
|
522
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
523
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
524
|
+
except Exception as e:
|
525
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
526
|
+
# return the normal text if parsing fails
|
527
|
+
return StreamingParseResult(normal_text=text)
|
528
|
+
|
529
|
+
def structure_info(self) -> _GetInfoFunc:
|
530
|
+
return lambda name: StructureInfo(
|
531
|
+
begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
|
532
|
+
+ name
|
533
|
+
+ "\n```json\n",
|
534
|
+
end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
535
|
+
trigger="<|tool▁calls▁begin|>",
|
536
|
+
)
|
537
|
+
|
538
|
+
|
480
539
|
class MultiFormatParser:
|
481
540
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
482
541
|
"""
|
@@ -543,6 +602,7 @@ class FunctionCallParser:
|
|
543
602
|
"llama3": Llama32Detector,
|
544
603
|
"qwen25": Qwen25Detector,
|
545
604
|
"mistral": MistralDetector,
|
605
|
+
"deepseekv3": DeepSeekV3Detector,
|
546
606
|
}
|
547
607
|
|
548
608
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
sglang/srt/layers/activation.py
CHANGED
@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
)
|
30
30
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
|
-
from sglang.srt.utils import
|
31
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
32
32
|
|
33
|
-
_is_cuda =
|
33
|
+
_is_cuda = is_cuda()
|
34
34
|
|
35
35
|
if _is_cuda:
|
36
36
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|