sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
import logging
|
17
17
|
import math
|
18
18
|
import threading
|
19
|
+
import time
|
19
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
20
21
|
from typing import TYPE_CHECKING, List, Optional
|
21
22
|
|
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
|
|
195
196
|
self._done_flag = False
|
196
197
|
self._lock = threading.Lock()
|
197
198
|
|
199
|
+
self.start_time = time.monotonic()
|
200
|
+
|
198
201
|
super().__init__(host_indices, token_ids, last_hash)
|
199
202
|
|
200
203
|
def increment(self, num_tokens: int):
|
@@ -243,12 +246,12 @@ class HiCacheController:
|
|
243
246
|
self.storage_backend = HiCacheFile()
|
244
247
|
self.get_hash_str = get_hash_str
|
245
248
|
elif storage_backend == "nixl":
|
246
|
-
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
249
|
+
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
247
250
|
|
248
251
|
self.storage_backend = HiCacheNixl()
|
249
252
|
self.get_hash_str = get_hash_str
|
250
253
|
elif storage_backend == "mooncake":
|
251
|
-
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
254
|
+
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
252
255
|
MooncakeStore,
|
253
256
|
get_hash_str_mooncake,
|
254
257
|
)
|
@@ -278,6 +281,12 @@ class HiCacheController:
|
|
278
281
|
self.enable_storage = True
|
279
282
|
# todo: threshold policy for prefetching
|
280
283
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
284
|
+
self.prefetch_capacity_limit = int(
|
285
|
+
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
286
|
+
)
|
287
|
+
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
288
|
+
self.prefetch_tokens_occupied = 0
|
289
|
+
|
281
290
|
# create a new communication group for synchronizing storage operations across TP workers
|
282
291
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
283
292
|
if self.tp_world_size > 1:
|
@@ -525,7 +534,7 @@ class HiCacheController:
|
|
525
534
|
host_indices: torch.Tensor,
|
526
535
|
new_input_tokens: List[int],
|
527
536
|
last_hash: Optional[str] = None,
|
528
|
-
) ->
|
537
|
+
) -> PrefetchOperation:
|
529
538
|
"""
|
530
539
|
Prefetch KV caches from storage backend to host memory.
|
531
540
|
"""
|
@@ -586,11 +595,23 @@ class HiCacheController:
|
|
586
595
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
587
596
|
if self.is_mooncake_backend():
|
588
597
|
self.mooncake_page_transfer(operation)
|
598
|
+
elif self.storage_backend_type == "hf3fs":
|
599
|
+
self.generic_page_transfer(operation, batch_size=128)
|
589
600
|
else:
|
590
601
|
self.generic_page_transfer(operation)
|
591
602
|
except Empty:
|
592
603
|
continue
|
593
604
|
|
605
|
+
def prefetch_rate_limit_check(self) -> bool:
|
606
|
+
"""
|
607
|
+
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
608
|
+
"""
|
609
|
+
# cancel prefetch if too much memory is occupied
|
610
|
+
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
611
|
+
return False
|
612
|
+
# todo: more sophisticated rate limiting based on storage backend performance
|
613
|
+
return True
|
614
|
+
|
594
615
|
def prefetch_thread_func(self):
|
595
616
|
"""
|
596
617
|
Manage prefetching operations from storage backend to host memory.
|
@@ -604,34 +625,36 @@ class HiCacheController:
|
|
604
625
|
if operation is None:
|
605
626
|
continue
|
606
627
|
|
607
|
-
last_hash = operation.last_hash
|
608
|
-
tokens_to_fetch = operation.token_ids
|
609
|
-
|
610
628
|
storage_hit_count = 0
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
last_hash
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
629
|
+
if self.prefetch_rate_limit_check():
|
630
|
+
last_hash = operation.last_hash
|
631
|
+
tokens_to_fetch = operation.token_ids
|
632
|
+
|
633
|
+
remaining_tokens = len(tokens_to_fetch)
|
634
|
+
hash_value = []
|
635
|
+
while remaining_tokens >= self.page_size:
|
636
|
+
last_hash = self.get_hash_str(
|
637
|
+
tokens_to_fetch[
|
638
|
+
storage_hit_count : storage_hit_count + self.page_size
|
639
|
+
],
|
640
|
+
last_hash,
|
641
|
+
)
|
642
|
+
|
643
|
+
# todo, more unified interface
|
644
|
+
if not self.is_mooncake_backend():
|
645
|
+
if not self.storage_backend.exists(last_hash):
|
646
|
+
break
|
647
|
+
hash_value.append(last_hash)
|
648
|
+
storage_hit_count += self.page_size
|
649
|
+
remaining_tokens -= self.page_size
|
650
|
+
|
651
|
+
if self.is_mooncake_backend():
|
652
|
+
# deferring to batch exists for mooncake store
|
653
|
+
exist_result = self.storage_backend.exists(hash_value)
|
654
|
+
storage_hit_count = (
|
655
|
+
sum(1 for v in exist_result.values() if v != 0)
|
656
|
+
* self.page_size
|
657
|
+
)
|
635
658
|
|
636
659
|
if self.tp_world_size > 1:
|
637
660
|
storage_hit_count_tensor = torch.tensor(
|
@@ -750,6 +773,8 @@ class HiCacheController:
|
|
750
773
|
|
751
774
|
if self.is_mooncake_backend():
|
752
775
|
self.mooncake_page_backup(operation)
|
776
|
+
elif self.storage_backend_type == "hf3fs":
|
777
|
+
self.generic_page_backup(operation, batch_size=128)
|
753
778
|
else:
|
754
779
|
self.generic_page_backup(operation)
|
755
780
|
|
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
|
216
216
|
rids=recv_obj.rids,
|
217
217
|
finished_reasons=recv_obj.finished_reasons,
|
218
218
|
output_strs=output_strs,
|
219
|
-
output_ids=
|
219
|
+
output_ids=recv_obj.decode_ids,
|
220
220
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
221
|
completion_tokens=recv_obj.completion_tokens,
|
222
222
|
cached_tokens=recv_obj.cached_tokens,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
|
|
26
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
28
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
29
|
+
from sglang.srt.utils import ImageData
|
29
30
|
|
30
31
|
# Handle serialization of Image for pydantic
|
31
32
|
if TYPE_CHECKING:
|
@@ -45,7 +46,7 @@ class SessionParams:
|
|
45
46
|
|
46
47
|
# Type definitions for multimodal input data
|
47
48
|
# Individual data item types for each modality
|
48
|
-
ImageDataInputItem = Union[Image, str, Dict]
|
49
|
+
ImageDataInputItem = Union[Image, str, ImageData, Dict]
|
49
50
|
AudioDataInputItem = Union[str, Dict]
|
50
51
|
VideoDataInputItem = Union[str, Dict]
|
51
52
|
# Union type for any multimodal data item
|
@@ -101,8 +102,10 @@ class GenerateReqInput:
|
|
101
102
|
|
102
103
|
# The modalities of the image data [image, multi-images, video]
|
103
104
|
modalities: Optional[List[str]] = None
|
104
|
-
# The path to the LoRA
|
105
|
+
# The path to the LoRA adaptors
|
105
106
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
107
|
+
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
108
|
+
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
106
109
|
|
107
110
|
# Session info for continual prompting
|
108
111
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
@@ -123,6 +126,9 @@ class GenerateReqInput:
|
|
123
126
|
# For data parallel rank routing
|
124
127
|
data_parallel_rank: Optional[int] = None
|
125
128
|
|
129
|
+
# For background responses (OpenAI responses API)
|
130
|
+
background: bool = False
|
131
|
+
|
126
132
|
def contains_mm_input(self) -> bool:
|
127
133
|
return (
|
128
134
|
has_valid_data(self.image_data)
|
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
|
|
500
506
|
stream: bool
|
501
507
|
|
502
508
|
# LoRA related
|
503
|
-
|
509
|
+
lora_id: Optional[str] = None # None means just use the base model
|
504
510
|
# The input embeds
|
505
511
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
506
512
|
|
@@ -557,6 +563,9 @@ class EmbeddingReqInput:
|
|
557
563
|
# For cross-encoder requests
|
558
564
|
is_cross_encoder_request: bool = False
|
559
565
|
|
566
|
+
# For background responses (OpenAI responses API)
|
567
|
+
background: bool = False
|
568
|
+
|
560
569
|
def normalize_batch_and_arguments(self):
|
561
570
|
# at least one of text, input_ids, or image should be provided
|
562
571
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -1073,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
|
|
1073
1082
|
lora_name: str
|
1074
1083
|
# The path of loading.
|
1075
1084
|
lora_path: str
|
1085
|
+
# Whether to pin the LoRA adapter in memory.
|
1086
|
+
pinned: bool = False
|
1076
1087
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1077
1088
|
lora_id: Optional[str] = None
|
1078
1089
|
|
@@ -1081,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
|
|
1081
1092
|
lora_id=self.lora_id,
|
1082
1093
|
lora_name=self.lora_name,
|
1083
1094
|
lora_path=self.lora_path,
|
1095
|
+
pinned=self.pinned,
|
1084
1096
|
)
|
1085
1097
|
|
1086
1098
|
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
|
|
388
388
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
389
389
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
390
390
|
print_warning_once(
|
391
|
-
"Multimodal embedding cache is full.
|
392
|
-
"
|
391
|
+
"Multimodal embedding cache is full. This typically occurs when a single "
|
392
|
+
"embedding exceeds the cache size limit. Consider increasing the "
|
393
|
+
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
394
|
+
"embedding size."
|
393
395
|
)
|
394
396
|
|
395
|
-
embedding_per_req_chunk, _,
|
397
|
+
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
396
398
|
embedding=embedding_per_req,
|
397
399
|
extend_prefix_len=prefix_length[i],
|
398
400
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
399
401
|
items_offset=items_offset,
|
400
402
|
)
|
401
|
-
# remove this item from cache if chunk reaches to the end
|
402
|
-
embedding_per_req_length = (
|
403
|
-
embedding_per_req.shape[0]
|
404
|
-
if embedding_per_req.dim() == 2
|
405
|
-
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
406
|
-
)
|
407
|
-
if end_index == embedding_per_req_length:
|
408
|
-
embedding_cache.free(embedding_items_hash)
|
409
403
|
embedding_list.append(embedding_per_req_chunk)
|
410
404
|
if len(embedding_list) == 0:
|
411
405
|
return None
|
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
51
51
|
ScheduleBatchDisaggregationDecodeMixin,
|
52
52
|
)
|
53
53
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
54
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
55
54
|
from sglang.srt.mem_cache.allocator import (
|
56
55
|
BaseTokenToKVPoolAllocator,
|
57
56
|
SWATokenToKVPoolAllocator,
|
@@ -85,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
84
|
"disable_radix_cache",
|
86
85
|
"enable_dp_attention",
|
87
86
|
"enable_two_batch_overlap",
|
87
|
+
"tbo_token_distribution_threshold",
|
88
88
|
"enable_dp_lm_head",
|
89
89
|
"moe_a2a_backend",
|
90
90
|
"deepep_mode",
|
@@ -107,8 +107,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
107
107
|
"num_reserved_decode_tokens",
|
108
108
|
"weight_loader_disable_mmap",
|
109
109
|
"enable_triton_kernel_moe",
|
110
|
+
"enable_flashinfer_mxfp4_moe",
|
110
111
|
"enable_multimodal",
|
111
112
|
"enable_symm_mem",
|
113
|
+
"quantization",
|
112
114
|
]
|
113
115
|
|
114
116
|
# Put some global args for easy access
|
@@ -423,7 +425,7 @@ class Req:
|
|
423
425
|
token_ids_logprob: List[int] = None,
|
424
426
|
stream: bool = False,
|
425
427
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
426
|
-
|
428
|
+
lora_id: Optional[str] = None,
|
427
429
|
input_embeds: Optional[List[List[float]]] = None,
|
428
430
|
token_type_ids: List[int] = None,
|
429
431
|
session_id: Optional[str] = None,
|
@@ -467,7 +469,7 @@ class Req:
|
|
467
469
|
self.sampling_params = sampling_params
|
468
470
|
self.custom_logit_processor = custom_logit_processor
|
469
471
|
self.return_hidden_states = return_hidden_states
|
470
|
-
self.
|
472
|
+
self.lora_id = lora_id
|
471
473
|
|
472
474
|
# Memory pool info
|
473
475
|
self.req_pool_idx: Optional[int] = None
|
@@ -845,6 +847,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
845
847
|
|
846
848
|
# The sum of all sequence lengths
|
847
849
|
seq_lens_sum: int = None
|
850
|
+
# The original sequence lengths, Qwen-1M related
|
851
|
+
orig_seq_lens: torch.Tensor = None # shape: [b], int32
|
848
852
|
|
849
853
|
# For DP attention
|
850
854
|
global_num_tokens: Optional[List[int]] = None
|
@@ -917,8 +921,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
917
921
|
|
918
922
|
is_hybrid = False
|
919
923
|
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
920
|
-
assert
|
921
|
-
tree_cache
|
924
|
+
assert (
|
925
|
+
tree_cache is None
|
926
|
+
or isinstance(tree_cache, SWARadixCache)
|
927
|
+
or isinstance(tree_cache, SWAChunkCache)
|
922
928
|
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
923
929
|
is_hybrid = True
|
924
930
|
|
@@ -1128,6 +1134,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1128
1134
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
1129
1135
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
1130
1136
|
seq_lens = [len(r.fill_ids) for r in reqs]
|
1137
|
+
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
|
1131
1138
|
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
1132
1139
|
extend_lens = [r.extend_input_len for r in reqs]
|
1133
1140
|
|
@@ -1144,6 +1151,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1144
1151
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1145
1152
|
self.device, non_blocking=True
|
1146
1153
|
)
|
1154
|
+
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1155
|
+
self.device, non_blocking=True
|
1156
|
+
)
|
1147
1157
|
prefix_lens_tensor = torch.tensor(
|
1148
1158
|
prefix_lens, dtype=torch.int64, device=self.device
|
1149
1159
|
)
|
@@ -1257,6 +1267,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1257
1267
|
self.input_ids = input_ids_tensor
|
1258
1268
|
self.req_pool_indices = req_pool_indices_tensor
|
1259
1269
|
self.seq_lens = seq_lens_tensor
|
1270
|
+
self.orig_seq_lens = orig_seq_lens_tensor
|
1260
1271
|
self.out_cache_loc = out_cache_loc
|
1261
1272
|
self.input_embeds = (
|
1262
1273
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
@@ -1504,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1504
1515
|
self.forward_mode = ForwardMode.IDLE
|
1505
1516
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1506
1517
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1518
|
+
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1507
1519
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1508
1520
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1509
1521
|
self.seq_lens_sum = 0
|
@@ -1558,9 +1570,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1558
1570
|
if self.enable_overlap:
|
1559
1571
|
# Do not use in-place operations in the overlap mode
|
1560
1572
|
self.seq_lens = self.seq_lens + 1
|
1573
|
+
self.orig_seq_lens = self.orig_seq_lens + 1
|
1561
1574
|
else:
|
1562
1575
|
# A faster in-place version
|
1563
1576
|
self.seq_lens.add_(1)
|
1577
|
+
self.orig_seq_lens.add_(1)
|
1564
1578
|
self.seq_lens_sum += bs
|
1565
1579
|
|
1566
1580
|
# free memory
|
@@ -1624,6 +1638,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1624
1638
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1625
1639
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1626
1640
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1641
|
+
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1627
1642
|
self.out_cache_loc = None
|
1628
1643
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
1629
1644
|
self.output_ids = self.output_ids[keep_indices_device]
|
@@ -1656,6 +1671,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1656
1671
|
[self.req_pool_indices, other.req_pool_indices]
|
1657
1672
|
)
|
1658
1673
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1674
|
+
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1659
1675
|
self.out_cache_loc = None
|
1660
1676
|
self.seq_lens_sum += other.seq_lens_sum
|
1661
1677
|
if self.output_ids is not None:
|
@@ -1705,6 +1721,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1705
1721
|
or attention_backend_str == "flashmla"
|
1706
1722
|
or attention_backend_str == "cutlass_mla"
|
1707
1723
|
or attention_backend_str == "ascend"
|
1724
|
+
or attention_backend_str == "trtllm_mha"
|
1708
1725
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1709
1726
|
):
|
1710
1727
|
seq_lens_cpu = (
|
@@ -1729,6 +1746,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1729
1746
|
input_ids=self.input_ids,
|
1730
1747
|
req_pool_indices=self.req_pool_indices,
|
1731
1748
|
seq_lens=self.seq_lens,
|
1749
|
+
orig_seq_lens=self.orig_seq_lens,
|
1732
1750
|
out_cache_loc=self.out_cache_loc,
|
1733
1751
|
seq_lens_cpu=seq_lens_cpu,
|
1734
1752
|
seq_lens_sum=self.seq_lens_sum,
|
@@ -1750,7 +1768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1750
1768
|
encoder_lens=self.encoder_lens,
|
1751
1769
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
1752
1770
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1753
|
-
|
1771
|
+
lora_ids=[req.lora_id for req in self.reqs],
|
1754
1772
|
sampling_info=self.sampling_info,
|
1755
1773
|
input_embeds=self.input_embeds,
|
1756
1774
|
token_type_ids=self.token_type_ids,
|
@@ -1891,11 +1909,14 @@ class ModelWorkerBatch:
|
|
1891
1909
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
1892
1910
|
|
1893
1911
|
# For LoRA
|
1894
|
-
|
1912
|
+
lora_ids: Optional[List[str]]
|
1895
1913
|
|
1896
1914
|
# Sampling info
|
1897
1915
|
sampling_info: SamplingBatchInfo
|
1898
1916
|
|
1917
|
+
# The original sequence lengths, Qwen-1M related
|
1918
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
1919
|
+
|
1899
1920
|
# The input Embeds
|
1900
1921
|
input_embeds: Optional[torch.Tensor] = None
|
1901
1922
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|
120
120
|
SchedulerOutputProcessorMixin,
|
121
121
|
)
|
122
122
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
123
|
+
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
123
124
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
124
125
|
SchedulerUpdateWeightsMixin,
|
125
126
|
)
|
@@ -472,8 +473,10 @@ class Scheduler(
|
|
472
473
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
473
474
|
enable=server_args.enable_memory_saver
|
474
475
|
)
|
476
|
+
self.offload_tags = set()
|
475
477
|
self.init_profier()
|
476
478
|
|
479
|
+
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
477
480
|
self.input_blocker = (
|
478
481
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
479
482
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
@@ -616,6 +619,7 @@ class Scheduler(
|
|
616
619
|
),
|
617
620
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
618
621
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
622
|
+
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
619
623
|
)
|
620
624
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
621
625
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -946,6 +950,14 @@ class Scheduler(
|
|
946
950
|
|
947
951
|
def recv_requests(self) -> List[Req]:
|
948
952
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
953
|
+
|
954
|
+
if self.recv_skipper is not None:
|
955
|
+
last_forward_mode = (
|
956
|
+
self.last_batch.forward_mode if self.last_batch is not None else None
|
957
|
+
)
|
958
|
+
if not self.recv_skipper.handle(last_forward_mode):
|
959
|
+
return []
|
960
|
+
|
949
961
|
if self.pp_rank == 0:
|
950
962
|
if self.attn_tp_rank == 0:
|
951
963
|
recv_reqs = []
|
@@ -1029,7 +1041,9 @@ class Scheduler(
|
|
1029
1041
|
for recv_req in recv_reqs:
|
1030
1042
|
# If it is a health check generation request and there are running requests, ignore it.
|
1031
1043
|
if is_health_check_generate_req(recv_req) and (
|
1032
|
-
self.chunked_req is not None
|
1044
|
+
self.chunked_req is not None
|
1045
|
+
or not self.running_batch.is_empty()
|
1046
|
+
or len(self.offload_tags) > 0
|
1033
1047
|
):
|
1034
1048
|
self.return_health_check_ct += 1
|
1035
1049
|
continue
|
@@ -1090,7 +1104,7 @@ class Scheduler(
|
|
1090
1104
|
top_logprobs_num=recv_req.top_logprobs_num,
|
1091
1105
|
token_ids_logprob=recv_req.token_ids_logprob,
|
1092
1106
|
stream=recv_req.stream,
|
1093
|
-
|
1107
|
+
lora_id=recv_req.lora_id,
|
1094
1108
|
input_embeds=recv_req.input_embeds,
|
1095
1109
|
custom_logit_processor=recv_req.custom_logit_processor,
|
1096
1110
|
return_hidden_states=recv_req.return_hidden_states,
|
@@ -1534,18 +1548,15 @@ class Scheduler(
|
|
1534
1548
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1535
1549
|
|
1536
1550
|
if self.enable_lora:
|
1537
|
-
lora_set = set([req.
|
1551
|
+
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
1538
1552
|
|
1539
1553
|
# Get requests from the waiting queue to a new prefill batch
|
1540
1554
|
for req in self.waiting_queue:
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
| set([req.lora_path])
|
1547
|
-
)
|
1548
|
-
> self.max_loras_per_batch
|
1555
|
+
|
1556
|
+
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
|
1557
|
+
lora_set
|
1558
|
+
| set([req.lora_id for req in adder.can_run_list])
|
1559
|
+
| set([req.lora_id])
|
1549
1560
|
):
|
1550
1561
|
self.running_batch.batch_is_full = True
|
1551
1562
|
break
|
@@ -1562,7 +1573,10 @@ class Scheduler(
|
|
1562
1573
|
break
|
1563
1574
|
|
1564
1575
|
if self.enable_hicache_storage:
|
1565
|
-
self.tree_cache.check_prefetch_progress(req.rid)
|
1576
|
+
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
1577
|
+
if not prefetch_done:
|
1578
|
+
# skip staging requests that are ongoing prefetch
|
1579
|
+
continue
|
1566
1580
|
|
1567
1581
|
req.init_next_round_input(self.tree_cache)
|
1568
1582
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
|
|
571
571
|
|
572
572
|
req.send_decode_id_offset = len(decode_ids)
|
573
573
|
read_offsets.append(read_offset)
|
574
|
-
|
575
|
-
output_ids.append(req.output_ids[send_token_offset:])
|
574
|
+
output_ids.append(req.output_ids[send_token_offset:])
|
576
575
|
req.send_token_offset = len(req.output_ids)
|
577
576
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
578
577
|
spaces_between_special_tokens.append(
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
2
|
+
from sglang.srt.server_args import ServerArgs
|
3
|
+
|
4
|
+
|
5
|
+
class SchedulerRecvSkipper:
|
6
|
+
@staticmethod
|
7
|
+
def maybe_create(server_args: ServerArgs):
|
8
|
+
if server_args.scheduler_recv_interval <= 1:
|
9
|
+
return None
|
10
|
+
return SchedulerRecvSkipper(server_args)
|
11
|
+
|
12
|
+
def __init__(self, server_args: ServerArgs):
|
13
|
+
# Can be supported if needed, but may need e.g. `global_forward_mode`
|
14
|
+
assert not server_args.enable_dp_attention
|
15
|
+
self._counter = 0
|
16
|
+
self._threshold = server_args.scheduler_recv_interval
|
17
|
+
|
18
|
+
def handle(self, last_forward_mode: ForwardMode):
|
19
|
+
should_recv = False
|
20
|
+
|
21
|
+
last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
|
22
|
+
self._counter += last_weight
|
23
|
+
|
24
|
+
if self._counter >= self._threshold:
|
25
|
+
self._counter = 0
|
26
|
+
should_recv = True
|
27
|
+
|
28
|
+
return should_recv
|
29
|
+
|
30
|
+
|
31
|
+
# All can be tuned if needed
|
32
|
+
_DEFAULT_WEIGHT = 1000
|
33
|
+
_WEIGHT_OF_FORWARD_MODE = {
|
34
|
+
ForwardMode.DECODE: 1,
|
35
|
+
ForwardMode.TARGET_VERIFY: 1,
|
36
|
+
None: 1,
|
37
|
+
}
|
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
|
|
78
78
|
if tags is None or len(tags) == 0:
|
79
79
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
80
80
|
|
81
|
+
for tag in tags:
|
82
|
+
self.offload_tags.add(tag)
|
83
|
+
|
81
84
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
82
85
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
83
86
|
self.flush_cache()
|
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
|
|
97
100
|
if tags is None or len(tags) == 0:
|
98
101
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
99
102
|
|
103
|
+
for tag in tags:
|
104
|
+
self.offload_tags.remove(tag)
|
105
|
+
|
100
106
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
101
107
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
102
108
|
torch.distributed.barrier(self.tp_cpu_group)
|
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
|
|
21
21
|
import json
|
22
22
|
import logging
|
23
23
|
import os
|
24
|
+
import re
|
24
25
|
from typing import Optional
|
25
26
|
|
26
27
|
from sglang.srt.code_completion_parser import (
|
@@ -54,6 +55,7 @@ class TemplateManager:
|
|
54
55
|
self._chat_template_name: Optional[str] = None
|
55
56
|
self._completion_template_name: Optional[str] = None
|
56
57
|
self._jinja_template_content_format: Optional[str] = "openai"
|
58
|
+
self._force_reasoning: bool = False
|
57
59
|
|
58
60
|
@property
|
59
61
|
def chat_template_name(self) -> Optional[str]:
|
@@ -70,6 +72,31 @@ class TemplateManager:
|
|
70
72
|
"""Get the detected template content format ('string' or 'openai' or None)."""
|
71
73
|
return self._jinja_template_content_format
|
72
74
|
|
75
|
+
@property
|
76
|
+
def force_reasoning(self) -> bool:
|
77
|
+
"""
|
78
|
+
Check if the current chat template enforces reasoning/thinking.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
True if the template contains reasoning patterns like <think> tags
|
82
|
+
"""
|
83
|
+
return self._force_reasoning
|
84
|
+
|
85
|
+
def _detect_reasoning_pattern(self, template: str) -> bool:
|
86
|
+
"""
|
87
|
+
Detect if the chat template contains reasoning/thinking patterns.
|
88
|
+
"""
|
89
|
+
if template is None:
|
90
|
+
return False
|
91
|
+
|
92
|
+
force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
|
93
|
+
has_reasoning = re.search(force_reasoning_pattern, template) is not None
|
94
|
+
|
95
|
+
if has_reasoning:
|
96
|
+
logger.info("Detected the force reasoning pattern in chat template.")
|
97
|
+
|
98
|
+
return has_reasoning
|
99
|
+
|
73
100
|
def load_chat_template(
|
74
101
|
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
|
75
102
|
) -> None:
|
@@ -93,7 +120,8 @@ class TemplateManager:
|
|
93
120
|
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
94
121
|
if hf_template:
|
95
122
|
# override the chat template
|
96
|
-
tokenizer_manager.tokenizer
|
123
|
+
if tokenizer_manager.tokenizer:
|
124
|
+
tokenizer_manager.tokenizer.chat_template = hf_template
|
97
125
|
self._jinja_template_content_format = (
|
98
126
|
detect_jinja_template_content_format(hf_template)
|
99
127
|
)
|
@@ -106,6 +134,12 @@ class TemplateManager:
|
|
106
134
|
self._jinja_template_content_format = "string"
|
107
135
|
logger.info("No chat template found, defaulting to 'string' content format")
|
108
136
|
|
137
|
+
# Detect reasoning pattern from chat template
|
138
|
+
if tokenizer_manager.tokenizer:
|
139
|
+
self._force_reasoning = self._detect_reasoning_pattern(
|
140
|
+
tokenizer_manager.tokenizer.chat_template
|
141
|
+
)
|
142
|
+
|
109
143
|
def _load_explicit_chat_template(
|
110
144
|
self, tokenizer_manager, chat_template_arg: str
|
111
145
|
) -> None:
|