sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  import logging
17
17
  import math
18
18
  import threading
19
+ import time
19
20
  from queue import Empty, Full, PriorityQueue, Queue
20
21
  from typing import TYPE_CHECKING, List, Optional
21
22
 
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
195
196
  self._done_flag = False
196
197
  self._lock = threading.Lock()
197
198
 
199
+ self.start_time = time.monotonic()
200
+
198
201
  super().__init__(host_indices, token_ids, last_hash)
199
202
 
200
203
  def increment(self, num_tokens: int):
@@ -243,12 +246,12 @@ class HiCacheController:
243
246
  self.storage_backend = HiCacheFile()
244
247
  self.get_hash_str = get_hash_str
245
248
  elif storage_backend == "nixl":
246
- from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
249
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
247
250
 
248
251
  self.storage_backend = HiCacheNixl()
249
252
  self.get_hash_str = get_hash_str
250
253
  elif storage_backend == "mooncake":
251
- from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
254
+ from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
252
255
  MooncakeStore,
253
256
  get_hash_str_mooncake,
254
257
  )
@@ -278,6 +281,12 @@ class HiCacheController:
278
281
  self.enable_storage = True
279
282
  # todo: threshold policy for prefetching
280
283
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
284
+ self.prefetch_capacity_limit = int(
285
+ 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
286
+ )
287
+ # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
288
+ self.prefetch_tokens_occupied = 0
289
+
281
290
  # create a new communication group for synchronizing storage operations across TP workers
282
291
  self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
283
292
  if self.tp_world_size > 1:
@@ -525,7 +534,7 @@ class HiCacheController:
525
534
  host_indices: torch.Tensor,
526
535
  new_input_tokens: List[int],
527
536
  last_hash: Optional[str] = None,
528
- ) -> int:
537
+ ) -> PrefetchOperation:
529
538
  """
530
539
  Prefetch KV caches from storage backend to host memory.
531
540
  """
@@ -586,11 +595,23 @@ class HiCacheController:
586
595
  operation = self.prefetch_buffer.get(block=True, timeout=1)
587
596
  if self.is_mooncake_backend():
588
597
  self.mooncake_page_transfer(operation)
598
+ elif self.storage_backend_type == "hf3fs":
599
+ self.generic_page_transfer(operation, batch_size=128)
589
600
  else:
590
601
  self.generic_page_transfer(operation)
591
602
  except Empty:
592
603
  continue
593
604
 
605
+ def prefetch_rate_limit_check(self) -> bool:
606
+ """
607
+ Rate limit the prefetching operations to avoid overwhelming the storage backend.
608
+ """
609
+ # cancel prefetch if too much memory is occupied
610
+ if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
611
+ return False
612
+ # todo: more sophisticated rate limiting based on storage backend performance
613
+ return True
614
+
594
615
  def prefetch_thread_func(self):
595
616
  """
596
617
  Manage prefetching operations from storage backend to host memory.
@@ -604,34 +625,36 @@ class HiCacheController:
604
625
  if operation is None:
605
626
  continue
606
627
 
607
- last_hash = operation.last_hash
608
- tokens_to_fetch = operation.token_ids
609
-
610
628
  storage_hit_count = 0
611
- remaining_tokens = len(tokens_to_fetch)
612
- hash_value = []
613
- while remaining_tokens >= self.page_size:
614
- last_hash = self.get_hash_str(
615
- tokens_to_fetch[
616
- storage_hit_count : storage_hit_count + self.page_size
617
- ],
618
- last_hash,
619
- )
620
-
621
- # todo, more unified interface
622
- if not self.is_mooncake_backend():
623
- if not self.storage_backend.exists(last_hash):
624
- break
625
- hash_value.append(last_hash)
626
- storage_hit_count += self.page_size
627
- remaining_tokens -= self.page_size
628
-
629
- if self.is_mooncake_backend():
630
- # deferring to batch exists for mooncake store
631
- exist_result = self.storage_backend.exists(hash_value)
632
- storage_hit_count = (
633
- sum(1 for v in exist_result.values() if v != 0) * self.page_size
634
- )
629
+ if self.prefetch_rate_limit_check():
630
+ last_hash = operation.last_hash
631
+ tokens_to_fetch = operation.token_ids
632
+
633
+ remaining_tokens = len(tokens_to_fetch)
634
+ hash_value = []
635
+ while remaining_tokens >= self.page_size:
636
+ last_hash = self.get_hash_str(
637
+ tokens_to_fetch[
638
+ storage_hit_count : storage_hit_count + self.page_size
639
+ ],
640
+ last_hash,
641
+ )
642
+
643
+ # todo, more unified interface
644
+ if not self.is_mooncake_backend():
645
+ if not self.storage_backend.exists(last_hash):
646
+ break
647
+ hash_value.append(last_hash)
648
+ storage_hit_count += self.page_size
649
+ remaining_tokens -= self.page_size
650
+
651
+ if self.is_mooncake_backend():
652
+ # deferring to batch exists for mooncake store
653
+ exist_result = self.storage_backend.exists(hash_value)
654
+ storage_hit_count = (
655
+ sum(1 for v in exist_result.values() if v != 0)
656
+ * self.page_size
657
+ )
635
658
 
636
659
  if self.tp_world_size > 1:
637
660
  storage_hit_count_tensor = torch.tensor(
@@ -750,6 +773,8 @@ class HiCacheController:
750
773
 
751
774
  if self.is_mooncake_backend():
752
775
  self.mooncake_page_backup(operation)
776
+ elif self.storage_backend_type == "hf3fs":
777
+ self.generic_page_backup(operation, batch_size=128)
753
778
  else:
754
779
  self.generic_page_backup(operation)
755
780
 
@@ -216,7 +216,7 @@ class DetokenizerManager:
216
216
  rids=recv_obj.rids,
217
217
  finished_reasons=recv_obj.finished_reasons,
218
218
  output_strs=output_strs,
219
- output_ids=None,
219
+ output_ids=recv_obj.decode_ids,
220
220
  prompt_tokens=recv_obj.prompt_tokens,
221
221
  completion_tokens=recv_obj.completion_tokens,
222
222
  cached_tokens=recv_obj.cached_tokens,
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
26
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
28
28
  from sglang.srt.sampling.sampling_params import SamplingParams
29
+ from sglang.srt.utils import ImageData
29
30
 
30
31
  # Handle serialization of Image for pydantic
31
32
  if TYPE_CHECKING:
@@ -45,7 +46,7 @@ class SessionParams:
45
46
 
46
47
  # Type definitions for multimodal input data
47
48
  # Individual data item types for each modality
48
- ImageDataInputItem = Union[Image, str, Dict]
49
+ ImageDataInputItem = Union[Image, str, ImageData, Dict]
49
50
  AudioDataInputItem = Union[str, Dict]
50
51
  VideoDataInputItem = Union[str, Dict]
51
52
  # Union type for any multimodal data item
@@ -101,8 +102,10 @@ class GenerateReqInput:
101
102
 
102
103
  # The modalities of the image data [image, multi-images, video]
103
104
  modalities: Optional[List[str]] = None
104
- # The path to the LoRA
105
+ # The path to the LoRA adaptors
105
106
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
107
+ # The uid of LoRA adaptors, should be initialized by tokenizer manager
108
+ lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
106
109
 
107
110
  # Session info for continual prompting
108
111
  session_params: Optional[Union[List[Dict], Dict]] = None
@@ -123,6 +126,9 @@ class GenerateReqInput:
123
126
  # For data parallel rank routing
124
127
  data_parallel_rank: Optional[int] = None
125
128
 
129
+ # For background responses (OpenAI responses API)
130
+ background: bool = False
131
+
126
132
  def contains_mm_input(self) -> bool:
127
133
  return (
128
134
  has_valid_data(self.image_data)
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
500
506
  stream: bool
501
507
 
502
508
  # LoRA related
503
- lora_path: Optional[str] = None # None means just use the base model
509
+ lora_id: Optional[str] = None # None means just use the base model
504
510
  # The input embeds
505
511
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
506
512
 
@@ -557,6 +563,9 @@ class EmbeddingReqInput:
557
563
  # For cross-encoder requests
558
564
  is_cross_encoder_request: bool = False
559
565
 
566
+ # For background responses (OpenAI responses API)
567
+ background: bool = False
568
+
560
569
  def normalize_batch_and_arguments(self):
561
570
  # at least one of text, input_ids, or image should be provided
562
571
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -1073,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
1073
1082
  lora_name: str
1074
1083
  # The path of loading.
1075
1084
  lora_path: str
1085
+ # Whether to pin the LoRA adapter in memory.
1086
+ pinned: bool = False
1076
1087
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1077
1088
  lora_id: Optional[str] = None
1078
1089
 
@@ -1081,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
1081
1092
  lora_id=self.lora_id,
1082
1093
  lora_name=self.lora_name,
1083
1094
  lora_path=self.lora_path,
1095
+ pinned=self.pinned,
1084
1096
  )
1085
1097
 
1086
1098
 
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
388
388
  embedding_per_req = data_embedding_func(embedding_items_per_req)
389
389
  if not embedding_cache.put(embedding_items_hash, embedding_per_req):
390
390
  print_warning_once(
391
- "Multimodal embedding cache is full. Consider increasing the "
392
- "`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
391
+ "Multimodal embedding cache is full. This typically occurs when a single "
392
+ "embedding exceeds the cache size limit. Consider increasing the "
393
+ "`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
394
+ "embedding size."
393
395
  )
394
396
 
395
- embedding_per_req_chunk, _, end_index = get_embedding_chunk(
397
+ embedding_per_req_chunk, _, _ = get_embedding_chunk(
396
398
  embedding=embedding_per_req,
397
399
  extend_prefix_len=prefix_length[i],
398
400
  extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
399
401
  items_offset=items_offset,
400
402
  )
401
- # remove this item from cache if chunk reaches to the end
402
- embedding_per_req_length = (
403
- embedding_per_req.shape[0]
404
- if embedding_per_req.dim() == 2
405
- else embedding_per_req.shape[0] * embedding_per_req.shape[1]
406
- )
407
- if end_index == embedding_per_req_length:
408
- embedding_cache.free(embedding_items_hash)
409
403
  embedding_list.append(embedding_per_req_chunk)
410
404
  if len(embedding_list) == 0:
411
405
  return None
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
51
51
  ScheduleBatchDisaggregationDecodeMixin,
52
52
  )
53
53
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
54
- from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
55
54
  from sglang.srt.mem_cache.allocator import (
56
55
  BaseTokenToKVPoolAllocator,
57
56
  SWATokenToKVPoolAllocator,
@@ -85,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
84
  "disable_radix_cache",
86
85
  "enable_dp_attention",
87
86
  "enable_two_batch_overlap",
87
+ "tbo_token_distribution_threshold",
88
88
  "enable_dp_lm_head",
89
89
  "moe_a2a_backend",
90
90
  "deepep_mode",
@@ -107,8 +107,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
107
107
  "num_reserved_decode_tokens",
108
108
  "weight_loader_disable_mmap",
109
109
  "enable_triton_kernel_moe",
110
+ "enable_flashinfer_mxfp4_moe",
110
111
  "enable_multimodal",
111
112
  "enable_symm_mem",
113
+ "quantization",
112
114
  ]
113
115
 
114
116
  # Put some global args for easy access
@@ -423,7 +425,7 @@ class Req:
423
425
  token_ids_logprob: List[int] = None,
424
426
  stream: bool = False,
425
427
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
426
- lora_path: Optional[str] = None,
428
+ lora_id: Optional[str] = None,
427
429
  input_embeds: Optional[List[List[float]]] = None,
428
430
  token_type_ids: List[int] = None,
429
431
  session_id: Optional[str] = None,
@@ -467,7 +469,7 @@ class Req:
467
469
  self.sampling_params = sampling_params
468
470
  self.custom_logit_processor = custom_logit_processor
469
471
  self.return_hidden_states = return_hidden_states
470
- self.lora_path = lora_path
472
+ self.lora_id = lora_id
471
473
 
472
474
  # Memory pool info
473
475
  self.req_pool_idx: Optional[int] = None
@@ -845,6 +847,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
845
847
 
846
848
  # The sum of all sequence lengths
847
849
  seq_lens_sum: int = None
850
+ # The original sequence lengths, Qwen-1M related
851
+ orig_seq_lens: torch.Tensor = None # shape: [b], int32
848
852
 
849
853
  # For DP attention
850
854
  global_num_tokens: Optional[List[int]] = None
@@ -917,8 +921,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
917
921
 
918
922
  is_hybrid = False
919
923
  if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
920
- assert isinstance(tree_cache, SWARadixCache) or isinstance(
921
- tree_cache, SWAChunkCache
924
+ assert (
925
+ tree_cache is None
926
+ or isinstance(tree_cache, SWARadixCache)
927
+ or isinstance(tree_cache, SWAChunkCache)
922
928
  ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
923
929
  is_hybrid = True
924
930
 
@@ -1128,6 +1134,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1128
1134
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1129
1135
  extend_num_tokens = sum(len(ids) for ids in input_ids)
1130
1136
  seq_lens = [len(r.fill_ids) for r in reqs]
1137
+ orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
1131
1138
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1132
1139
  extend_lens = [r.extend_input_len for r in reqs]
1133
1140
 
@@ -1144,6 +1151,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1144
1151
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1145
1152
  self.device, non_blocking=True
1146
1153
  )
1154
+ orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1155
+ self.device, non_blocking=True
1156
+ )
1147
1157
  prefix_lens_tensor = torch.tensor(
1148
1158
  prefix_lens, dtype=torch.int64, device=self.device
1149
1159
  )
@@ -1257,6 +1267,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1257
1267
  self.input_ids = input_ids_tensor
1258
1268
  self.req_pool_indices = req_pool_indices_tensor
1259
1269
  self.seq_lens = seq_lens_tensor
1270
+ self.orig_seq_lens = orig_seq_lens_tensor
1260
1271
  self.out_cache_loc = out_cache_loc
1261
1272
  self.input_embeds = (
1262
1273
  torch.tensor(input_embeds).to(self.device, non_blocking=True)
@@ -1504,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1504
1515
  self.forward_mode = ForwardMode.IDLE
1505
1516
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1506
1517
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1518
+ self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1507
1519
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1508
1520
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1509
1521
  self.seq_lens_sum = 0
@@ -1558,9 +1570,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1558
1570
  if self.enable_overlap:
1559
1571
  # Do not use in-place operations in the overlap mode
1560
1572
  self.seq_lens = self.seq_lens + 1
1573
+ self.orig_seq_lens = self.orig_seq_lens + 1
1561
1574
  else:
1562
1575
  # A faster in-place version
1563
1576
  self.seq_lens.add_(1)
1577
+ self.orig_seq_lens.add_(1)
1564
1578
  self.seq_lens_sum += bs
1565
1579
 
1566
1580
  # free memory
@@ -1624,6 +1638,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1624
1638
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1625
1639
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1626
1640
  self.seq_lens = self.seq_lens[keep_indices_device]
1641
+ self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1627
1642
  self.out_cache_loc = None
1628
1643
  self.seq_lens_sum = self.seq_lens.sum().item()
1629
1644
  self.output_ids = self.output_ids[keep_indices_device]
@@ -1656,6 +1671,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1656
1671
  [self.req_pool_indices, other.req_pool_indices]
1657
1672
  )
1658
1673
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1674
+ self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1659
1675
  self.out_cache_loc = None
1660
1676
  self.seq_lens_sum += other.seq_lens_sum
1661
1677
  if self.output_ids is not None:
@@ -1705,6 +1721,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1705
1721
  or attention_backend_str == "flashmla"
1706
1722
  or attention_backend_str == "cutlass_mla"
1707
1723
  or attention_backend_str == "ascend"
1724
+ or attention_backend_str == "trtllm_mha"
1708
1725
  or global_server_args_dict["enable_two_batch_overlap"]
1709
1726
  ):
1710
1727
  seq_lens_cpu = (
@@ -1729,6 +1746,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1729
1746
  input_ids=self.input_ids,
1730
1747
  req_pool_indices=self.req_pool_indices,
1731
1748
  seq_lens=self.seq_lens,
1749
+ orig_seq_lens=self.orig_seq_lens,
1732
1750
  out_cache_loc=self.out_cache_loc,
1733
1751
  seq_lens_cpu=seq_lens_cpu,
1734
1752
  seq_lens_sum=self.seq_lens_sum,
@@ -1750,7 +1768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1750
1768
  encoder_lens=self.encoder_lens,
1751
1769
  encoder_lens_cpu=self.encoder_lens_cpu,
1752
1770
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1753
- lora_paths=[req.lora_path for req in self.reqs],
1771
+ lora_ids=[req.lora_id for req in self.reqs],
1754
1772
  sampling_info=self.sampling_info,
1755
1773
  input_embeds=self.input_embeds,
1756
1774
  token_type_ids=self.token_type_ids,
@@ -1891,11 +1909,14 @@ class ModelWorkerBatch:
1891
1909
  encoder_out_cache_loc: Optional[torch.Tensor]
1892
1910
 
1893
1911
  # For LoRA
1894
- lora_paths: Optional[List[str]]
1912
+ lora_ids: Optional[List[str]]
1895
1913
 
1896
1914
  # Sampling info
1897
1915
  sampling_info: SamplingBatchInfo
1898
1916
 
1917
+ # The original sequence lengths, Qwen-1M related
1918
+ orig_seq_lens: Optional[torch.Tensor] = None
1919
+
1899
1920
  # The input Embeds
1900
1921
  input_embeds: Optional[torch.Tensor] = None
1901
1922
 
@@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
120
120
  SchedulerOutputProcessorMixin,
121
121
  )
122
122
  from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
123
+ from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
123
124
  from sglang.srt.managers.scheduler_update_weights_mixin import (
124
125
  SchedulerUpdateWeightsMixin,
125
126
  )
@@ -472,8 +473,10 @@ class Scheduler(
472
473
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
473
474
  enable=server_args.enable_memory_saver
474
475
  )
476
+ self.offload_tags = set()
475
477
  self.init_profier()
476
478
 
479
+ self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
477
480
  self.input_blocker = (
478
481
  SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
479
482
  if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
@@ -616,6 +619,7 @@ class Scheduler(
616
619
  ),
617
620
  hicache_mem_layout=server_args.hicache_mem_layout,
618
621
  hicache_storage_backend=server_args.hicache_storage_backend,
622
+ hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
619
623
  )
620
624
  self.tp_worker.register_hicache_layer_transfer_counter(
621
625
  self.tree_cache.cache_controller.layer_done_counter
@@ -946,6 +950,14 @@ class Scheduler(
946
950
 
947
951
  def recv_requests(self) -> List[Req]:
948
952
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
953
+
954
+ if self.recv_skipper is not None:
955
+ last_forward_mode = (
956
+ self.last_batch.forward_mode if self.last_batch is not None else None
957
+ )
958
+ if not self.recv_skipper.handle(last_forward_mode):
959
+ return []
960
+
949
961
  if self.pp_rank == 0:
950
962
  if self.attn_tp_rank == 0:
951
963
  recv_reqs = []
@@ -1029,7 +1041,9 @@ class Scheduler(
1029
1041
  for recv_req in recv_reqs:
1030
1042
  # If it is a health check generation request and there are running requests, ignore it.
1031
1043
  if is_health_check_generate_req(recv_req) and (
1032
- self.chunked_req is not None or not self.running_batch.is_empty()
1044
+ self.chunked_req is not None
1045
+ or not self.running_batch.is_empty()
1046
+ or len(self.offload_tags) > 0
1033
1047
  ):
1034
1048
  self.return_health_check_ct += 1
1035
1049
  continue
@@ -1090,7 +1104,7 @@ class Scheduler(
1090
1104
  top_logprobs_num=recv_req.top_logprobs_num,
1091
1105
  token_ids_logprob=recv_req.token_ids_logprob,
1092
1106
  stream=recv_req.stream,
1093
- lora_path=recv_req.lora_path,
1107
+ lora_id=recv_req.lora_id,
1094
1108
  input_embeds=recv_req.input_embeds,
1095
1109
  custom_logit_processor=recv_req.custom_logit_processor,
1096
1110
  return_hidden_states=recv_req.return_hidden_states,
@@ -1534,18 +1548,15 @@ class Scheduler(
1534
1548
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1535
1549
 
1536
1550
  if self.enable_lora:
1537
- lora_set = set([req.lora_path for req in self.running_batch.reqs])
1551
+ lora_set = set([req.lora_id for req in self.running_batch.reqs])
1538
1552
 
1539
1553
  # Get requests from the waiting queue to a new prefill batch
1540
1554
  for req in self.waiting_queue:
1541
- if (
1542
- self.enable_lora
1543
- and len(
1544
- lora_set
1545
- | set([req.lora_path for req in adder.can_run_list])
1546
- | set([req.lora_path])
1547
- )
1548
- > self.max_loras_per_batch
1555
+
1556
+ if self.enable_lora and not self.tp_worker.can_run_lora_batch(
1557
+ lora_set
1558
+ | set([req.lora_id for req in adder.can_run_list])
1559
+ | set([req.lora_id])
1549
1560
  ):
1550
1561
  self.running_batch.batch_is_full = True
1551
1562
  break
@@ -1562,7 +1573,10 @@ class Scheduler(
1562
1573
  break
1563
1574
 
1564
1575
  if self.enable_hicache_storage:
1565
- self.tree_cache.check_prefetch_progress(req.rid)
1576
+ prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
1577
+ if not prefetch_done:
1578
+ # skip staging requests that are ongoing prefetch
1579
+ continue
1566
1580
 
1567
1581
  req.init_next_round_input(self.tree_cache)
1568
1582
  res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
571
571
 
572
572
  req.send_decode_id_offset = len(decode_ids)
573
573
  read_offsets.append(read_offset)
574
- if self.skip_tokenizer_init:
575
- output_ids.append(req.output_ids[send_token_offset:])
574
+ output_ids.append(req.output_ids[send_token_offset:])
576
575
  req.send_token_offset = len(req.output_ids)
577
576
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
578
577
  spaces_between_special_tokens.append(
@@ -0,0 +1,37 @@
1
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
2
+ from sglang.srt.server_args import ServerArgs
3
+
4
+
5
+ class SchedulerRecvSkipper:
6
+ @staticmethod
7
+ def maybe_create(server_args: ServerArgs):
8
+ if server_args.scheduler_recv_interval <= 1:
9
+ return None
10
+ return SchedulerRecvSkipper(server_args)
11
+
12
+ def __init__(self, server_args: ServerArgs):
13
+ # Can be supported if needed, but may need e.g. `global_forward_mode`
14
+ assert not server_args.enable_dp_attention
15
+ self._counter = 0
16
+ self._threshold = server_args.scheduler_recv_interval
17
+
18
+ def handle(self, last_forward_mode: ForwardMode):
19
+ should_recv = False
20
+
21
+ last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
22
+ self._counter += last_weight
23
+
24
+ if self._counter >= self._threshold:
25
+ self._counter = 0
26
+ should_recv = True
27
+
28
+ return should_recv
29
+
30
+
31
+ # All can be tuned if needed
32
+ _DEFAULT_WEIGHT = 1000
33
+ _WEIGHT_OF_FORWARD_MODE = {
34
+ ForwardMode.DECODE: 1,
35
+ ForwardMode.TARGET_VERIFY: 1,
36
+ None: 1,
37
+ }
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
78
78
  if tags is None or len(tags) == 0:
79
79
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
80
80
 
81
+ for tag in tags:
82
+ self.offload_tags.add(tag)
83
+
81
84
  if GPU_MEMORY_TYPE_KV_CACHE in tags:
82
85
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
83
86
  self.flush_cache()
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
97
100
  if tags is None or len(tags) == 0:
98
101
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
99
102
 
103
+ for tag in tags:
104
+ self.offload_tags.remove(tag)
105
+
100
106
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
101
107
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
102
108
  torch.distributed.barrier(self.tp_cpu_group)
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
21
21
  import json
22
22
  import logging
23
23
  import os
24
+ import re
24
25
  from typing import Optional
25
26
 
26
27
  from sglang.srt.code_completion_parser import (
@@ -54,6 +55,7 @@ class TemplateManager:
54
55
  self._chat_template_name: Optional[str] = None
55
56
  self._completion_template_name: Optional[str] = None
56
57
  self._jinja_template_content_format: Optional[str] = "openai"
58
+ self._force_reasoning: bool = False
57
59
 
58
60
  @property
59
61
  def chat_template_name(self) -> Optional[str]:
@@ -70,6 +72,31 @@ class TemplateManager:
70
72
  """Get the detected template content format ('string' or 'openai' or None)."""
71
73
  return self._jinja_template_content_format
72
74
 
75
+ @property
76
+ def force_reasoning(self) -> bool:
77
+ """
78
+ Check if the current chat template enforces reasoning/thinking.
79
+
80
+ Returns:
81
+ True if the template contains reasoning patterns like <think> tags
82
+ """
83
+ return self._force_reasoning
84
+
85
+ def _detect_reasoning_pattern(self, template: str) -> bool:
86
+ """
87
+ Detect if the chat template contains reasoning/thinking patterns.
88
+ """
89
+ if template is None:
90
+ return False
91
+
92
+ force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
93
+ has_reasoning = re.search(force_reasoning_pattern, template) is not None
94
+
95
+ if has_reasoning:
96
+ logger.info("Detected the force reasoning pattern in chat template.")
97
+
98
+ return has_reasoning
99
+
73
100
  def load_chat_template(
74
101
  self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
75
102
  ) -> None:
@@ -93,7 +120,8 @@ class TemplateManager:
93
120
  hf_template = self._resolve_hf_chat_template(tokenizer_manager)
94
121
  if hf_template:
95
122
  # override the chat template
96
- tokenizer_manager.tokenizer.chat_template = hf_template
123
+ if tokenizer_manager.tokenizer:
124
+ tokenizer_manager.tokenizer.chat_template = hf_template
97
125
  self._jinja_template_content_format = (
98
126
  detect_jinja_template_content_format(hf_template)
99
127
  )
@@ -106,6 +134,12 @@ class TemplateManager:
106
134
  self._jinja_template_content_format = "string"
107
135
  logger.info("No chat template found, defaulting to 'string' content format")
108
136
 
137
+ # Detect reasoning pattern from chat template
138
+ if tokenizer_manager.tokenizer:
139
+ self._force_reasoning = self._detect_reasoning_pattern(
140
+ tokenizer_manager.tokenizer.chat_template
141
+ )
142
+
109
143
  def _load_explicit_chat_template(
110
144
  self, tokenizer_manager, chat_template_arg: str
111
145
  ) -> None: