sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ import uuid
29
29
  from collections import deque
30
30
  from contextlib import nullcontext
31
31
  from datetime import datetime
32
+ from enum import Enum
32
33
  from http import HTTPStatus
33
34
  from typing import (
34
35
  Any,
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
70
71
  BatchMultimodalOut,
71
72
  BatchStrOut,
72
73
  BatchTokenIDOut,
73
- BlockReqType,
74
74
  CloseSessionReqInput,
75
75
  ConfigureLoggingReq,
76
76
  EmbeddingReqInput,
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
116
116
  )
117
117
  from sglang.srt.managers.mm_utils import TensorTransportMode
118
118
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
+ from sglang.srt.managers.scheduler import is_health_check_generate_req
119
120
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
120
121
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
121
122
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -202,13 +203,29 @@ class TokenizerManager:
202
203
 
203
204
  if self.model_config.is_multimodal:
204
205
  import_processors()
205
- _processor = get_processor(
206
- server_args.tokenizer_path,
207
- tokenizer_mode=server_args.tokenizer_mode,
208
- trust_remote_code=server_args.trust_remote_code,
209
- revision=server_args.revision,
210
- use_fast=not server_args.disable_fast_image_processor,
211
- )
206
+ try:
207
+ _processor = get_processor(
208
+ server_args.tokenizer_path,
209
+ tokenizer_mode=server_args.tokenizer_mode,
210
+ trust_remote_code=server_args.trust_remote_code,
211
+ revision=server_args.revision,
212
+ use_fast=not server_args.disable_fast_image_processor,
213
+ )
214
+ except ValueError as e:
215
+ error_message = str(e)
216
+ if "does not have a slow version" in error_message:
217
+ logger.info(
218
+ f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
219
+ )
220
+ _processor = get_processor(
221
+ server_args.tokenizer_path,
222
+ tokenizer_mode=server_args.tokenizer_mode,
223
+ trust_remote_code=server_args.trust_remote_code,
224
+ revision=server_args.revision,
225
+ use_fast=True,
226
+ )
227
+ else:
228
+ raise e
212
229
  transport_mode = _determine_tensor_transport_mode(self.server_args)
213
230
 
214
231
  # We want to parallelize the image pre-processing so we create an executor for it
@@ -225,10 +242,10 @@ class TokenizerManager:
225
242
  self.tokenizer = get_tokenizer_from_processor(self.processor)
226
243
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
227
244
  else:
228
- self.mm_processor = None
245
+ self.mm_processor = self.processor = None
229
246
 
230
247
  if server_args.skip_tokenizer_init:
231
- self.tokenizer = self.processor = None
248
+ self.tokenizer = None
232
249
  else:
233
250
  self.tokenizer = get_tokenizer(
234
251
  server_args.tokenizer_path,
@@ -255,6 +272,7 @@ class TokenizerManager:
255
272
  self.health_check_failed = False
256
273
  self.gracefully_exit = False
257
274
  self.last_receive_tstamp = 0
275
+ self.server_status = ServerStatus.Starting
258
276
 
259
277
  # Dumping
260
278
  self.dump_requests_folder = "" # By default do not dump
@@ -538,7 +556,7 @@ class TokenizerManager:
538
556
  if self.server_args.enable_lora and obj.lora_path:
539
557
  # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
540
558
  # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
541
- obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
559
+ obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
542
560
 
543
561
  self._validate_one_request(obj, input_ids)
544
562
  return self._create_tokenized_object(
@@ -647,7 +665,7 @@ class TokenizerManager:
647
665
  bootstrap_host=obj.bootstrap_host,
648
666
  bootstrap_port=obj.bootstrap_port,
649
667
  bootstrap_room=obj.bootstrap_room,
650
- lora_path=obj.lora_path,
668
+ lora_id=obj.lora_id,
651
669
  input_embeds=input_embeds,
652
670
  session_params=session_params,
653
671
  custom_logit_processor=obj.custom_logit_processor,
@@ -732,7 +750,11 @@ class TokenizerManager:
732
750
  try:
733
751
  await asyncio.wait_for(state.event.wait(), timeout=4)
734
752
  except asyncio.TimeoutError:
735
- if request is not None and await request.is_disconnected():
753
+ if (
754
+ request is not None
755
+ and not obj.background
756
+ and await request.is_disconnected()
757
+ ):
736
758
  # Abort the request for disconnected requests (non-streaming, waiting queue)
737
759
  self.abort_request(obj.rid)
738
760
  # Use exception to kill the whole call stack and asyncio task
@@ -755,7 +777,7 @@ class TokenizerManager:
755
777
 
756
778
  # Mark ongoing LoRA request as finished.
757
779
  if self.server_args.enable_lora and obj.lora_path:
758
- await self.lora_registry.release(obj.lora_path)
780
+ await self.lora_registry.release(obj.lora_id)
759
781
 
760
782
  # Check if this was an abort/error created by scheduler
761
783
  if isinstance(out["meta_info"].get("finish_reason"), dict):
@@ -787,7 +809,11 @@ class TokenizerManager:
787
809
  if obj.stream:
788
810
  yield out
789
811
  else:
790
- if request is not None and await request.is_disconnected():
812
+ if (
813
+ request is not None
814
+ and not obj.background
815
+ and await request.is_disconnected()
816
+ ):
791
817
  # Abort the request for disconnected requests (non-streaming, running)
792
818
  self.abort_request(obj.rid)
793
819
  # Use exception to kill the whole call stack and asyncio task
@@ -1069,38 +1095,57 @@ class TokenizerManager:
1069
1095
  _: Optional[fastapi.Request] = None,
1070
1096
  ) -> LoadLoRAAdapterReqOutput:
1071
1097
  self.auto_create_handle_loop()
1072
- if not self.server_args.enable_lora:
1073
- raise ValueError(
1074
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1075
- )
1076
1098
 
1077
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1078
- # with dp_size > 1.
1079
- assert (
1080
- self.server_args.dp_size == 1
1081
- ), "dp_size must be 1 for dynamic lora loading"
1082
- logger.info(
1083
- "Start load Lora adapter. Lora name=%s, path=%s",
1084
- obj.lora_name,
1085
- obj.lora_path,
1086
- )
1099
+ try:
1100
+ if not self.server_args.enable_lora:
1101
+ raise ValueError(
1102
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1103
+ )
1087
1104
 
1088
- async with self.lora_update_lock:
1089
- # Generate new uniquely identifiable LoRARef object.
1090
- new_adapter = LoRARef(
1091
- lora_name=obj.lora_name,
1092
- lora_path=obj.lora_path,
1105
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1106
+ # with dp_size > 1.
1107
+ assert (
1108
+ self.server_args.dp_size == 1
1109
+ ), "dp_size must be 1 for dynamic lora loading"
1110
+ logger.info(
1111
+ "Start load Lora adapter. Lora name=%s, path=%s",
1112
+ obj.lora_name,
1113
+ obj.lora_path,
1093
1114
  )
1094
1115
 
1095
- # Trigger the actual loading operation at the backend processes.
1096
- obj.lora_id = new_adapter.lora_id
1097
- result = (await self.update_lora_adapter_communicator(obj))[0]
1116
+ async with self.lora_update_lock:
1117
+ if (
1118
+ self.server_args.max_loaded_loras is not None
1119
+ and self.lora_registry.num_registered_loras
1120
+ >= self.server_args.max_loaded_loras
1121
+ ):
1122
+ raise ValueError(
1123
+ f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1124
+ f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1125
+ "Please unload some LoRA adapters before loading new ones."
1126
+ )
1098
1127
 
1099
- # Register the LoRA adapter only after loading is successful.
1100
- if result.success:
1101
- await self.lora_registry.register(new_adapter)
1128
+ # Generate new uniquely identifiable LoRARef object.
1129
+ new_adapter = LoRARef(
1130
+ lora_name=obj.lora_name,
1131
+ lora_path=obj.lora_path,
1132
+ pinned=obj.pinned,
1133
+ )
1102
1134
 
1103
- return result
1135
+ # Trigger the actual loading operation at the backend processes.
1136
+ obj.lora_id = new_adapter.lora_id
1137
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1138
+
1139
+ # Register the LoRA adapter only after loading is successful.
1140
+ if result.success:
1141
+ await self.lora_registry.register(new_adapter)
1142
+
1143
+ return result
1144
+ except ValueError as e:
1145
+ return LoadLoRAAdapterReqOutput(
1146
+ success=False,
1147
+ error_message=str(e),
1148
+ )
1104
1149
 
1105
1150
  async def unload_lora_adapter(
1106
1151
  self,
@@ -1108,37 +1153,41 @@ class TokenizerManager:
1108
1153
  _: Optional[fastapi.Request] = None,
1109
1154
  ) -> UnloadLoRAAdapterReqOutput:
1110
1155
  self.auto_create_handle_loop()
1111
- if not self.server_args.enable_lora:
1112
- raise ValueError(
1113
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1114
- )
1115
1156
 
1116
- assert (
1117
- obj.lora_name is not None
1118
- ), "lora_name must be provided to unload LoRA adapter"
1157
+ try:
1158
+ if not self.server_args.enable_lora:
1159
+ raise ValueError(
1160
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1161
+ )
1119
1162
 
1120
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1121
- # with dp_size > 1.
1122
- assert (
1123
- self.server_args.dp_size == 1
1124
- ), "dp_size must be 1 for dynamic lora loading"
1125
- logger.info(
1126
- "Start unload Lora adapter. Lora name=%s",
1127
- obj.lora_name,
1128
- )
1163
+ assert (
1164
+ obj.lora_name is not None
1165
+ ), "lora_name must be provided to unload LoRA adapter"
1166
+
1167
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1168
+ # with dp_size > 1.
1169
+ assert (
1170
+ self.server_args.dp_size == 1
1171
+ ), "dp_size must be 1 for dynamic lora loading"
1172
+ logger.info(
1173
+ "Start unload Lora adapter. Lora name=%s",
1174
+ obj.lora_name,
1175
+ )
1129
1176
 
1130
- async with self.lora_update_lock:
1131
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1132
- # from being started.
1133
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1134
- obj.lora_id = lora_id
1177
+ async with self.lora_update_lock:
1178
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1179
+ # from being started.
1180
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1181
+ obj.lora_id = lora_id
1135
1182
 
1136
- # Initiate the actual unloading operation at the backend processes only after all
1137
- # ongoing requests using this LoRA adapter are finished.
1138
- await self.lora_registry.wait_for_unload(lora_id)
1139
- result = (await self.update_lora_adapter_communicator(obj))[0]
1183
+ # Initiate the actual unloading operation at the backend processes only after all
1184
+ # ongoing requests using this LoRA adapter are finished.
1185
+ await self.lora_registry.wait_for_unload(lora_id)
1186
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1140
1187
 
1141
- return result
1188
+ return result
1189
+ except ValueError as e:
1190
+ return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1142
1191
 
1143
1192
  async def get_weights_by_name(
1144
1193
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -1508,8 +1557,17 @@ class TokenizerManager:
1508
1557
 
1509
1558
  if isinstance(recv_obj, BatchStrOut):
1510
1559
  state.text += recv_obj.output_strs[i]
1560
+ if state.obj.stream:
1561
+ state.output_ids.extend(recv_obj.output_ids[i])
1562
+ output_token_ids = state.output_ids[state.last_output_offset :]
1563
+ state.last_output_offset = len(state.output_ids)
1564
+ else:
1565
+ state.output_ids.extend(recv_obj.output_ids[i])
1566
+ output_token_ids = state.output_ids.copy()
1567
+
1511
1568
  out_dict = {
1512
1569
  "text": state.text,
1570
+ "output_ids": output_token_ids,
1513
1571
  "meta_info": meta_info,
1514
1572
  }
1515
1573
  elif isinstance(recv_obj, BatchTokenIDOut):
@@ -1767,6 +1825,8 @@ class TokenizerManager:
1767
1825
  asyncio.create_task(asyncio.to_thread(background_task))
1768
1826
 
1769
1827
  def _handle_abort_req(self, recv_obj):
1828
+ if is_health_check_generate_req(recv_obj):
1829
+ return
1770
1830
  state = self.rid_to_state[recv_obj.rid]
1771
1831
  state.finished = True
1772
1832
  if recv_obj.finished_reason:
@@ -1901,6 +1961,16 @@ class TokenizerManager:
1901
1961
  return scores
1902
1962
 
1903
1963
 
1964
+ class ServerStatus(Enum):
1965
+ Up = "Up"
1966
+ Starting = "Starting"
1967
+ UnHealthy = "UnHealthy"
1968
+ Crashed = "Crashed"
1969
+
1970
+ def is_healthy(self) -> bool:
1971
+ return self == ServerStatus.Up
1972
+
1973
+
1904
1974
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
1905
1975
  is_cross_node = server_args.dist_init_addr
1906
1976
 
@@ -311,3 +311,6 @@ class TpModelWorker:
311
311
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
312
312
  result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
313
313
  return result
314
+
315
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
316
+ return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
288
288
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
289
289
  return self.worker.unload_lora_adapter(recv_req)
290
290
 
291
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
292
+ return self.worker.can_run_lora_batch(lora_ids)
293
+
291
294
  def __delete__(self):
292
295
  self.input_queue.put((None, None))
293
296
  self.copy_queue.put((None, None, None))
@@ -1,6 +1,7 @@
1
1
  import logging
2
+ import multiprocessing as mp
2
3
  from http import HTTPStatus
3
- from typing import Optional
4
+ from typing import Dict, List, Optional
4
5
 
5
6
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
7
 
@@ -38,3 +39,46 @@ def validate_input_length(
38
39
  return error_msg
39
40
 
40
41
  return None
42
+
43
+
44
+ class DPBalanceMeta:
45
+ """
46
+ This class will be use in scheduler and dp controller
47
+ """
48
+
49
+ def __init__(self, num_workers: int):
50
+ self.num_workers = num_workers
51
+ self._manager = mp.Manager()
52
+ self.mutex = self._manager.Lock()
53
+
54
+ init_local_tokens = [0] * self.num_workers
55
+ init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
56
+
57
+ self.shared_state = self._manager.Namespace()
58
+ self.shared_state.local_tokens = self._manager.list(init_local_tokens)
59
+ self.shared_state.onfly_info = self._manager.list(init_onfly_info)
60
+
61
+ def destructor(self):
62
+ # we must destructor this class manually
63
+ self._manager.shutdown()
64
+
65
+ def get_shared_onfly(self) -> List[Dict[int, int]]:
66
+ return [dict(d) for d in self.shared_state.onfly_info]
67
+
68
+ def set_shared_onfly_info(self, data: List[Dict[int, int]]):
69
+ self.shared_state.onfly_info = data
70
+
71
+ def get_shared_local_tokens(self) -> List[int]:
72
+ return list(self.shared_state.local_tokens)
73
+
74
+ def set_shared_local_tokens(self, data: List[int]):
75
+ self.shared_state.local_tokens = data
76
+
77
+ def __getstate__(self):
78
+ state = self.__dict__.copy()
79
+ del state["_manager"]
80
+ return state
81
+
82
+ def __setstate__(self, state):
83
+ self.__dict__.update(state)
84
+ self._manager = None
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch.utils.cpp_extension import load
8
+
9
+ _abs_path = os.path.dirname(os.path.abspath(__file__))
10
+ radix_tree_cpp = load(
11
+ name="radix_tree_cpp",
12
+ sources=[
13
+ f"{_abs_path}/tree_v2_binding.cpp",
14
+ f"{_abs_path}/tree_v2_debug.cpp",
15
+ f"{_abs_path}/tree_v2.cpp",
16
+ ],
17
+ extra_cflags=["-O3", "-std=c++20"],
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+
22
+ class TreeNodeCpp:
23
+ """
24
+ A placeholder for the TreeNode class. Cannot be constructed elsewhere.
25
+ """
26
+
27
+ class IOHandle:
28
+ """
29
+ A placeholder for the IOHandle class. Cannot be constructed elsewhere.
30
+ """
31
+
32
+ class RadixTreeCpp:
33
+ def __init__(
34
+ self,
35
+ disabled: bool,
36
+ host_size: Optional[int],
37
+ page_size: int,
38
+ write_through_threshold: int,
39
+ ):
40
+ """
41
+ Initializes the RadixTreeCpp instance.
42
+ Args:
43
+ disabled (bool): If True, the radix tree is disabled.
44
+ host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
45
+ page_size (int): Size of the page for the radix tree.
46
+ write_through_threshold (int): Threshold for writing through from GPU to CPU.
47
+ """
48
+ self.tree = radix_tree_cpp.RadixTree( # type: ignore
49
+ disabled, host_size, page_size, write_through_threshold
50
+ )
51
+
52
+ def match_prefix(
53
+ self, prefix: List[int]
54
+ ) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
55
+ """
56
+ Matches a prefix in the radix tree.
57
+ Args:
58
+ prefix (List[int]): The prefix to match.
59
+ Returns:
60
+ Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
61
+ 0. A list of indices that is matched by the prefix on the GPU.
62
+ 1. Sum length of the indices matched on the CPU.
63
+ 2. The last node of the prefix matched on the GPU.
64
+ 3. The last node of the prefix matched on the CPU.
65
+ """
66
+ return self.tree.match_prefix(prefix)
67
+
68
+ def evict(self, num_tokens: int) -> List[torch.Tensor]:
69
+ """
70
+ Evicts a number of tokens from the radix tree.
71
+ Args:
72
+ num_tokens (int): The number of tokens to evict.
73
+ Returns:
74
+ List[torch.Tensor]: A list of indices that were evicted.
75
+ """
76
+ return self.tree.evict(num_tokens)
77
+
78
+ def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
79
+ """
80
+ Locks or unlocks a reference to a tree node.
81
+ After locking, the node will not be evicted from the radix tree.
82
+ Args:
83
+ handle (TreeNodeCpp): The tree node to lock or unlock.
84
+ lock (bool): If True, locks the node; if False, unlocks it.
85
+ """
86
+ return self.tree.lock_ref(handle, lock)
87
+
88
+ def writing_through(
89
+ self, key: List[int], indices: torch.Tensor
90
+ ) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
91
+ """
92
+ Inserts a key-value pair into the radix tree and perform write-through check.
93
+ Args:
94
+ key (List[int]): The key to insert.
95
+ indices (torch.Tensor): The value associated with the key.
96
+ Returns:
97
+ Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
98
+ 0. A list of (IOHandle, device indices, host indices) tuples.
99
+ These IOhandles require write-through to the CPU in python side.
100
+ 1. The number of indices that are matched on device.
101
+ """
102
+ return self.tree.writing_through(key, indices)
103
+
104
+ def loading_onboard(
105
+ self,
106
+ host_node: TreeNodeCpp,
107
+ new_device_indices: torch.Tensor,
108
+ ) -> Tuple[IOHandle, List[torch.Tensor]]:
109
+ """
110
+ Updates the device indices of tree nodes within a range on the tree.
111
+ Args:
112
+ host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
113
+ new_device_indices (torch.Tensor): The new device indices to set.
114
+ The length of this tensor must be exactly host indices length.
115
+ Returns:
116
+ Tuple[IOHandle, List[torch.Tensor]]:
117
+ 0. An IOHandle that requires loading to the CPU in python side.
118
+ 1. A list of host indices corresponding to the new device indices.
119
+ """
120
+ return self.tree.loading_onboard(host_node, new_device_indices)
121
+
122
+ def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
123
+ """
124
+ Commits the write-through process for a tree node.
125
+ Args:
126
+ handle (IOHandle): The IOHandle to commit.
127
+ success (bool): If True, commits the write-through; if False, just indicates failure.
128
+ """
129
+ return self.tree.commit_writing_through(handle, success)
130
+
131
+ def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
132
+ """
133
+ Commits the load onboard process for tree nodes within a range on the tree.
134
+ Args:
135
+ handle (IOHandle): The IOHandle to commit.
136
+ success (bool): If True, commits the load-onboard; if False, just indicates failure.
137
+ """
138
+ return self.tree.commit_loading_onboard(handle, success)
139
+
140
+ def evictable_size(self) -> int:
141
+ """
142
+ Returns the size of the evictable part of the radix tree.
143
+ This is the size of the part that can be evicted from the GPU (ref_count = 0).
144
+ Returns:
145
+ int: The size of the evictable part.
146
+ """
147
+ return self.tree.evictable_size()
148
+
149
+ def protected_size(self) -> int:
150
+ """
151
+ Returns the size of the protected part of the radix tree.
152
+ This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
153
+ Returns:
154
+ int: The size of the protected part.
155
+ """
156
+ return self.tree.protected_size()
157
+
158
+ def total_size(self) -> int:
159
+ """
160
+ Returns the total size of the radix tree (including CPU nodes).
161
+ Returns:
162
+ int: The total size of the radix tree.
163
+ """
164
+ return self.tree.total_size()
165
+
166
+ def reset(self) -> None:
167
+ """
168
+ Resets the radix tree, clearing all nodes and indices.
169
+ """
170
+ return self.tree.reset()
171
+
172
+ def debug_print(self) -> None:
173
+ """
174
+ Prints the internal state of the radix tree for debugging purposes.
175
+ """
176
+ return self.tree.debug_print()
177
+
178
+ else:
179
+ # Real implementation of the classes for runtime
180
+ RadixTreeCpp = radix_tree_cpp.RadixTree
181
+ TreeNodeCpp = object
182
+ IOHandle = object
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
33
33
  It abstracts the underlying storage mechanism, allowing different implementations to be used.
34
34
  """
35
35
 
36
- # todo, translate tensor object access for different TP ranks
37
- # potentially pass model and TP configs into storage backend
36
+ # todo, potentially pass model and TP configs into storage backend
38
37
  # todo, the page size of storage backend does not have to be the same as the same as host memory pool
39
38
 
40
39
  @abstractmethod
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
117
116
  def get(
118
117
  self,
119
118
  key: str,
120
- target_location: Optional[Any] = None,
119
+ target_location: torch.Tensor,
121
120
  target_sizes: Optional[Any] = None,
122
121
  ) -> torch.Tensor | None:
123
122
  key = self._get_suffixed_key(key)
124
123
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
125
124
  try:
126
- if target_location is not None:
127
- # Load directly into target_location's memory buffer
128
- with open(tensor_path, "rb") as f:
129
- target_location.set_(
130
- torch.frombuffer(f.read(), dtype=target_location.dtype)
131
- .reshape(target_location.shape)
132
- .storage()
133
- )
134
- return target_location
135
- else:
136
- loaded_tensor = torch.load(tensor_path)
137
- if isinstance(loaded_tensor, torch.Tensor):
138
- return loaded_tensor
139
- else:
140
- logger.error(f"Loaded data for key {key} is not a tensor.")
141
- return None
125
+ # Load directly into target_location's memory buffer
126
+ with open(tensor_path, "rb") as f:
127
+ target_location.set_(
128
+ torch.frombuffer(f.read(), dtype=target_location.dtype)
129
+ .reshape(target_location.shape)
130
+ .untyped_storage()
131
+ )
132
+ return target_location
142
133
  except FileNotFoundError:
134
+ logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
143
135
  return None
144
136
 
145
137
  def batch_get(
146
138
  self,
147
139
  keys: List[str],
148
- target_locations: Optional[Any] = None,
140
+ target_locations: List[torch.Tensor],
149
141
  target_sizes: Optional[Any] = None,
150
142
  ) -> List[torch.Tensor | None]:
151
143
  return [
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
168
160
  logger.debug(f"Key {key} already exists. Skipped.")
169
161
  return True
170
162
  try:
171
- torch.save(value, tensor_path)
163
+ value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
172
164
  return True
173
165
  except Exception as e:
174
166
  logger.error(f"Failed to save tensor {key}: {e}")