sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  import logging
17
17
  import math
18
18
  import threading
19
+ import time
19
20
  from queue import Empty, Full, PriorityQueue, Queue
20
21
  from typing import TYPE_CHECKING, List, Optional
21
22
 
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
195
196
  self._done_flag = False
196
197
  self._lock = threading.Lock()
197
198
 
199
+ self.start_time = time.monotonic()
200
+
198
201
  super().__init__(host_indices, token_ids, last_hash)
199
202
 
200
203
  def increment(self, num_tokens: int):
@@ -236,18 +239,19 @@ class HiCacheController:
236
239
  self.enable_storage = False
237
240
  # todo: move backend initialization to storage backend module
238
241
  if storage_backend is not None:
242
+ self.storage_backend_type = storage_backend
239
243
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
240
244
 
241
245
  if storage_backend == "file":
242
246
  self.storage_backend = HiCacheFile()
243
247
  self.get_hash_str = get_hash_str
244
248
  elif storage_backend == "nixl":
245
- from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
249
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
246
250
 
247
251
  self.storage_backend = HiCacheNixl()
248
252
  self.get_hash_str = get_hash_str
249
253
  elif storage_backend == "mooncake":
250
- from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
254
+ from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
251
255
  MooncakeStore,
252
256
  get_hash_str_mooncake,
253
257
  )
@@ -277,6 +281,12 @@ class HiCacheController:
277
281
  self.enable_storage = True
278
282
  # todo: threshold policy for prefetching
279
283
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
284
+ self.prefetch_capacity_limit = int(
285
+ 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
286
+ )
287
+ # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
288
+ self.prefetch_tokens_occupied = 0
289
+
280
290
  # create a new communication group for synchronizing storage operations across TP workers
281
291
  self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
282
292
  if self.tp_world_size > 1:
@@ -524,7 +534,7 @@ class HiCacheController:
524
534
  host_indices: torch.Tensor,
525
535
  new_input_tokens: List[int],
526
536
  last_hash: Optional[str] = None,
527
- ) -> int:
537
+ ) -> PrefetchOperation:
528
538
  """
529
539
  Prefetch KV caches from storage backend to host memory.
530
540
  """
@@ -573,6 +583,9 @@ class HiCacheController:
573
583
  self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
584
  operation.increment(len(operation.hash_value) * self.page_size)
575
585
 
586
+ def is_mooncake_backend(self):
587
+ return self.storage_backend_type == "mooncake"
588
+
576
589
  def prefetch_io_aux_func(self):
577
590
  """
578
591
  Auxiliary function conducting IO operations for prefetching.
@@ -580,13 +593,25 @@ class HiCacheController:
580
593
  while not self.stop_event.is_set():
581
594
  try:
582
595
  operation = self.prefetch_buffer.get(block=True, timeout=1)
583
- if isinstance(self.storage_backend, MooncakeStore):
596
+ if self.is_mooncake_backend():
584
597
  self.mooncake_page_transfer(operation)
598
+ elif self.storage_backend_type == "hf3fs":
599
+ self.generic_page_transfer(operation, batch_size=128)
585
600
  else:
586
601
  self.generic_page_transfer(operation)
587
602
  except Empty:
588
603
  continue
589
604
 
605
+ def prefetch_rate_limit_check(self) -> bool:
606
+ """
607
+ Rate limit the prefetching operations to avoid overwhelming the storage backend.
608
+ """
609
+ # cancel prefetch if too much memory is occupied
610
+ if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
611
+ return False
612
+ # todo: more sophisticated rate limiting based on storage backend performance
613
+ return True
614
+
590
615
  def prefetch_thread_func(self):
591
616
  """
592
617
  Manage prefetching operations from storage backend to host memory.
@@ -600,34 +625,36 @@ class HiCacheController:
600
625
  if operation is None:
601
626
  continue
602
627
 
603
- last_hash = operation.last_hash
604
- tokens_to_fetch = operation.token_ids
605
-
606
628
  storage_hit_count = 0
607
- remaining_tokens = len(tokens_to_fetch)
608
- hash_value = []
609
- while remaining_tokens >= self.page_size:
610
- last_hash = self.get_hash_str(
611
- tokens_to_fetch[
612
- storage_hit_count : storage_hit_count + self.page_size
613
- ],
614
- last_hash,
615
- )
616
-
617
- # todo, more unified interface
618
- if not isinstance(self.storage_backend, MooncakeStore):
619
- if not self.storage_backend.exists(last_hash):
620
- break
621
- hash_value.append(last_hash)
622
- storage_hit_count += self.page_size
623
- remaining_tokens -= self.page_size
624
-
625
- if isinstance(self.storage_backend, MooncakeStore):
626
- # deferring to batch exists for mooncake store
627
- exist_result = self.storage_backend.exists(hash_value)
628
- storage_hit_count = (
629
- sum(1 for v in exist_result.values() if v != 0) * self.page_size
630
- )
629
+ if self.prefetch_rate_limit_check():
630
+ last_hash = operation.last_hash
631
+ tokens_to_fetch = operation.token_ids
632
+
633
+ remaining_tokens = len(tokens_to_fetch)
634
+ hash_value = []
635
+ while remaining_tokens >= self.page_size:
636
+ last_hash = self.get_hash_str(
637
+ tokens_to_fetch[
638
+ storage_hit_count : storage_hit_count + self.page_size
639
+ ],
640
+ last_hash,
641
+ )
642
+
643
+ # todo, more unified interface
644
+ if not self.is_mooncake_backend():
645
+ if not self.storage_backend.exists(last_hash):
646
+ break
647
+ hash_value.append(last_hash)
648
+ storage_hit_count += self.page_size
649
+ remaining_tokens -= self.page_size
650
+
651
+ if self.is_mooncake_backend():
652
+ # deferring to batch exists for mooncake store
653
+ exist_result = self.storage_backend.exists(hash_value)
654
+ storage_hit_count = (
655
+ sum(1 for v in exist_result.values() if v != 0)
656
+ * self.page_size
657
+ )
631
658
 
632
659
  if self.tp_world_size > 1:
633
660
  storage_hit_count_tensor = torch.tensor(
@@ -744,8 +771,10 @@ class HiCacheController:
744
771
  remaining_tokens -= self.page_size
745
772
  operation.hash_value = hash_value
746
773
 
747
- if isinstance(self.storage_backend, MooncakeStore):
774
+ if self.is_mooncake_backend():
748
775
  self.mooncake_page_backup(operation)
776
+ elif self.storage_backend_type == "hf3fs":
777
+ self.generic_page_backup(operation, batch_size=128)
749
778
  else:
750
779
  self.generic_page_backup(operation)
751
780
 
@@ -16,9 +16,13 @@
16
16
  import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
+ import struct
20
+ import sys
19
21
  import threading
20
22
  import time
21
23
  from enum import Enum, auto
24
+ from multiprocessing import shared_memory
25
+ from typing import Dict, List
22
26
 
23
27
  import psutil
24
28
  import setproctitle
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
32
36
  )
33
37
  from sglang.srt.managers.schedule_batch import Req
34
38
  from sglang.srt.managers.scheduler import run_scheduler_process
39
+ from sglang.srt.managers.utils import DPBalanceMeta
35
40
  from sglang.srt.server_args import PortArgs, ServerArgs
36
41
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
37
42
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
45
50
 
46
51
  ROUND_ROBIN = auto()
47
52
  SHORTEST_QUEUE = auto()
53
+ MINIMUM_TOKENS = auto()
48
54
 
49
55
  @classmethod
50
56
  def from_str(cls, method: str):
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
58
64
  class DataParallelController:
59
65
  """A controller that dispatches requests to multiple data parallel workers."""
60
66
 
61
- def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
67
+ def __init__(
68
+ self,
69
+ server_args: ServerArgs,
70
+ port_args: PortArgs,
71
+ dp_balance_meta: DPBalanceMeta,
72
+ ) -> None:
73
+ # for dp balance
74
+ self.global_balance_id = 0
75
+ self.balance_meta = dp_balance_meta
76
+
62
77
  # Parse args
63
78
  self.max_total_num_tokens = None
64
79
  self.server_args = server_args
@@ -79,6 +94,7 @@ class DataParallelController:
79
94
  dispatch_lookup = {
80
95
  LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
81
96
  LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
97
+ LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
82
98
  }
83
99
  self.dispatching = dispatch_lookup[self.load_balance_method]
84
100
 
@@ -234,6 +250,7 @@ class DataParallelController:
234
250
  pp_rank,
235
251
  dp_rank,
236
252
  writer,
253
+ self.balance_meta,
237
254
  ),
238
255
  )
239
256
  with memory_saver_adapter.configure_subprocess():
@@ -269,6 +286,33 @@ class DataParallelController:
269
286
  def shortest_queue_scheduler(self, input_requests):
270
287
  raise NotImplementedError()
271
288
 
289
+ def minimum_tokens_scheduler(self, req):
290
+ # This variable corresponds to the balance_id in TokenizedGenerateReqInput.
291
+ # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
292
+ def get_next_global_balance_id() -> int:
293
+ INT32_MAX = 2147483647
294
+ current_id = self.global_balance_id
295
+ self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
296
+ return current_id
297
+
298
+ req.dp_balance_id = get_next_global_balance_id()
299
+ with self.balance_meta.mutex:
300
+ # 1. local_tokens represents the tokens currently inferring on the worker,
301
+ # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
302
+ onfly_info = self.balance_meta.get_shared_onfly()
303
+ local_tokens = self.balance_meta.get_shared_local_tokens()
304
+ total_tokens = [
305
+ local_token + sum(onfly_dict.values())
306
+ for local_token, onfly_dict in zip(local_tokens, onfly_info)
307
+ ]
308
+ target_worker = total_tokens.index(min(total_tokens))
309
+ onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
310
+ # 2. write the new onfly info to the shm
311
+ self.balance_meta.set_shared_onfly_info(onfly_info)
312
+
313
+ # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
314
+ self.workers[target_worker].send_pyobj(req)
315
+
272
316
  def event_loop(self):
273
317
  while True:
274
318
  while True:
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
302
346
  setproctitle.setproctitle("sglang::data_parallel_controller")
303
347
  configure_logger(server_args)
304
348
  parent_process = psutil.Process().parent()
349
+ balance_meta = DPBalanceMeta(server_args.dp_size)
305
350
 
306
351
  try:
307
- controller = DataParallelController(server_args, port_args)
352
+ controller = DataParallelController(
353
+ server_args, port_args, dp_balance_meta=balance_meta
354
+ )
308
355
  pipe_writer.send(
309
356
  {
310
357
  "status": "ready",
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
323
370
  traceback = get_exception_traceback()
324
371
  logger.error(f"DataParallelController hit an exception: {traceback}")
325
372
  parent_process.send_signal(signal.SIGQUIT)
373
+ finally:
374
+ # we need to destruct mp.Manager() in balance_meta
375
+ balance_meta.destructor()
@@ -216,7 +216,7 @@ class DetokenizerManager:
216
216
  rids=recv_obj.rids,
217
217
  finished_reasons=recv_obj.finished_reasons,
218
218
  output_strs=output_strs,
219
- output_ids=None,
219
+ output_ids=recv_obj.decode_ids,
220
220
  prompt_tokens=recv_obj.prompt_tokens,
221
221
  completion_tokens=recv_obj.completion_tokens,
222
222
  cached_tokens=recv_obj.cached_tokens,
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
26
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
28
28
  from sglang.srt.sampling.sampling_params import SamplingParams
29
+ from sglang.srt.utils import ImageData
29
30
 
30
31
  # Handle serialization of Image for pydantic
31
32
  if TYPE_CHECKING:
@@ -45,7 +46,7 @@ class SessionParams:
45
46
 
46
47
  # Type definitions for multimodal input data
47
48
  # Individual data item types for each modality
48
- ImageDataInputItem = Union[Image, str, Dict]
49
+ ImageDataInputItem = Union[Image, str, ImageData, Dict]
49
50
  AudioDataInputItem = Union[str, Dict]
50
51
  VideoDataInputItem = Union[str, Dict]
51
52
  # Union type for any multimodal data item
@@ -101,8 +102,10 @@ class GenerateReqInput:
101
102
 
102
103
  # The modalities of the image data [image, multi-images, video]
103
104
  modalities: Optional[List[str]] = None
104
- # The path to the LoRA
105
+ # The path to the LoRA adaptors
105
106
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
107
+ # The uid of LoRA adaptors, should be initialized by tokenizer manager
108
+ lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
106
109
 
107
110
  # Session info for continual prompting
108
111
  session_params: Optional[Union[List[Dict], Dict]] = None
@@ -123,6 +126,9 @@ class GenerateReqInput:
123
126
  # For data parallel rank routing
124
127
  data_parallel_rank: Optional[int] = None
125
128
 
129
+ # For background responses (OpenAI responses API)
130
+ background: bool = False
131
+
126
132
  def contains_mm_input(self) -> bool:
127
133
  return (
128
134
  has_valid_data(self.image_data)
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
500
506
  stream: bool
501
507
 
502
508
  # LoRA related
503
- lora_path: Optional[str] = None # None means just use the base model
509
+ lora_id: Optional[str] = None # None means just use the base model
504
510
  # The input embeds
505
511
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
506
512
 
@@ -523,6 +529,9 @@ class TokenizedGenerateReqInput:
523
529
  # For data parallel rank routing
524
530
  data_parallel_rank: Optional[int] = None
525
531
 
532
+ # For dp balance
533
+ dp_balance_id: int = -1
534
+
526
535
 
527
536
  @dataclass
528
537
  class EmbeddingReqInput:
@@ -554,6 +563,9 @@ class EmbeddingReqInput:
554
563
  # For cross-encoder requests
555
564
  is_cross_encoder_request: bool = False
556
565
 
566
+ # For background responses (OpenAI responses API)
567
+ background: bool = False
568
+
557
569
  def normalize_batch_and_arguments(self):
558
570
  # at least one of text, input_ids, or image should be provided
559
571
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -648,6 +660,8 @@ class TokenizedEmbeddingReqInput:
648
660
  token_type_ids: List[int]
649
661
  # Dummy sampling params for compatibility
650
662
  sampling_params: SamplingParams
663
+ # For dp balance
664
+ dp_balance_id: int = -1
651
665
 
652
666
 
653
667
  @dataclass
@@ -1068,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
1068
1082
  lora_name: str
1069
1083
  # The path of loading.
1070
1084
  lora_path: str
1085
+ # Whether to pin the LoRA adapter in memory.
1086
+ pinned: bool = False
1071
1087
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1072
1088
  lora_id: Optional[str] = None
1073
1089
 
@@ -1076,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
1076
1092
  lora_id=self.lora_id,
1077
1093
  lora_name=self.lora_name,
1078
1094
  lora_path=self.lora_path,
1095
+ pinned=self.pinned,
1079
1096
  )
1080
1097
 
1081
1098
 
@@ -1097,7 +1114,7 @@ class UnloadLoRAAdapterReqInput:
1097
1114
  class LoRAUpdateResult:
1098
1115
  success: bool
1099
1116
  error_message: Optional[str] = None
1100
- loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1117
+ loaded_adapters: Optional[Dict[str, LoRARef]] = None
1101
1118
 
1102
1119
 
1103
1120
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
388
388
  embedding_per_req = data_embedding_func(embedding_items_per_req)
389
389
  if not embedding_cache.put(embedding_items_hash, embedding_per_req):
390
390
  print_warning_once(
391
- "Multimodal embedding cache is full. Consider increasing the "
392
- "`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
391
+ "Multimodal embedding cache is full. This typically occurs when a single "
392
+ "embedding exceeds the cache size limit. Consider increasing the "
393
+ "`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
394
+ "embedding size."
393
395
  )
394
396
 
395
- embedding_per_req_chunk, _, end_index = get_embedding_chunk(
397
+ embedding_per_req_chunk, _, _ = get_embedding_chunk(
396
398
  embedding=embedding_per_req,
397
399
  extend_prefix_len=prefix_length[i],
398
400
  extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
399
401
  items_offset=items_offset,
400
402
  )
401
- # remove this item from cache if chunk reaches to the end
402
- embedding_per_req_length = (
403
- embedding_per_req.shape[0]
404
- if embedding_per_req.dim() == 2
405
- else embedding_per_req.shape[0] * embedding_per_req.shape[1]
406
- )
407
- if end_index == embedding_per_req_length:
408
- embedding_cache.free(embedding_items_hash)
409
403
  embedding_list.append(embedding_per_req_chunk)
410
404
  if len(embedding_list) == 0:
411
405
  return None
@@ -84,10 +84,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
84
84
  "disable_radix_cache",
85
85
  "enable_dp_attention",
86
86
  "enable_two_batch_overlap",
87
+ "tbo_token_distribution_threshold",
87
88
  "enable_dp_lm_head",
88
- "enable_deepep_moe",
89
+ "moe_a2a_backend",
89
90
  "deepep_mode",
90
- "enable_ep_moe",
91
91
  "enable_flashinfer_cutlass_moe",
92
92
  "enable_flashinfer_trtllm_moe",
93
93
  "enable_flashinfer_allreduce_fusion",
@@ -107,7 +107,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
107
107
  "num_reserved_decode_tokens",
108
108
  "weight_loader_disable_mmap",
109
109
  "enable_triton_kernel_moe",
110
+ "enable_flashinfer_mxfp4_moe",
110
111
  "enable_multimodal",
112
+ "enable_symm_mem",
113
+ "quantization",
111
114
  ]
112
115
 
113
116
  # Put some global args for easy access
@@ -422,7 +425,7 @@ class Req:
422
425
  token_ids_logprob: List[int] = None,
423
426
  stream: bool = False,
424
427
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
425
- lora_path: Optional[str] = None,
428
+ lora_id: Optional[str] = None,
426
429
  input_embeds: Optional[List[List[float]]] = None,
427
430
  token_type_ids: List[int] = None,
428
431
  session_id: Optional[str] = None,
@@ -466,7 +469,7 @@ class Req:
466
469
  self.sampling_params = sampling_params
467
470
  self.custom_logit_processor = custom_logit_processor
468
471
  self.return_hidden_states = return_hidden_states
469
- self.lora_path = lora_path
472
+ self.lora_id = lora_id
470
473
 
471
474
  # Memory pool info
472
475
  self.req_pool_idx: Optional[int] = None
@@ -844,6 +847,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
844
847
 
845
848
  # The sum of all sequence lengths
846
849
  seq_lens_sum: int = None
850
+ # The original sequence lengths, Qwen-1M related
851
+ orig_seq_lens: torch.Tensor = None # shape: [b], int32
847
852
 
848
853
  # For DP attention
849
854
  global_num_tokens: Optional[List[int]] = None
@@ -916,8 +921,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
916
921
 
917
922
  is_hybrid = False
918
923
  if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
919
- assert isinstance(tree_cache, SWARadixCache) or isinstance(
920
- tree_cache, SWAChunkCache
924
+ assert (
925
+ tree_cache is None
926
+ or isinstance(tree_cache, SWARadixCache)
927
+ or isinstance(tree_cache, SWAChunkCache)
921
928
  ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
922
929
  is_hybrid = True
923
930
 
@@ -1127,6 +1134,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1127
1134
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1128
1135
  extend_num_tokens = sum(len(ids) for ids in input_ids)
1129
1136
  seq_lens = [len(r.fill_ids) for r in reqs]
1137
+ orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
1130
1138
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1131
1139
  extend_lens = [r.extend_input_len for r in reqs]
1132
1140
 
@@ -1143,6 +1151,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1143
1151
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1144
1152
  self.device, non_blocking=True
1145
1153
  )
1154
+ orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1155
+ self.device, non_blocking=True
1156
+ )
1146
1157
  prefix_lens_tensor = torch.tensor(
1147
1158
  prefix_lens, dtype=torch.int64, device=self.device
1148
1159
  )
@@ -1256,6 +1267,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1256
1267
  self.input_ids = input_ids_tensor
1257
1268
  self.req_pool_indices = req_pool_indices_tensor
1258
1269
  self.seq_lens = seq_lens_tensor
1270
+ self.orig_seq_lens = orig_seq_lens_tensor
1259
1271
  self.out_cache_loc = out_cache_loc
1260
1272
  self.input_embeds = (
1261
1273
  torch.tensor(input_embeds).to(self.device, non_blocking=True)
@@ -1503,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1503
1515
  self.forward_mode = ForwardMode.IDLE
1504
1516
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1505
1517
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1518
+ self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1506
1519
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1507
1520
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1508
1521
  self.seq_lens_sum = 0
@@ -1557,9 +1570,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1557
1570
  if self.enable_overlap:
1558
1571
  # Do not use in-place operations in the overlap mode
1559
1572
  self.seq_lens = self.seq_lens + 1
1573
+ self.orig_seq_lens = self.orig_seq_lens + 1
1560
1574
  else:
1561
1575
  # A faster in-place version
1562
1576
  self.seq_lens.add_(1)
1577
+ self.orig_seq_lens.add_(1)
1563
1578
  self.seq_lens_sum += bs
1564
1579
 
1565
1580
  # free memory
@@ -1623,6 +1638,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1623
1638
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1624
1639
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1625
1640
  self.seq_lens = self.seq_lens[keep_indices_device]
1641
+ self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1626
1642
  self.out_cache_loc = None
1627
1643
  self.seq_lens_sum = self.seq_lens.sum().item()
1628
1644
  self.output_ids = self.output_ids[keep_indices_device]
@@ -1655,6 +1671,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1655
1671
  [self.req_pool_indices, other.req_pool_indices]
1656
1672
  )
1657
1673
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1674
+ self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1658
1675
  self.out_cache_loc = None
1659
1676
  self.seq_lens_sum += other.seq_lens_sum
1660
1677
  if self.output_ids is not None:
@@ -1704,6 +1721,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1704
1721
  or attention_backend_str == "flashmla"
1705
1722
  or attention_backend_str == "cutlass_mla"
1706
1723
  or attention_backend_str == "ascend"
1724
+ or attention_backend_str == "trtllm_mha"
1707
1725
  or global_server_args_dict["enable_two_batch_overlap"]
1708
1726
  ):
1709
1727
  seq_lens_cpu = (
@@ -1728,6 +1746,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1728
1746
  input_ids=self.input_ids,
1729
1747
  req_pool_indices=self.req_pool_indices,
1730
1748
  seq_lens=self.seq_lens,
1749
+ orig_seq_lens=self.orig_seq_lens,
1731
1750
  out_cache_loc=self.out_cache_loc,
1732
1751
  seq_lens_cpu=seq_lens_cpu,
1733
1752
  seq_lens_sum=self.seq_lens_sum,
@@ -1749,7 +1768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1749
1768
  encoder_lens=self.encoder_lens,
1750
1769
  encoder_lens_cpu=self.encoder_lens_cpu,
1751
1770
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1752
- lora_paths=[req.lora_path for req in self.reqs],
1771
+ lora_ids=[req.lora_id for req in self.reqs],
1753
1772
  sampling_info=self.sampling_info,
1754
1773
  input_embeds=self.input_embeds,
1755
1774
  token_type_ids=self.token_type_ids,
@@ -1890,11 +1909,14 @@ class ModelWorkerBatch:
1890
1909
  encoder_out_cache_loc: Optional[torch.Tensor]
1891
1910
 
1892
1911
  # For LoRA
1893
- lora_paths: Optional[List[str]]
1912
+ lora_ids: Optional[List[str]]
1894
1913
 
1895
1914
  # Sampling info
1896
1915
  sampling_info: SamplingBatchInfo
1897
1916
 
1917
+ # The original sequence lengths, Qwen-1M related
1918
+ orig_seq_lens: Optional[torch.Tensor] = None
1919
+
1898
1920
  # The input Embeds
1899
1921
  input_embeds: Optional[torch.Tensor] = None
1900
1922
 
@@ -455,7 +455,9 @@ class PrefillAdder:
455
455
  if not self.is_hybrid:
456
456
  # Skip this logic for swa. The SWA has different memory management, and
457
457
  # this mechanism is underestimating the memory usage.
458
- cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
458
+ cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
459
+ req.extend_input_len
460
+ )
459
461
  tokens_freed = 0
460
462
  for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
461
463
  # tokens_left gives a reservative calculation as the last token is not stored