sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -169,12 +169,13 @@ class StorageOperation:
169
169
  host_indices: torch.Tensor,
170
170
  token_ids: List[int],
171
171
  last_hash: Optional[str] = None,
172
+ hash_value: Optional[List[str]] = None,
172
173
  ):
173
174
  self.host_indices = host_indices
174
175
  self.token_ids = token_ids
175
176
  self.last_hash = last_hash
176
177
  self.completed_tokens = 0
177
- self.hash_value = []
178
+ self.hash_value = hash_value if hash_value is not None else []
178
179
 
179
180
  self.id = StorageOperation.counter
180
181
  StorageOperation.counter += 1
@@ -259,6 +260,7 @@ class HiCacheController:
259
260
  self.storage_backend = MooncakeStore()
260
261
  self.get_hash_str = get_hash_str_mooncake
261
262
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
+ assert self.mem_pool_host.layout == "page_first"
262
264
  elif storage_backend == "hf3fs":
263
265
  from sglang.srt.distributed import get_tensor_model_parallel_rank
264
266
  from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
@@ -433,7 +435,9 @@ class HiCacheController:
433
435
  if self.io_backend == "kernel":
434
436
  return host_indices.to(self.mem_pool_device.device), device_indices
435
437
  elif self.io_backend == "direct":
436
- return host_indices, device_indices.cpu()
438
+ device_indices = device_indices.cpu()
439
+ host_indices, idx = host_indices.sort()
440
+ return host_indices, device_indices.index_select(0, idx)
437
441
  else:
438
442
  raise ValueError(f"Unsupported io backend")
439
443
 
@@ -570,10 +574,6 @@ class HiCacheController:
570
574
  )
571
575
  completed_tokens += self.page_size
572
576
  else:
573
- # operation terminated by controller, release pre-allocated memory
574
- self.mem_pool_host.free(
575
- operation.host_indices[operation.completed_tokens :]
576
- )
577
577
  break
578
578
 
579
579
  def mooncake_page_transfer(self, operation):
@@ -599,6 +599,14 @@ class HiCacheController:
599
599
  self.generic_page_transfer(operation, batch_size=128)
600
600
  else:
601
601
  self.generic_page_transfer(operation)
602
+
603
+ if self.tp_world_size > 1:
604
+ # to ensure all TP workers release the host memory at the same time
605
+ torch.distributed.barrier(group=self.prefetch_tp_group)
606
+ # operation terminated by controller, release pre-allocated memory
607
+ self.mem_pool_host.free(
608
+ operation.host_indices[operation.completed_tokens :]
609
+ )
602
610
  except Empty:
603
611
  continue
604
612
 
@@ -626,7 +634,9 @@ class HiCacheController:
626
634
  continue
627
635
 
628
636
  storage_hit_count = 0
629
- if self.prefetch_rate_limit_check():
637
+ if (
638
+ operation.host_indices is not None
639
+ ) and self.prefetch_rate_limit_check():
630
640
  last_hash = operation.last_hash
631
641
  tokens_to_fetch = operation.token_ids
632
642
 
@@ -670,7 +680,8 @@ class HiCacheController:
670
680
  if storage_hit_count < self.prefetch_threshold:
671
681
  # not to prefetch if not enough benefits
672
682
  self.prefetch_revoke_queue.put(operation.request_id)
673
- self.mem_pool_host.free(operation.host_indices)
683
+ if operation.host_indices is not None:
684
+ self.mem_pool_host.free(operation.host_indices)
674
685
  logger.debug(
675
686
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
676
687
  )
@@ -693,12 +704,12 @@ class HiCacheController:
693
704
  self,
694
705
  host_indices: torch.Tensor,
695
706
  token_ids: List[int],
696
- last_hash: Optional[str] = None,
707
+ hash_value: Optional[List[str]] = None,
697
708
  ) -> int:
698
709
  """
699
710
  Write KV caches from host memory to storage backend.
700
711
  """
701
- operation = StorageOperation(host_indices, token_ids, last_hash)
712
+ operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
702
713
  self.backup_queue.put(operation)
703
714
  return operation.id
704
715
 
@@ -753,24 +764,6 @@ class HiCacheController:
753
764
  if operation is None:
754
765
  continue
755
766
 
756
- last_hash = operation.last_hash
757
- tokens_to_backup = operation.token_ids
758
-
759
- backup_hit_count = 0
760
- remaining_tokens = len(tokens_to_backup)
761
- hash_value = []
762
- while remaining_tokens >= self.page_size:
763
- last_hash = self.get_hash_str(
764
- tokens_to_backup[
765
- backup_hit_count : backup_hit_count + self.page_size
766
- ],
767
- last_hash,
768
- )
769
- backup_hit_count += self.page_size
770
- hash_value.append(last_hash)
771
- remaining_tokens -= self.page_size
772
- operation.hash_value = hash_value
773
-
774
767
  if self.is_mooncake_backend():
775
768
  self.mooncake_page_backup(operation)
776
769
  elif self.storage_backend_type == "hf3fs":
@@ -793,7 +786,6 @@ class HiCacheController:
793
786
  self.ack_backup_queue.put(
794
787
  (
795
788
  operation.id,
796
- operation.hash_value[: min_completed_tokens // self.page_size],
797
789
  min_completed_tokens,
798
790
  )
799
791
  )
@@ -216,7 +216,7 @@ class DetokenizerManager:
216
216
  rids=recv_obj.rids,
217
217
  finished_reasons=recv_obj.finished_reasons,
218
218
  output_strs=output_strs,
219
- output_ids=recv_obj.decode_ids,
219
+ output_ids=recv_obj.output_ids,
220
220
  prompt_tokens=recv_obj.prompt_tokens,
221
221
  completion_tokens=recv_obj.completion_tokens,
222
222
  cached_tokens=recv_obj.cached_tokens,
@@ -99,25 +99,24 @@ class GenerateReqInput:
99
99
  stream: bool = False
100
100
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
101
101
  log_metrics: bool = True
102
+ # Whether to return hidden states
103
+ return_hidden_states: Union[List[bool], bool] = False
102
104
 
103
105
  # The modalities of the image data [image, multi-images, video]
104
106
  modalities: Optional[List[str]] = None
107
+ # Session info for continual prompting
108
+ session_params: Optional[Union[List[Dict], Dict]] = None
109
+
105
110
  # The path to the LoRA adaptors
106
111
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
107
112
  # The uid of LoRA adaptors, should be initialized by tokenizer manager
108
113
  lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
109
114
 
110
- # Session info for continual prompting
111
- session_params: Optional[Union[List[Dict], Dict]] = None
112
-
113
115
  # Custom logit processor for advanced sampling control. Must be a serialized instance
114
116
  # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
115
117
  # Use the processor's `to_str()` method to generate the serialized string.
116
118
  custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
117
119
 
118
- # Whether to return hidden states
119
- return_hidden_states: Union[List[bool], bool] = False
120
-
121
120
  # For disaggregated inference
122
121
  bootstrap_host: Optional[Union[List[str], str]] = None
123
122
  bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
@@ -456,6 +455,7 @@ class GenerateReqInput:
456
455
  log_metrics=self.log_metrics,
457
456
  modalities=self.modalities[i] if self.modalities else None,
458
457
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
458
+ lora_id=self.lora_id[i] if self.lora_id is not None else None,
459
459
  custom_logit_processor=(
460
460
  self.custom_logit_processor[i]
461
461
  if self.custom_logit_processor is not None
@@ -614,8 +614,7 @@ def general_mm_embed_routine(
614
614
  input_ids: Input token IDs tensor
615
615
  forward_batch: Batch information for model forward pass
616
616
  language_model: Base language model to use
617
- image_data_embedding_func: Function to embed image data
618
- audio_data_embedding_func: Function to embed audio data
617
+ data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
619
618
  placeholder_tokens: Token IDs for multimodal placeholders
620
619
  **kwargs: Additional arguments passed to language model
621
620
 
@@ -20,7 +20,7 @@ def import_processors():
20
20
  try:
21
21
  module = importlib.import_module(name)
22
22
  except Exception as e:
23
- logger.warning(f"Ignore import error when loading {name}: " f"{e}")
23
+ logger.warning(f"Ignore import error when loading {name}: {e}")
24
24
  continue
25
25
  all_members = inspect.getmembers(module, inspect.isclass)
26
26
  classes = [
@@ -37,6 +37,7 @@ import logging
37
37
  import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
+ from itertools import chain
40
41
  from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
42
 
42
43
  import numpy as np
@@ -57,6 +58,7 @@ from sglang.srt.mem_cache.allocator import (
57
58
  )
58
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
+ from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
60
62
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
62
64
  from sglang.srt.metrics.collector import TimeStats
@@ -638,14 +640,26 @@ class Req:
638
640
  ):
639
641
  self.fill_ids = self.origin_input_ids + self.output_ids
640
642
  if tree_cache is not None:
641
- (
642
- self.prefix_indices,
643
- self.last_node,
644
- self.last_host_node,
645
- self.host_hit_length,
646
- ) = tree_cache.match_prefix(
647
- key=self.adjust_max_prefix_ids(),
648
- )
643
+ if isinstance(tree_cache, LoRARadixCache):
644
+ (
645
+ self.prefix_indices,
646
+ self.last_node,
647
+ self.last_host_node,
648
+ self.host_hit_length,
649
+ ) = tree_cache.match_prefix_with_lora_id(
650
+ key=LoRAKey(
651
+ lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
652
+ ),
653
+ )
654
+ else:
655
+ (
656
+ self.prefix_indices,
657
+ self.last_node,
658
+ self.last_host_node,
659
+ self.host_hit_length,
660
+ ) = tree_cache.match_prefix(
661
+ key=self.adjust_max_prefix_ids(),
662
+ )
649
663
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
650
664
 
651
665
  def adjust_max_prefix_ids(self):
@@ -1145,9 +1159,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1145
1159
  req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1146
1160
  self.device, non_blocking=True
1147
1161
  )
1148
- input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1149
- self.device, non_blocking=True
1150
- )
1162
+ input_ids_tensor = torch.tensor(
1163
+ list(chain.from_iterable(input_ids)), dtype=torch.int64
1164
+ ).to(self.device, non_blocking=True)
1151
1165
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1152
1166
  self.device, non_blocking=True
1153
1167
  )
@@ -1713,15 +1727,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1713
1727
  attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1714
1728
  # Create seq_lens_cpu when needed
1715
1729
  if (
1716
- attention_backend_str == "fa3"
1717
- or (
1718
- global_server_args_dict["use_mla_backend"]
1719
- and attention_backend_str == "flashinfer"
1720
- )
1721
- or attention_backend_str == "flashmla"
1722
- or attention_backend_str == "cutlass_mla"
1723
- or attention_backend_str == "ascend"
1724
- or attention_backend_str == "trtllm_mha"
1730
+ attention_backend_str
1731
+ in [
1732
+ "fa3",
1733
+ "flashinfer",
1734
+ "flashmla",
1735
+ "cutlass_mla",
1736
+ "ascend",
1737
+ "trtllm_mha",
1738
+ "aiter",
1739
+ ]
1725
1740
  or global_server_args_dict["enable_two_batch_overlap"]
1726
1741
  ):
1727
1742
  seq_lens_cpu = (
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  # This can prevent the server from being too conservative.
37
37
  # Note that this only clips the estimation in the scheduler but does not change the stop
38
38
  # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
39
- CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
39
+ CLIP_MAX_NEW_TOKENS = int(
40
40
  os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
41
41
  )
42
42
 
@@ -305,7 +305,7 @@ class PrefillAdder:
305
305
  [
306
306
  min(
307
307
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
308
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
308
+ CLIP_MAX_NEW_TOKENS,
309
309
  )
310
310
  * self.new_token_ratio
311
311
  for r in running_batch.reqs
@@ -388,7 +388,7 @@ class PrefillAdder:
388
388
  0,
389
389
  req.extend_input_len,
390
390
  (
391
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
391
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
392
392
  if not truncated
393
393
  else 0
394
394
  ),
@@ -477,7 +477,7 @@ class PrefillAdder:
477
477
  self._update_prefill_budget(
478
478
  0,
479
479
  req.extend_input_len,
480
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
480
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
481
481
  )
482
482
  else:
483
483
  if self.rem_chunk_tokens == 0:
@@ -499,7 +499,7 @@ class PrefillAdder:
499
499
  return self.add_one_req_ignore_eos(req, has_chunked_req)
500
500
 
501
501
  total_tokens = req.extend_input_len + min(
502
- req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
502
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
503
503
  )
504
504
 
505
505
  # adjusting the input_tokens based on host_hit_length and page_size
@@ -544,7 +544,7 @@ class PrefillAdder:
544
544
  input_tokens,
545
545
  min(
546
546
  req.sampling_params.max_new_tokens,
547
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
547
+ CLIP_MAX_NEW_TOKENS,
548
548
  ),
549
549
  )
550
550
  else:
@@ -130,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
130
130
  from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
131
131
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
132
132
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
133
+ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
133
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
134
135
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
135
136
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
@@ -611,12 +612,7 @@ class Scheduler(
611
612
  hicache_ratio=server_args.hicache_ratio,
612
613
  hicache_size=server_args.hicache_size,
613
614
  hicache_write_policy=server_args.hicache_write_policy,
614
- hicache_io_backend=(
615
- "direct"
616
- if server_args.attention_backend
617
- == "fa3" # hot fix for incompatibility
618
- else server_args.hicache_io_backend
619
- ),
615
+ hicache_io_backend=server_args.hicache_io_backend,
620
616
  hicache_mem_layout=server_args.hicache_mem_layout,
621
617
  hicache_storage_backend=server_args.hicache_storage_backend,
622
618
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
@@ -635,7 +631,19 @@ class Scheduler(
635
631
  page_size=self.page_size,
636
632
  disable=server_args.disable_radix_cache,
637
633
  )
638
-
634
+ elif self.enable_lora:
635
+ assert (
636
+ not self.enable_hierarchical_cache
637
+ ), "LoRA radix cache doesn't support hierarchical cache"
638
+ assert (
639
+ self.schedule_policy == "fcfs"
640
+ ), "LoRA radix cache only supports FCFS policy"
641
+ self.tree_cache = LoRARadixCache(
642
+ req_to_token_pool=self.req_to_token_pool,
643
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
644
+ page_size=self.page_size,
645
+ disable=server_args.disable_radix_cache,
646
+ )
639
647
  else:
640
648
  self.tree_cache = RadixCache(
641
649
  req_to_token_pool=self.req_to_token_pool,
@@ -8,6 +8,18 @@ import torch
8
8
 
9
9
  from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if _is_npu:
15
+ import torch_npu
16
+
17
+ patches = [
18
+ ["profiler.profile", torch_npu.profiler.profile],
19
+ ["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
20
+ ["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
21
+ ]
22
+ torch_npu._apply_patches(patches)
11
23
 
12
24
  logger = logging.getLogger(__name__)
13
25
 
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
136
148
  activities=torchprof_activities,
137
149
  with_stack=with_stack if with_stack is not None else True,
138
150
  record_shapes=record_shapes if record_shapes is not None else False,
151
+ on_trace_ready=(
152
+ None
153
+ if not _is_npu
154
+ else torch_npu.profiler.tensorboard_trace_handler(
155
+ self.torch_profiler_output_dir
156
+ )
157
+ ),
139
158
  )
140
159
  self.torch_profiler.start()
141
160
  self.profile_in_progress = True
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
166
185
  logger.info("Stop profiling" + stage_suffix + "...")
167
186
  if self.torch_profiler is not None:
168
187
  self.torch_profiler.stop()
169
- self.torch_profiler.export_chrome_trace(
170
- os.path.join(
171
- self.torch_profiler_output_dir,
172
- self.profile_id
173
- + f"-TP-{self.tp_rank}"
174
- + stage_suffix
175
- + ".trace.json.gz",
188
+ if not _is_npu:
189
+ self.torch_profiler.export_chrome_trace(
190
+ os.path.join(
191
+ self.torch_profiler_output_dir,
192
+ self.profile_id
193
+ + f"-TP-{self.tp_rank}"
194
+ + stage_suffix
195
+ + ".trace.json.gz",
196
+ )
176
197
  )
177
- )
178
198
  torch.distributed.barrier(self.tp_cpu_group)
179
199
 
180
200
  if self.rpd_profiler is not None:
@@ -269,10 +269,9 @@ class TokenizerManager:
269
269
  self.asyncio_tasks = set()
270
270
 
271
271
  # Health check
272
- self.health_check_failed = False
272
+ self.server_status = ServerStatus.Starting
273
273
  self.gracefully_exit = False
274
274
  self.last_receive_tstamp = 0
275
- self.server_status = ServerStatus.Starting
276
275
 
277
276
  # Dumping
278
277
  self.dump_requests_folder = "" # By default do not dump
@@ -291,8 +290,8 @@ class TokenizerManager:
291
290
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
292
291
  None
293
292
  )
294
- self._is_updating = False
295
- self._is_updating_cond = asyncio.Condition()
293
+ self.is_pause = False
294
+ self.is_pause_cond = asyncio.Condition()
296
295
 
297
296
  # LoRA
298
297
  # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
@@ -476,16 +475,20 @@ class TokenizerManager:
476
475
  self.auto_create_handle_loop()
477
476
  obj.normalize_batch_and_arguments()
478
477
 
479
- async with self._is_updating_cond:
480
- await self._is_updating_cond.wait_for(lambda: not self._is_updating)
481
-
482
478
  if self.log_requests:
483
479
  max_length, skip_names, _ = self.log_request_metadata
484
480
  logger.info(
485
481
  f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
486
482
  )
487
483
 
484
+ async with self.is_pause_cond:
485
+ await self.is_pause_cond.wait_for(lambda: not self.is_pause)
486
+
488
487
  async with self.model_update_lock.reader_lock:
488
+ if self.server_args.enable_lora and obj.lora_path:
489
+ # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
490
+ obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
491
+
489
492
  if obj.is_single:
490
493
  tokenized_obj = await self._tokenize_one_request(obj)
491
494
  state = self._send_one_request(obj, tokenized_obj, created_time)
@@ -553,11 +556,6 @@ class TokenizerManager:
553
556
  else:
554
557
  mm_inputs = None
555
558
 
556
- if self.server_args.enable_lora and obj.lora_path:
557
- # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
558
- # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
559
- obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
560
-
561
559
  self._validate_one_request(obj, input_ids)
562
560
  return self._create_tokenized_object(
563
561
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -775,10 +773,6 @@ class TokenizerManager:
775
773
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
776
774
  logger.info(msg)
777
775
 
778
- # Mark ongoing LoRA request as finished.
779
- if self.server_args.enable_lora and obj.lora_path:
780
- await self.lora_registry.release(obj.lora_id)
781
-
782
776
  # Check if this was an abort/error created by scheduler
783
777
  if isinstance(out["meta_info"].get("finish_reason"), dict):
784
778
  finish_reason = out["meta_info"]["finish_reason"]
@@ -797,6 +791,11 @@ class TokenizerManager:
797
791
  # Delete the key to prevent resending abort request to the scheduler and
798
792
  # to ensure aborted request state is cleaned up.
799
793
  del self.rid_to_state[state.obj.rid]
794
+
795
+ # Mark ongoing LoRA request as finished.
796
+ if self.server_args.enable_lora and state.obj.lora_path:
797
+ await self.lora_registry.release(state.obj.lora_id)
798
+
800
799
  raise fastapi.HTTPException(
801
800
  status_code=finish_reason["status_code"],
802
801
  detail=finish_reason["message"],
@@ -982,14 +981,14 @@ class TokenizerManager:
982
981
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
983
982
 
984
983
  async def pause_generation(self):
985
- async with self._is_updating_cond:
986
- self._is_updating = True
984
+ async with self.is_pause_cond:
985
+ self.is_pause = True
987
986
  self.abort_request(abort_all=True)
988
987
 
989
988
  async def continue_generation(self):
990
- async with self._is_updating_cond:
991
- self._is_updating = False
992
- self._is_updating_cond.notify_all()
989
+ async with self.is_pause_cond:
990
+ self.is_pause = False
991
+ self.is_pause_cond.notify_all()
993
992
 
994
993
  async def update_weights_from_disk(
995
994
  self,
@@ -1474,7 +1473,7 @@ class TokenizerManager:
1474
1473
  while True:
1475
1474
  remain_num_req = len(self.rid_to_state)
1476
1475
 
1477
- if self.health_check_failed:
1476
+ if self.server_status == ServerStatus.UnHealthy:
1478
1477
  # if health check failed, we should exit immediately
1479
1478
  logger.error(
1480
1479
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
@@ -1600,6 +1599,10 @@ class TokenizerManager:
1600
1599
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1601
1600
  del self.rid_to_state[rid]
1602
1601
 
1602
+ # Mark ongoing LoRA request as finished.
1603
+ if self.server_args.enable_lora and state.obj.lora_path:
1604
+ asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
1605
+
1603
1606
  state.out_list.append(out_dict)
1604
1607
  state.event.set()
1605
1608
 
@@ -1965,10 +1968,6 @@ class ServerStatus(Enum):
1965
1968
  Up = "Up"
1966
1969
  Starting = "Starting"
1967
1970
  UnHealthy = "UnHealthy"
1968
- Crashed = "Crashed"
1969
-
1970
- def is_healthy(self) -> bool:
1971
- return self == ServerStatus.Up
1972
1971
 
1973
1972
 
1974
1973
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: