sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -169,12 +169,13 @@ class StorageOperation:
169
169
  host_indices: torch.Tensor,
170
170
  token_ids: List[int],
171
171
  last_hash: Optional[str] = None,
172
+ hash_value: Optional[List[str]] = None,
172
173
  ):
173
174
  self.host_indices = host_indices
174
175
  self.token_ids = token_ids
175
176
  self.last_hash = last_hash
176
177
  self.completed_tokens = 0
177
- self.hash_value = []
178
+ self.hash_value = hash_value if hash_value is not None else []
178
179
 
179
180
  self.id = StorageOperation.counter
180
181
  StorageOperation.counter += 1
@@ -259,6 +260,7 @@ class HiCacheController:
259
260
  self.storage_backend = MooncakeStore()
260
261
  self.get_hash_str = get_hash_str_mooncake
261
262
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
+ assert self.mem_pool_host.layout == "page_first"
262
264
  elif storage_backend == "hf3fs":
263
265
  from sglang.srt.distributed import get_tensor_model_parallel_rank
264
266
  from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
@@ -294,6 +296,9 @@ class HiCacheController:
294
296
  self.prefetch_tp_group = torch.distributed.new_group(
295
297
  group_ranks, backend="gloo"
296
298
  )
299
+ self.prefetch_io_tp_group = torch.distributed.new_group(
300
+ group_ranks, backend="gloo"
301
+ )
297
302
  self.backup_tp_group = torch.distributed.new_group(
298
303
  group_ranks, backend="gloo"
299
304
  )
@@ -433,7 +438,9 @@ class HiCacheController:
433
438
  if self.io_backend == "kernel":
434
439
  return host_indices.to(self.mem_pool_device.device), device_indices
435
440
  elif self.io_backend == "direct":
436
- return host_indices, device_indices.cpu()
441
+ device_indices = device_indices.cpu()
442
+ host_indices, idx = host_indices.sort()
443
+ return host_indices, device_indices.index_select(0, idx)
437
444
  else:
438
445
  raise ValueError(f"Unsupported io backend")
439
446
 
@@ -570,10 +577,6 @@ class HiCacheController:
570
577
  )
571
578
  completed_tokens += self.page_size
572
579
  else:
573
- # operation terminated by controller, release pre-allocated memory
574
- self.mem_pool_host.free(
575
- operation.host_indices[operation.completed_tokens :]
576
- )
577
580
  break
578
581
 
579
582
  def mooncake_page_transfer(self, operation):
@@ -599,6 +602,14 @@ class HiCacheController:
599
602
  self.generic_page_transfer(operation, batch_size=128)
600
603
  else:
601
604
  self.generic_page_transfer(operation)
605
+
606
+ if self.tp_world_size > 1:
607
+ # to ensure all TP workers release the host memory at the same time
608
+ torch.distributed.barrier(group=self.prefetch_io_tp_group)
609
+ # operation terminated by controller, release pre-allocated memory
610
+ self.mem_pool_host.free(
611
+ operation.host_indices[operation.completed_tokens :]
612
+ )
602
613
  except Empty:
603
614
  continue
604
615
 
@@ -626,7 +637,9 @@ class HiCacheController:
626
637
  continue
627
638
 
628
639
  storage_hit_count = 0
629
- if self.prefetch_rate_limit_check():
640
+ if (
641
+ operation.host_indices is not None
642
+ ) and self.prefetch_rate_limit_check():
630
643
  last_hash = operation.last_hash
631
644
  tokens_to_fetch = operation.token_ids
632
645
 
@@ -670,7 +683,8 @@ class HiCacheController:
670
683
  if storage_hit_count < self.prefetch_threshold:
671
684
  # not to prefetch if not enough benefits
672
685
  self.prefetch_revoke_queue.put(operation.request_id)
673
- self.mem_pool_host.free(operation.host_indices)
686
+ if operation.host_indices is not None:
687
+ self.mem_pool_host.free(operation.host_indices)
674
688
  logger.debug(
675
689
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
676
690
  )
@@ -693,12 +707,12 @@ class HiCacheController:
693
707
  self,
694
708
  host_indices: torch.Tensor,
695
709
  token_ids: List[int],
696
- last_hash: Optional[str] = None,
710
+ hash_value: Optional[List[str]] = None,
697
711
  ) -> int:
698
712
  """
699
713
  Write KV caches from host memory to storage backend.
700
714
  """
701
- operation = StorageOperation(host_indices, token_ids, last_hash)
715
+ operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
702
716
  self.backup_queue.put(operation)
703
717
  return operation.id
704
718
 
@@ -753,24 +767,6 @@ class HiCacheController:
753
767
  if operation is None:
754
768
  continue
755
769
 
756
- last_hash = operation.last_hash
757
- tokens_to_backup = operation.token_ids
758
-
759
- backup_hit_count = 0
760
- remaining_tokens = len(tokens_to_backup)
761
- hash_value = []
762
- while remaining_tokens >= self.page_size:
763
- last_hash = self.get_hash_str(
764
- tokens_to_backup[
765
- backup_hit_count : backup_hit_count + self.page_size
766
- ],
767
- last_hash,
768
- )
769
- backup_hit_count += self.page_size
770
- hash_value.append(last_hash)
771
- remaining_tokens -= self.page_size
772
- operation.hash_value = hash_value
773
-
774
770
  if self.is_mooncake_backend():
775
771
  self.mooncake_page_backup(operation)
776
772
  elif self.storage_backend_type == "hf3fs":
@@ -793,7 +789,6 @@ class HiCacheController:
793
789
  self.ack_backup_queue.put(
794
790
  (
795
791
  operation.id,
796
- operation.hash_value[: min_completed_tokens // self.page_size],
797
792
  min_completed_tokens,
798
793
  )
799
794
  )
@@ -216,7 +216,7 @@ class DetokenizerManager:
216
216
  rids=recv_obj.rids,
217
217
  finished_reasons=recv_obj.finished_reasons,
218
218
  output_strs=output_strs,
219
- output_ids=recv_obj.decode_ids,
219
+ output_ids=recv_obj.output_ids,
220
220
  prompt_tokens=recv_obj.prompt_tokens,
221
221
  completion_tokens=recv_obj.completion_tokens,
222
222
  cached_tokens=recv_obj.cached_tokens,
@@ -99,25 +99,24 @@ class GenerateReqInput:
99
99
  stream: bool = False
100
100
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
101
101
  log_metrics: bool = True
102
+ # Whether to return hidden states
103
+ return_hidden_states: Union[List[bool], bool] = False
102
104
 
103
105
  # The modalities of the image data [image, multi-images, video]
104
106
  modalities: Optional[List[str]] = None
107
+ # Session info for continual prompting
108
+ session_params: Optional[Union[List[Dict], Dict]] = None
109
+
105
110
  # The path to the LoRA adaptors
106
111
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
107
112
  # The uid of LoRA adaptors, should be initialized by tokenizer manager
108
113
  lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
109
114
 
110
- # Session info for continual prompting
111
- session_params: Optional[Union[List[Dict], Dict]] = None
112
-
113
115
  # Custom logit processor for advanced sampling control. Must be a serialized instance
114
116
  # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
115
117
  # Use the processor's `to_str()` method to generate the serialized string.
116
118
  custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
117
119
 
118
- # Whether to return hidden states
119
- return_hidden_states: Union[List[bool], bool] = False
120
-
121
120
  # For disaggregated inference
122
121
  bootstrap_host: Optional[Union[List[str], str]] = None
123
122
  bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
@@ -456,6 +455,7 @@ class GenerateReqInput:
456
455
  log_metrics=self.log_metrics,
457
456
  modalities=self.modalities[i] if self.modalities else None,
458
457
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
458
+ lora_id=self.lora_id[i] if self.lora_id is not None else None,
459
459
  custom_logit_processor=(
460
460
  self.custom_logit_processor[i]
461
461
  if self.custom_logit_processor is not None
@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
798
798
  load_format: Optional[str] = None
799
799
  # Whether to abort all requests before updating weights
800
800
  abort_all_requests: bool = False
801
+ # Optional: Update weight version along with weights
802
+ weight_version: Optional[str] = None
801
803
 
802
804
 
803
805
  @dataclass
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
819
821
  flush_cache: bool = True
820
822
  # Whether to abort all requests before updating weights
821
823
  abort_all_requests: bool = False
824
+ # Optional: Update weight version along with weights
825
+ weight_version: Optional[str] = None
822
826
 
823
827
 
824
828
  @dataclass
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
842
846
  flush_cache: bool = True
843
847
  # Whether to abort all requests before updating weights
844
848
  abort_all_requests: bool = False
849
+ # Optional: Update weight version along with weights
850
+ weight_version: Optional[str] = None
845
851
 
846
852
 
847
853
  @dataclass
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
872
878
  message: str
873
879
 
874
880
 
881
+ @dataclass
882
+ class UpdateWeightVersionReqInput:
883
+ # The new weight version
884
+ new_version: str
885
+ # Whether to abort all running requests before updating
886
+ abort_all_requests: bool = True
887
+
888
+
875
889
  @dataclass
876
890
  class GetWeightsByNameReqInput:
877
891
  name: str
@@ -614,8 +614,7 @@ def general_mm_embed_routine(
614
614
  input_ids: Input token IDs tensor
615
615
  forward_batch: Batch information for model forward pass
616
616
  language_model: Base language model to use
617
- image_data_embedding_func: Function to embed image data
618
- audio_data_embedding_func: Function to embed audio data
617
+ data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
619
618
  placeholder_tokens: Token IDs for multimodal placeholders
620
619
  **kwargs: Additional arguments passed to language model
621
620
 
@@ -20,7 +20,7 @@ def import_processors():
20
20
  try:
21
21
  module = importlib.import_module(name)
22
22
  except Exception as e:
23
- logger.warning(f"Ignore import error when loading {name}: " f"{e}")
23
+ logger.warning(f"Ignore import error when loading {name}: {e}")
24
24
  continue
25
25
  all_members = inspect.getmembers(module, inspect.isclass)
26
26
  classes = [
@@ -37,6 +37,7 @@ import logging
37
37
  import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
+ from itertools import chain
40
41
  from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
42
 
42
43
  import numpy as np
@@ -57,6 +58,7 @@ from sglang.srt.mem_cache.allocator import (
57
58
  )
58
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
+ from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
60
62
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
62
64
  from sglang.srt.metrics.collector import TimeStats
@@ -82,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
82
84
  "device",
83
85
  "disable_chunked_prefix_cache",
84
86
  "disable_radix_cache",
85
- "enable_dp_attention",
86
87
  "enable_two_batch_overlap",
87
88
  "tbo_token_distribution_threshold",
88
89
  "enable_dp_lm_head",
@@ -111,6 +112,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
111
112
  "enable_multimodal",
112
113
  "enable_symm_mem",
113
114
  "quantization",
115
+ "enable_custom_logit_processor",
114
116
  ]
115
117
 
116
118
  # Put some global args for easy access
@@ -638,14 +640,26 @@ class Req:
638
640
  ):
639
641
  self.fill_ids = self.origin_input_ids + self.output_ids
640
642
  if tree_cache is not None:
641
- (
642
- self.prefix_indices,
643
- self.last_node,
644
- self.last_host_node,
645
- self.host_hit_length,
646
- ) = tree_cache.match_prefix(
647
- key=self.adjust_max_prefix_ids(),
648
- )
643
+ if isinstance(tree_cache, LoRARadixCache):
644
+ (
645
+ self.prefix_indices,
646
+ self.last_node,
647
+ self.last_host_node,
648
+ self.host_hit_length,
649
+ ) = tree_cache.match_prefix_with_lora_id(
650
+ key=LoRAKey(
651
+ lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
652
+ ),
653
+ )
654
+ else:
655
+ (
656
+ self.prefix_indices,
657
+ self.last_node,
658
+ self.last_host_node,
659
+ self.host_hit_length,
660
+ ) = tree_cache.match_prefix(
661
+ key=self.adjust_max_prefix_ids(),
662
+ )
649
663
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
650
664
 
651
665
  def adjust_max_prefix_ids(self):
@@ -895,12 +909,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
895
909
  spec_algorithm: SpeculativeAlgorithm = None
896
910
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
897
911
 
898
- # Enable custom logit processor
899
- enable_custom_logit_processor: bool = False
900
-
901
912
  # Whether to return hidden states
902
913
  return_hidden_states: bool = False
903
914
 
915
+ # Whether this batch is prefill-only (no token generation needed)
916
+ is_prefill_only: bool = False
917
+
904
918
  # hicache pointer for synchronizing data loading from CPU to GPU
905
919
  hicache_consumer_index: int = 0
906
920
 
@@ -914,7 +928,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
914
928
  model_config: ModelConfig,
915
929
  enable_overlap: bool,
916
930
  spec_algorithm: SpeculativeAlgorithm,
917
- enable_custom_logit_processor: bool,
918
931
  chunked_req: Optional[Req] = None,
919
932
  ):
920
933
  return_logprob = any(req.return_logprob for req in reqs)
@@ -941,8 +954,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
941
954
  has_grammar=any(req.grammar for req in reqs),
942
955
  device=req_to_token_pool.device,
943
956
  spec_algorithm=spec_algorithm,
944
- enable_custom_logit_processor=enable_custom_logit_processor,
945
957
  return_hidden_states=any(req.return_hidden_states for req in reqs),
958
+ is_prefill_only=all(
959
+ req.sampling_params.max_new_tokens == 0 for req in reqs
960
+ ),
946
961
  chunked_req=chunked_req,
947
962
  )
948
963
 
@@ -995,6 +1010,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
995
1010
  extend_num_tokens: int,
996
1011
  backup_state: bool = False,
997
1012
  ):
1013
+ # Over estimate the number of tokens: assume each request needs a new page.
998
1014
  num_tokens = (
999
1015
  extend_num_tokens
1000
1016
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
@@ -1027,8 +1043,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1027
1043
  last_loc: torch.Tensor,
1028
1044
  backup_state: bool = False,
1029
1045
  ):
1046
+ # Over estimate the number of tokens: assume each request needs a new page.
1030
1047
  num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1031
-
1032
1048
  self._evict_tree_cache_if_needed(num_tokens)
1033
1049
 
1034
1050
  if backup_state:
@@ -1145,9 +1161,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1145
1161
  req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1146
1162
  self.device, non_blocking=True
1147
1163
  )
1148
- input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1149
- self.device, non_blocking=True
1150
- )
1164
+ input_ids_tensor = torch.tensor(
1165
+ list(chain.from_iterable(input_ids)), dtype=torch.int64
1166
+ ).to(self.device, non_blocking=True)
1151
1167
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1152
1168
  self.device, non_blocking=True
1153
1169
  )
@@ -1707,37 +1723,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1707
1723
  extend_prefix_lens = self.prefix_lens
1708
1724
  extend_logprob_start_lens = self.extend_logprob_start_lens
1709
1725
 
1710
- if self.forward_mode.is_decode_or_idle():
1711
- attention_backend_str = global_server_args_dict["decode_attention_backend"]
1712
- else:
1713
- attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1714
- # Create seq_lens_cpu when needed
1715
- if (
1716
- attention_backend_str == "fa3"
1717
- or (
1718
- global_server_args_dict["use_mla_backend"]
1719
- and attention_backend_str == "flashinfer"
1720
- )
1721
- or attention_backend_str == "flashmla"
1722
- or attention_backend_str == "cutlass_mla"
1723
- or attention_backend_str == "ascend"
1724
- or attention_backend_str == "trtllm_mha"
1725
- or global_server_args_dict["enable_two_batch_overlap"]
1726
- ):
1727
- seq_lens_cpu = (
1728
- seq_lens_cpu_cache
1729
- if seq_lens_cpu_cache is not None
1730
- else self.seq_lens.cpu()
1731
- )
1732
- else:
1733
- seq_lens_cpu = None
1734
-
1735
1726
  if self.sampling_info:
1736
1727
  if self.has_grammar:
1737
1728
  self.sampling_info.grammars = [req.grammar for req in self.reqs]
1738
1729
  else:
1739
1730
  self.sampling_info.grammars = None
1740
1731
 
1732
+ seq_lens_cpu = (
1733
+ seq_lens_cpu_cache
1734
+ if seq_lens_cpu_cache is not None
1735
+ else self.seq_lens.cpu()
1736
+ )
1737
+
1741
1738
  global bid
1742
1739
  bid += 1
1743
1740
  return ModelWorkerBatch(
@@ -1800,18 +1797,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1800
1797
  return_logprob=self.return_logprob,
1801
1798
  decoding_reqs=self.decoding_reqs,
1802
1799
  spec_algorithm=self.spec_algorithm,
1803
- enable_custom_logit_processor=self.enable_custom_logit_processor,
1804
1800
  global_num_tokens=self.global_num_tokens,
1805
1801
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1806
1802
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1807
1803
  is_extend_in_batch=self.is_extend_in_batch,
1804
+ is_prefill_only=self.is_prefill_only,
1808
1805
  )
1809
1806
 
1810
- def _evict_tree_cache_if_needed(
1811
- self,
1812
- num_tokens: int,
1813
- ) -> None:
1814
- if isinstance(self.tree_cache, SWAChunkCache):
1807
+ def _evict_tree_cache_if_needed(self, num_tokens: int):
1808
+ if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
1815
1809
  return
1816
1810
 
1817
1811
  if self.is_hybrid:
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  # This can prevent the server from being too conservative.
37
37
  # Note that this only clips the estimation in the scheduler but does not change the stop
38
38
  # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
39
- CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
39
+ CLIP_MAX_NEW_TOKENS = int(
40
40
  os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
41
41
  )
42
42
 
@@ -305,7 +305,7 @@ class PrefillAdder:
305
305
  [
306
306
  min(
307
307
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
308
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
308
+ CLIP_MAX_NEW_TOKENS,
309
309
  )
310
310
  * self.new_token_ratio
311
311
  for r in running_batch.reqs
@@ -388,7 +388,7 @@ class PrefillAdder:
388
388
  0,
389
389
  req.extend_input_len,
390
390
  (
391
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
391
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
392
392
  if not truncated
393
393
  else 0
394
394
  ),
@@ -477,7 +477,7 @@ class PrefillAdder:
477
477
  self._update_prefill_budget(
478
478
  0,
479
479
  req.extend_input_len,
480
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
480
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
481
481
  )
482
482
  else:
483
483
  if self.rem_chunk_tokens == 0:
@@ -499,7 +499,7 @@ class PrefillAdder:
499
499
  return self.add_one_req_ignore_eos(req, has_chunked_req)
500
500
 
501
501
  total_tokens = req.extend_input_len + min(
502
- req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
502
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
503
503
  )
504
504
 
505
505
  # adjusting the input_tokens based on host_hit_length and page_size
@@ -544,7 +544,7 @@ class PrefillAdder:
544
544
  input_tokens,
545
545
  min(
546
546
  req.sampling_params.max_new_tokens,
547
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
547
+ CLIP_MAX_NEW_TOKENS,
548
548
  ),
549
549
  )
550
550
  else:
@@ -130,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
130
130
  from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
131
131
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
132
132
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
133
+ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
133
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
134
135
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
135
136
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
@@ -611,12 +612,7 @@ class Scheduler(
611
612
  hicache_ratio=server_args.hicache_ratio,
612
613
  hicache_size=server_args.hicache_size,
613
614
  hicache_write_policy=server_args.hicache_write_policy,
614
- hicache_io_backend=(
615
- "direct"
616
- if server_args.attention_backend
617
- == "fa3" # hot fix for incompatibility
618
- else server_args.hicache_io_backend
619
- ),
615
+ hicache_io_backend=server_args.hicache_io_backend,
620
616
  hicache_mem_layout=server_args.hicache_mem_layout,
621
617
  hicache_storage_backend=server_args.hicache_storage_backend,
622
618
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
@@ -635,7 +631,19 @@ class Scheduler(
635
631
  page_size=self.page_size,
636
632
  disable=server_args.disable_radix_cache,
637
633
  )
638
-
634
+ elif self.enable_lora:
635
+ assert (
636
+ not self.enable_hierarchical_cache
637
+ ), "LoRA radix cache doesn't support hierarchical cache"
638
+ assert (
639
+ self.schedule_policy == "fcfs"
640
+ ), "LoRA radix cache only supports FCFS policy"
641
+ self.tree_cache = LoRARadixCache(
642
+ req_to_token_pool=self.req_to_token_pool,
643
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
644
+ page_size=self.page_size,
645
+ disable=server_args.disable_radix_cache,
646
+ )
639
647
  else:
640
648
  self.tree_cache = RadixCache(
641
649
  req_to_token_pool=self.req_to_token_pool,
@@ -1458,8 +1466,9 @@ class Scheduler(
1458
1466
  if self.last_batch.batch_size() < last_bs:
1459
1467
  self.running_batch.batch_is_full = False
1460
1468
 
1461
- # Merge the new batch into the running batch
1462
- if not self.last_batch.is_empty():
1469
+ # Merge the new batch into the running batch.
1470
+ # For prefill-only batch, we can avoid going through decoding step.
1471
+ if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
1463
1472
  if self.running_batch.is_empty():
1464
1473
  self.running_batch = self.last_batch
1465
1474
  else:
@@ -1626,7 +1635,6 @@ class Scheduler(
1626
1635
  self.model_config,
1627
1636
  self.enable_overlap,
1628
1637
  self.spec_algorithm,
1629
- self.server_args.enable_custom_logit_processor,
1630
1638
  chunked_req=self.chunked_req,
1631
1639
  )
1632
1640
  if self.enable_hierarchical_cache:
@@ -2023,7 +2031,6 @@ class Scheduler(
2023
2031
  self.model_config,
2024
2032
  self.enable_overlap,
2025
2033
  self.spec_algorithm,
2026
- self.server_args.enable_custom_logit_processor,
2027
2034
  )
2028
2035
  idle_batch.prepare_for_idle()
2029
2036
  return idle_batch
@@ -8,6 +8,18 @@ import torch
8
8
 
9
9
  from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if _is_npu:
15
+ import torch_npu
16
+
17
+ patches = [
18
+ ["profiler.profile", torch_npu.profiler.profile],
19
+ ["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
20
+ ["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
21
+ ]
22
+ torch_npu._apply_patches(patches)
11
23
 
12
24
  logger = logging.getLogger(__name__)
13
25
 
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
136
148
  activities=torchprof_activities,
137
149
  with_stack=with_stack if with_stack is not None else True,
138
150
  record_shapes=record_shapes if record_shapes is not None else False,
151
+ on_trace_ready=(
152
+ None
153
+ if not _is_npu
154
+ else torch_npu.profiler.tensorboard_trace_handler(
155
+ self.torch_profiler_output_dir
156
+ )
157
+ ),
139
158
  )
140
159
  self.torch_profiler.start()
141
160
  self.profile_in_progress = True
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
166
185
  logger.info("Stop profiling" + stage_suffix + "...")
167
186
  if self.torch_profiler is not None:
168
187
  self.torch_profiler.stop()
169
- self.torch_profiler.export_chrome_trace(
170
- os.path.join(
171
- self.torch_profiler_output_dir,
172
- self.profile_id
173
- + f"-TP-{self.tp_rank}"
174
- + stage_suffix
175
- + ".trace.json.gz",
188
+ if not _is_npu:
189
+ self.torch_profiler.export_chrome_trace(
190
+ os.path.join(
191
+ self.torch_profiler_output_dir,
192
+ self.profile_id
193
+ + f"-TP-{self.tp_rank}"
194
+ + stage_suffix
195
+ + ".trace.json.gz",
196
+ )
176
197
  )
177
- )
178
198
  torch.distributed.barrier(self.tp_cpu_group)
179
199
 
180
200
  if self.rpd_profiler is not None: