sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  import logging
17
17
  import math
18
18
  import threading
19
+ import time
19
20
  from queue import Empty, Full, PriorityQueue, Queue
20
21
  from typing import TYPE_CHECKING, List, Optional
21
22
 
@@ -168,12 +169,13 @@ class StorageOperation:
168
169
  host_indices: torch.Tensor,
169
170
  token_ids: List[int],
170
171
  last_hash: Optional[str] = None,
172
+ hash_value: Optional[List[str]] = None,
171
173
  ):
172
174
  self.host_indices = host_indices
173
175
  self.token_ids = token_ids
174
176
  self.last_hash = last_hash
175
177
  self.completed_tokens = 0
176
- self.hash_value = []
178
+ self.hash_value = hash_value if hash_value is not None else []
177
179
 
178
180
  self.id = StorageOperation.counter
179
181
  StorageOperation.counter += 1
@@ -195,6 +197,8 @@ class PrefetchOperation(StorageOperation):
195
197
  self._done_flag = False
196
198
  self._lock = threading.Lock()
197
199
 
200
+ self.start_time = time.monotonic()
201
+
198
202
  super().__init__(host_indices, token_ids, last_hash)
199
203
 
200
204
  def increment(self, num_tokens: int):
@@ -243,12 +247,12 @@ class HiCacheController:
243
247
  self.storage_backend = HiCacheFile()
244
248
  self.get_hash_str = get_hash_str
245
249
  elif storage_backend == "nixl":
246
- from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
250
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
247
251
 
248
252
  self.storage_backend = HiCacheNixl()
249
253
  self.get_hash_str = get_hash_str
250
254
  elif storage_backend == "mooncake":
251
- from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
255
+ from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
252
256
  MooncakeStore,
253
257
  get_hash_str_mooncake,
254
258
  )
@@ -256,6 +260,7 @@ class HiCacheController:
256
260
  self.storage_backend = MooncakeStore()
257
261
  self.get_hash_str = get_hash_str_mooncake
258
262
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
+ assert self.mem_pool_host.layout == "page_first"
259
264
  elif storage_backend == "hf3fs":
260
265
  from sglang.srt.distributed import get_tensor_model_parallel_rank
261
266
  from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
@@ -278,6 +283,12 @@ class HiCacheController:
278
283
  self.enable_storage = True
279
284
  # todo: threshold policy for prefetching
280
285
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
286
+ self.prefetch_capacity_limit = int(
287
+ 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
288
+ )
289
+ # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
290
+ self.prefetch_tokens_occupied = 0
291
+
281
292
  # create a new communication group for synchronizing storage operations across TP workers
282
293
  self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
283
294
  if self.tp_world_size > 1:
@@ -424,7 +435,9 @@ class HiCacheController:
424
435
  if self.io_backend == "kernel":
425
436
  return host_indices.to(self.mem_pool_device.device), device_indices
426
437
  elif self.io_backend == "direct":
427
- return host_indices, device_indices.cpu()
438
+ device_indices = device_indices.cpu()
439
+ host_indices, idx = host_indices.sort()
440
+ return host_indices, device_indices.index_select(0, idx)
428
441
  else:
429
442
  raise ValueError(f"Unsupported io backend")
430
443
 
@@ -525,7 +538,7 @@ class HiCacheController:
525
538
  host_indices: torch.Tensor,
526
539
  new_input_tokens: List[int],
527
540
  last_hash: Optional[str] = None,
528
- ) -> int:
541
+ ) -> PrefetchOperation:
529
542
  """
530
543
  Prefetch KV caches from storage backend to host memory.
531
544
  """
@@ -561,10 +574,6 @@ class HiCacheController:
561
574
  )
562
575
  completed_tokens += self.page_size
563
576
  else:
564
- # operation terminated by controller, release pre-allocated memory
565
- self.mem_pool_host.free(
566
- operation.host_indices[operation.completed_tokens :]
567
- )
568
577
  break
569
578
 
570
579
  def mooncake_page_transfer(self, operation):
@@ -586,11 +595,31 @@ class HiCacheController:
586
595
  operation = self.prefetch_buffer.get(block=True, timeout=1)
587
596
  if self.is_mooncake_backend():
588
597
  self.mooncake_page_transfer(operation)
598
+ elif self.storage_backend_type == "hf3fs":
599
+ self.generic_page_transfer(operation, batch_size=128)
589
600
  else:
590
601
  self.generic_page_transfer(operation)
602
+
603
+ if self.tp_world_size > 1:
604
+ # to ensure all TP workers release the host memory at the same time
605
+ torch.distributed.barrier(group=self.prefetch_tp_group)
606
+ # operation terminated by controller, release pre-allocated memory
607
+ self.mem_pool_host.free(
608
+ operation.host_indices[operation.completed_tokens :]
609
+ )
591
610
  except Empty:
592
611
  continue
593
612
 
613
+ def prefetch_rate_limit_check(self) -> bool:
614
+ """
615
+ Rate limit the prefetching operations to avoid overwhelming the storage backend.
616
+ """
617
+ # cancel prefetch if too much memory is occupied
618
+ if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
619
+ return False
620
+ # todo: more sophisticated rate limiting based on storage backend performance
621
+ return True
622
+
594
623
  def prefetch_thread_func(self):
595
624
  """
596
625
  Manage prefetching operations from storage backend to host memory.
@@ -604,34 +633,38 @@ class HiCacheController:
604
633
  if operation is None:
605
634
  continue
606
635
 
607
- last_hash = operation.last_hash
608
- tokens_to_fetch = operation.token_ids
609
-
610
636
  storage_hit_count = 0
611
- remaining_tokens = len(tokens_to_fetch)
612
- hash_value = []
613
- while remaining_tokens >= self.page_size:
614
- last_hash = self.get_hash_str(
615
- tokens_to_fetch[
616
- storage_hit_count : storage_hit_count + self.page_size
617
- ],
618
- last_hash,
619
- )
620
-
621
- # todo, more unified interface
622
- if not self.is_mooncake_backend():
623
- if not self.storage_backend.exists(last_hash):
624
- break
625
- hash_value.append(last_hash)
626
- storage_hit_count += self.page_size
627
- remaining_tokens -= self.page_size
628
-
629
- if self.is_mooncake_backend():
630
- # deferring to batch exists for mooncake store
631
- exist_result = self.storage_backend.exists(hash_value)
632
- storage_hit_count = (
633
- sum(1 for v in exist_result.values() if v != 0) * self.page_size
634
- )
637
+ if (
638
+ operation.host_indices is not None
639
+ ) and self.prefetch_rate_limit_check():
640
+ last_hash = operation.last_hash
641
+ tokens_to_fetch = operation.token_ids
642
+
643
+ remaining_tokens = len(tokens_to_fetch)
644
+ hash_value = []
645
+ while remaining_tokens >= self.page_size:
646
+ last_hash = self.get_hash_str(
647
+ tokens_to_fetch[
648
+ storage_hit_count : storage_hit_count + self.page_size
649
+ ],
650
+ last_hash,
651
+ )
652
+
653
+ # todo, more unified interface
654
+ if not self.is_mooncake_backend():
655
+ if not self.storage_backend.exists(last_hash):
656
+ break
657
+ hash_value.append(last_hash)
658
+ storage_hit_count += self.page_size
659
+ remaining_tokens -= self.page_size
660
+
661
+ if self.is_mooncake_backend():
662
+ # deferring to batch exists for mooncake store
663
+ exist_result = self.storage_backend.exists(hash_value)
664
+ storage_hit_count = (
665
+ sum(1 for v in exist_result.values() if v != 0)
666
+ * self.page_size
667
+ )
635
668
 
636
669
  if self.tp_world_size > 1:
637
670
  storage_hit_count_tensor = torch.tensor(
@@ -647,7 +680,8 @@ class HiCacheController:
647
680
  if storage_hit_count < self.prefetch_threshold:
648
681
  # not to prefetch if not enough benefits
649
682
  self.prefetch_revoke_queue.put(operation.request_id)
650
- self.mem_pool_host.free(operation.host_indices)
683
+ if operation.host_indices is not None:
684
+ self.mem_pool_host.free(operation.host_indices)
651
685
  logger.debug(
652
686
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
653
687
  )
@@ -670,12 +704,12 @@ class HiCacheController:
670
704
  self,
671
705
  host_indices: torch.Tensor,
672
706
  token_ids: List[int],
673
- last_hash: Optional[str] = None,
707
+ hash_value: Optional[List[str]] = None,
674
708
  ) -> int:
675
709
  """
676
710
  Write KV caches from host memory to storage backend.
677
711
  """
678
- operation = StorageOperation(host_indices, token_ids, last_hash)
712
+ operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
679
713
  self.backup_queue.put(operation)
680
714
  return operation.id
681
715
 
@@ -730,26 +764,10 @@ class HiCacheController:
730
764
  if operation is None:
731
765
  continue
732
766
 
733
- last_hash = operation.last_hash
734
- tokens_to_backup = operation.token_ids
735
-
736
- backup_hit_count = 0
737
- remaining_tokens = len(tokens_to_backup)
738
- hash_value = []
739
- while remaining_tokens >= self.page_size:
740
- last_hash = self.get_hash_str(
741
- tokens_to_backup[
742
- backup_hit_count : backup_hit_count + self.page_size
743
- ],
744
- last_hash,
745
- )
746
- backup_hit_count += self.page_size
747
- hash_value.append(last_hash)
748
- remaining_tokens -= self.page_size
749
- operation.hash_value = hash_value
750
-
751
767
  if self.is_mooncake_backend():
752
768
  self.mooncake_page_backup(operation)
769
+ elif self.storage_backend_type == "hf3fs":
770
+ self.generic_page_backup(operation, batch_size=128)
753
771
  else:
754
772
  self.generic_page_backup(operation)
755
773
 
@@ -768,7 +786,6 @@ class HiCacheController:
768
786
  self.ack_backup_queue.put(
769
787
  (
770
788
  operation.id,
771
- operation.hash_value[: min_completed_tokens // self.page_size],
772
789
  min_completed_tokens,
773
790
  )
774
791
  )
@@ -216,7 +216,7 @@ class DetokenizerManager:
216
216
  rids=recv_obj.rids,
217
217
  finished_reasons=recv_obj.finished_reasons,
218
218
  output_strs=output_strs,
219
- output_ids=None,
219
+ output_ids=recv_obj.output_ids,
220
220
  prompt_tokens=recv_obj.prompt_tokens,
221
221
  completion_tokens=recv_obj.completion_tokens,
222
222
  cached_tokens=recv_obj.cached_tokens,
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
26
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
28
28
  from sglang.srt.sampling.sampling_params import SamplingParams
29
+ from sglang.srt.utils import ImageData
29
30
 
30
31
  # Handle serialization of Image for pydantic
31
32
  if TYPE_CHECKING:
@@ -45,7 +46,7 @@ class SessionParams:
45
46
 
46
47
  # Type definitions for multimodal input data
47
48
  # Individual data item types for each modality
48
- ImageDataInputItem = Union[Image, str, Dict]
49
+ ImageDataInputItem = Union[Image, str, ImageData, Dict]
49
50
  AudioDataInputItem = Union[str, Dict]
50
51
  VideoDataInputItem = Union[str, Dict]
51
52
  # Union type for any multimodal data item
@@ -98,23 +99,24 @@ class GenerateReqInput:
98
99
  stream: bool = False
99
100
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
100
101
  log_metrics: bool = True
102
+ # Whether to return hidden states
103
+ return_hidden_states: Union[List[bool], bool] = False
101
104
 
102
105
  # The modalities of the image data [image, multi-images, video]
103
106
  modalities: Optional[List[str]] = None
104
- # The path to the LoRA
105
- lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
106
-
107
107
  # Session info for continual prompting
108
108
  session_params: Optional[Union[List[Dict], Dict]] = None
109
109
 
110
+ # The path to the LoRA adaptors
111
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
112
+ # The uid of LoRA adaptors, should be initialized by tokenizer manager
113
+ lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
114
+
110
115
  # Custom logit processor for advanced sampling control. Must be a serialized instance
111
116
  # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
112
117
  # Use the processor's `to_str()` method to generate the serialized string.
113
118
  custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
114
119
 
115
- # Whether to return hidden states
116
- return_hidden_states: Union[List[bool], bool] = False
117
-
118
120
  # For disaggregated inference
119
121
  bootstrap_host: Optional[Union[List[str], str]] = None
120
122
  bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
@@ -123,6 +125,9 @@ class GenerateReqInput:
123
125
  # For data parallel rank routing
124
126
  data_parallel_rank: Optional[int] = None
125
127
 
128
+ # For background responses (OpenAI responses API)
129
+ background: bool = False
130
+
126
131
  def contains_mm_input(self) -> bool:
127
132
  return (
128
133
  has_valid_data(self.image_data)
@@ -450,6 +455,7 @@ class GenerateReqInput:
450
455
  log_metrics=self.log_metrics,
451
456
  modalities=self.modalities[i] if self.modalities else None,
452
457
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
458
+ lora_id=self.lora_id[i] if self.lora_id is not None else None,
453
459
  custom_logit_processor=(
454
460
  self.custom_logit_processor[i]
455
461
  if self.custom_logit_processor is not None
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
500
506
  stream: bool
501
507
 
502
508
  # LoRA related
503
- lora_path: Optional[str] = None # None means just use the base model
509
+ lora_id: Optional[str] = None # None means just use the base model
504
510
  # The input embeds
505
511
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
506
512
 
@@ -557,6 +563,9 @@ class EmbeddingReqInput:
557
563
  # For cross-encoder requests
558
564
  is_cross_encoder_request: bool = False
559
565
 
566
+ # For background responses (OpenAI responses API)
567
+ background: bool = False
568
+
560
569
  def normalize_batch_and_arguments(self):
561
570
  # at least one of text, input_ids, or image should be provided
562
571
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -1073,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
1073
1082
  lora_name: str
1074
1083
  # The path of loading.
1075
1084
  lora_path: str
1085
+ # Whether to pin the LoRA adapter in memory.
1086
+ pinned: bool = False
1076
1087
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1077
1088
  lora_id: Optional[str] = None
1078
1089
 
@@ -1081,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
1081
1092
  lora_id=self.lora_id,
1082
1093
  lora_name=self.lora_name,
1083
1094
  lora_path=self.lora_path,
1095
+ pinned=self.pinned,
1084
1096
  )
1085
1097
 
1086
1098
 
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
388
388
  embedding_per_req = data_embedding_func(embedding_items_per_req)
389
389
  if not embedding_cache.put(embedding_items_hash, embedding_per_req):
390
390
  print_warning_once(
391
- "Multimodal embedding cache is full. Consider increasing the "
392
- "`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
391
+ "Multimodal embedding cache is full. This typically occurs when a single "
392
+ "embedding exceeds the cache size limit. Consider increasing the "
393
+ "`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
394
+ "embedding size."
393
395
  )
394
396
 
395
- embedding_per_req_chunk, _, end_index = get_embedding_chunk(
397
+ embedding_per_req_chunk, _, _ = get_embedding_chunk(
396
398
  embedding=embedding_per_req,
397
399
  extend_prefix_len=prefix_length[i],
398
400
  extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
399
401
  items_offset=items_offset,
400
402
  )
401
- # remove this item from cache if chunk reaches to the end
402
- embedding_per_req_length = (
403
- embedding_per_req.shape[0]
404
- if embedding_per_req.dim() == 2
405
- else embedding_per_req.shape[0] * embedding_per_req.shape[1]
406
- )
407
- if end_index == embedding_per_req_length:
408
- embedding_cache.free(embedding_items_hash)
409
403
  embedding_list.append(embedding_per_req_chunk)
410
404
  if len(embedding_list) == 0:
411
405
  return None
@@ -620,8 +614,7 @@ def general_mm_embed_routine(
620
614
  input_ids: Input token IDs tensor
621
615
  forward_batch: Batch information for model forward pass
622
616
  language_model: Base language model to use
623
- image_data_embedding_func: Function to embed image data
624
- audio_data_embedding_func: Function to embed audio data
617
+ data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
625
618
  placeholder_tokens: Token IDs for multimodal placeholders
626
619
  **kwargs: Additional arguments passed to language model
627
620
 
@@ -20,7 +20,7 @@ def import_processors():
20
20
  try:
21
21
  module = importlib.import_module(name)
22
22
  except Exception as e:
23
- logger.warning(f"Ignore import error when loading {name}: " f"{e}")
23
+ logger.warning(f"Ignore import error when loading {name}: {e}")
24
24
  continue
25
25
  all_members = inspect.getmembers(module, inspect.isclass)
26
26
  classes = [
@@ -37,6 +37,7 @@ import logging
37
37
  import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
+ from itertools import chain
40
41
  from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
42
 
42
43
  import numpy as np
@@ -51,13 +52,13 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
51
52
  ScheduleBatchDisaggregationDecodeMixin,
52
53
  )
53
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
54
- from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
55
55
  from sglang.srt.mem_cache.allocator import (
56
56
  BaseTokenToKVPoolAllocator,
57
57
  SWATokenToKVPoolAllocator,
58
58
  )
59
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
60
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
+ from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
61
62
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
62
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
63
64
  from sglang.srt.metrics.collector import TimeStats
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
86
  "disable_radix_cache",
86
87
  "enable_dp_attention",
87
88
  "enable_two_batch_overlap",
89
+ "tbo_token_distribution_threshold",
88
90
  "enable_dp_lm_head",
89
91
  "moe_a2a_backend",
90
92
  "deepep_mode",
@@ -107,8 +109,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
107
109
  "num_reserved_decode_tokens",
108
110
  "weight_loader_disable_mmap",
109
111
  "enable_triton_kernel_moe",
112
+ "enable_flashinfer_mxfp4_moe",
110
113
  "enable_multimodal",
111
114
  "enable_symm_mem",
115
+ "quantization",
112
116
  ]
113
117
 
114
118
  # Put some global args for easy access
@@ -423,7 +427,7 @@ class Req:
423
427
  token_ids_logprob: List[int] = None,
424
428
  stream: bool = False,
425
429
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
426
- lora_path: Optional[str] = None,
430
+ lora_id: Optional[str] = None,
427
431
  input_embeds: Optional[List[List[float]]] = None,
428
432
  token_type_ids: List[int] = None,
429
433
  session_id: Optional[str] = None,
@@ -467,7 +471,7 @@ class Req:
467
471
  self.sampling_params = sampling_params
468
472
  self.custom_logit_processor = custom_logit_processor
469
473
  self.return_hidden_states = return_hidden_states
470
- self.lora_path = lora_path
474
+ self.lora_id = lora_id
471
475
 
472
476
  # Memory pool info
473
477
  self.req_pool_idx: Optional[int] = None
@@ -636,14 +640,26 @@ class Req:
636
640
  ):
637
641
  self.fill_ids = self.origin_input_ids + self.output_ids
638
642
  if tree_cache is not None:
639
- (
640
- self.prefix_indices,
641
- self.last_node,
642
- self.last_host_node,
643
- self.host_hit_length,
644
- ) = tree_cache.match_prefix(
645
- key=self.adjust_max_prefix_ids(),
646
- )
643
+ if isinstance(tree_cache, LoRARadixCache):
644
+ (
645
+ self.prefix_indices,
646
+ self.last_node,
647
+ self.last_host_node,
648
+ self.host_hit_length,
649
+ ) = tree_cache.match_prefix_with_lora_id(
650
+ key=LoRAKey(
651
+ lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
652
+ ),
653
+ )
654
+ else:
655
+ (
656
+ self.prefix_indices,
657
+ self.last_node,
658
+ self.last_host_node,
659
+ self.host_hit_length,
660
+ ) = tree_cache.match_prefix(
661
+ key=self.adjust_max_prefix_ids(),
662
+ )
647
663
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
648
664
 
649
665
  def adjust_max_prefix_ids(self):
@@ -845,6 +861,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
845
861
 
846
862
  # The sum of all sequence lengths
847
863
  seq_lens_sum: int = None
864
+ # The original sequence lengths, Qwen-1M related
865
+ orig_seq_lens: torch.Tensor = None # shape: [b], int32
848
866
 
849
867
  # For DP attention
850
868
  global_num_tokens: Optional[List[int]] = None
@@ -917,8 +935,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
917
935
 
918
936
  is_hybrid = False
919
937
  if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
920
- assert isinstance(tree_cache, SWARadixCache) or isinstance(
921
- tree_cache, SWAChunkCache
938
+ assert (
939
+ tree_cache is None
940
+ or isinstance(tree_cache, SWARadixCache)
941
+ or isinstance(tree_cache, SWAChunkCache)
922
942
  ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
923
943
  is_hybrid = True
924
944
 
@@ -1128,6 +1148,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1128
1148
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1129
1149
  extend_num_tokens = sum(len(ids) for ids in input_ids)
1130
1150
  seq_lens = [len(r.fill_ids) for r in reqs]
1151
+ orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
1131
1152
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1132
1153
  extend_lens = [r.extend_input_len for r in reqs]
1133
1154
 
@@ -1138,10 +1159,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1138
1159
  req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1139
1160
  self.device, non_blocking=True
1140
1161
  )
1141
- input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1162
+ input_ids_tensor = torch.tensor(
1163
+ list(chain.from_iterable(input_ids)), dtype=torch.int64
1164
+ ).to(self.device, non_blocking=True)
1165
+ seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1142
1166
  self.device, non_blocking=True
1143
1167
  )
1144
- seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1168
+ orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1145
1169
  self.device, non_blocking=True
1146
1170
  )
1147
1171
  prefix_lens_tensor = torch.tensor(
@@ -1257,6 +1281,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1257
1281
  self.input_ids = input_ids_tensor
1258
1282
  self.req_pool_indices = req_pool_indices_tensor
1259
1283
  self.seq_lens = seq_lens_tensor
1284
+ self.orig_seq_lens = orig_seq_lens_tensor
1260
1285
  self.out_cache_loc = out_cache_loc
1261
1286
  self.input_embeds = (
1262
1287
  torch.tensor(input_embeds).to(self.device, non_blocking=True)
@@ -1504,6 +1529,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1504
1529
  self.forward_mode = ForwardMode.IDLE
1505
1530
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1506
1531
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1532
+ self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1507
1533
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1508
1534
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1509
1535
  self.seq_lens_sum = 0
@@ -1558,9 +1584,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1558
1584
  if self.enable_overlap:
1559
1585
  # Do not use in-place operations in the overlap mode
1560
1586
  self.seq_lens = self.seq_lens + 1
1587
+ self.orig_seq_lens = self.orig_seq_lens + 1
1561
1588
  else:
1562
1589
  # A faster in-place version
1563
1590
  self.seq_lens.add_(1)
1591
+ self.orig_seq_lens.add_(1)
1564
1592
  self.seq_lens_sum += bs
1565
1593
 
1566
1594
  # free memory
@@ -1624,6 +1652,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1624
1652
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1625
1653
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1626
1654
  self.seq_lens = self.seq_lens[keep_indices_device]
1655
+ self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1627
1656
  self.out_cache_loc = None
1628
1657
  self.seq_lens_sum = self.seq_lens.sum().item()
1629
1658
  self.output_ids = self.output_ids[keep_indices_device]
@@ -1656,6 +1685,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1656
1685
  [self.req_pool_indices, other.req_pool_indices]
1657
1686
  )
1658
1687
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1688
+ self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1659
1689
  self.out_cache_loc = None
1660
1690
  self.seq_lens_sum += other.seq_lens_sum
1661
1691
  if self.output_ids is not None:
@@ -1697,14 +1727,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1697
1727
  attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1698
1728
  # Create seq_lens_cpu when needed
1699
1729
  if (
1700
- attention_backend_str == "fa3"
1701
- or (
1702
- global_server_args_dict["use_mla_backend"]
1703
- and attention_backend_str == "flashinfer"
1704
- )
1705
- or attention_backend_str == "flashmla"
1706
- or attention_backend_str == "cutlass_mla"
1707
- or attention_backend_str == "ascend"
1730
+ attention_backend_str
1731
+ in [
1732
+ "fa3",
1733
+ "flashinfer",
1734
+ "flashmla",
1735
+ "cutlass_mla",
1736
+ "ascend",
1737
+ "trtllm_mha",
1738
+ "aiter",
1739
+ ]
1708
1740
  or global_server_args_dict["enable_two_batch_overlap"]
1709
1741
  ):
1710
1742
  seq_lens_cpu = (
@@ -1729,6 +1761,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1729
1761
  input_ids=self.input_ids,
1730
1762
  req_pool_indices=self.req_pool_indices,
1731
1763
  seq_lens=self.seq_lens,
1764
+ orig_seq_lens=self.orig_seq_lens,
1732
1765
  out_cache_loc=self.out_cache_loc,
1733
1766
  seq_lens_cpu=seq_lens_cpu,
1734
1767
  seq_lens_sum=self.seq_lens_sum,
@@ -1750,7 +1783,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1750
1783
  encoder_lens=self.encoder_lens,
1751
1784
  encoder_lens_cpu=self.encoder_lens_cpu,
1752
1785
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1753
- lora_paths=[req.lora_path for req in self.reqs],
1786
+ lora_ids=[req.lora_id for req in self.reqs],
1754
1787
  sampling_info=self.sampling_info,
1755
1788
  input_embeds=self.input_embeds,
1756
1789
  token_type_ids=self.token_type_ids,
@@ -1891,11 +1924,14 @@ class ModelWorkerBatch:
1891
1924
  encoder_out_cache_loc: Optional[torch.Tensor]
1892
1925
 
1893
1926
  # For LoRA
1894
- lora_paths: Optional[List[str]]
1927
+ lora_ids: Optional[List[str]]
1895
1928
 
1896
1929
  # Sampling info
1897
1930
  sampling_info: SamplingBatchInfo
1898
1931
 
1932
+ # The original sequence lengths, Qwen-1M related
1933
+ orig_seq_lens: Optional[torch.Tensor] = None
1934
+
1899
1935
  # The input Embeds
1900
1936
  input_embeds: Optional[torch.Tensor] = None
1901
1937