sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ import uuid
|
|
29
29
|
from collections import deque
|
30
30
|
from contextlib import nullcontext
|
31
31
|
from datetime import datetime
|
32
|
+
from enum import Enum
|
32
33
|
from http import HTTPStatus
|
33
34
|
from typing import (
|
34
35
|
Any,
|
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
|
|
70
71
|
BatchMultimodalOut,
|
71
72
|
BatchStrOut,
|
72
73
|
BatchTokenIDOut,
|
73
|
-
BlockReqType,
|
74
74
|
CloseSessionReqInput,
|
75
75
|
ConfigureLoggingReq,
|
76
76
|
EmbeddingReqInput,
|
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
|
116
116
|
)
|
117
117
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
118
118
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
119
|
+
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
119
120
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
120
121
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
121
122
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -202,13 +203,29 @@ class TokenizerManager:
|
|
202
203
|
|
203
204
|
if self.model_config.is_multimodal:
|
204
205
|
import_processors()
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
206
|
+
try:
|
207
|
+
_processor = get_processor(
|
208
|
+
server_args.tokenizer_path,
|
209
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
210
|
+
trust_remote_code=server_args.trust_remote_code,
|
211
|
+
revision=server_args.revision,
|
212
|
+
use_fast=not server_args.disable_fast_image_processor,
|
213
|
+
)
|
214
|
+
except ValueError as e:
|
215
|
+
error_message = str(e)
|
216
|
+
if "does not have a slow version" in error_message:
|
217
|
+
logger.info(
|
218
|
+
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
|
219
|
+
)
|
220
|
+
_processor = get_processor(
|
221
|
+
server_args.tokenizer_path,
|
222
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
223
|
+
trust_remote_code=server_args.trust_remote_code,
|
224
|
+
revision=server_args.revision,
|
225
|
+
use_fast=True,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
raise e
|
212
229
|
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
213
230
|
|
214
231
|
# We want to parallelize the image pre-processing so we create an executor for it
|
@@ -225,10 +242,10 @@ class TokenizerManager:
|
|
225
242
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
226
243
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
227
244
|
else:
|
228
|
-
self.mm_processor = None
|
245
|
+
self.mm_processor = self.processor = None
|
229
246
|
|
230
247
|
if server_args.skip_tokenizer_init:
|
231
|
-
self.tokenizer =
|
248
|
+
self.tokenizer = None
|
232
249
|
else:
|
233
250
|
self.tokenizer = get_tokenizer(
|
234
251
|
server_args.tokenizer_path,
|
@@ -255,6 +272,7 @@ class TokenizerManager:
|
|
255
272
|
self.health_check_failed = False
|
256
273
|
self.gracefully_exit = False
|
257
274
|
self.last_receive_tstamp = 0
|
275
|
+
self.server_status = ServerStatus.Starting
|
258
276
|
|
259
277
|
# Dumping
|
260
278
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -538,7 +556,7 @@ class TokenizerManager:
|
|
538
556
|
if self.server_args.enable_lora and obj.lora_path:
|
539
557
|
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
540
558
|
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
541
|
-
obj.
|
559
|
+
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
542
560
|
|
543
561
|
self._validate_one_request(obj, input_ids)
|
544
562
|
return self._create_tokenized_object(
|
@@ -647,7 +665,7 @@ class TokenizerManager:
|
|
647
665
|
bootstrap_host=obj.bootstrap_host,
|
648
666
|
bootstrap_port=obj.bootstrap_port,
|
649
667
|
bootstrap_room=obj.bootstrap_room,
|
650
|
-
|
668
|
+
lora_id=obj.lora_id,
|
651
669
|
input_embeds=input_embeds,
|
652
670
|
session_params=session_params,
|
653
671
|
custom_logit_processor=obj.custom_logit_processor,
|
@@ -732,7 +750,11 @@ class TokenizerManager:
|
|
732
750
|
try:
|
733
751
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
734
752
|
except asyncio.TimeoutError:
|
735
|
-
if
|
753
|
+
if (
|
754
|
+
request is not None
|
755
|
+
and not obj.background
|
756
|
+
and await request.is_disconnected()
|
757
|
+
):
|
736
758
|
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
737
759
|
self.abort_request(obj.rid)
|
738
760
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -755,7 +777,7 @@ class TokenizerManager:
|
|
755
777
|
|
756
778
|
# Mark ongoing LoRA request as finished.
|
757
779
|
if self.server_args.enable_lora and obj.lora_path:
|
758
|
-
await self.lora_registry.release(obj.
|
780
|
+
await self.lora_registry.release(obj.lora_id)
|
759
781
|
|
760
782
|
# Check if this was an abort/error created by scheduler
|
761
783
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
@@ -787,7 +809,11 @@ class TokenizerManager:
|
|
787
809
|
if obj.stream:
|
788
810
|
yield out
|
789
811
|
else:
|
790
|
-
if
|
812
|
+
if (
|
813
|
+
request is not None
|
814
|
+
and not obj.background
|
815
|
+
and await request.is_disconnected()
|
816
|
+
):
|
791
817
|
# Abort the request for disconnected requests (non-streaming, running)
|
792
818
|
self.abort_request(obj.rid)
|
793
819
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -1069,38 +1095,57 @@ class TokenizerManager:
|
|
1069
1095
|
_: Optional[fastapi.Request] = None,
|
1070
1096
|
) -> LoadLoRAAdapterReqOutput:
|
1071
1097
|
self.auto_create_handle_loop()
|
1072
|
-
if not self.server_args.enable_lora:
|
1073
|
-
raise ValueError(
|
1074
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1075
|
-
)
|
1076
1098
|
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
logger.info(
|
1083
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1084
|
-
obj.lora_name,
|
1085
|
-
obj.lora_path,
|
1086
|
-
)
|
1099
|
+
try:
|
1100
|
+
if not self.server_args.enable_lora:
|
1101
|
+
raise ValueError(
|
1102
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1103
|
+
)
|
1087
1104
|
|
1088
|
-
|
1089
|
-
#
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1105
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1106
|
+
# with dp_size > 1.
|
1107
|
+
assert (
|
1108
|
+
self.server_args.dp_size == 1
|
1109
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1110
|
+
logger.info(
|
1111
|
+
"Start load Lora adapter. Lora name=%s, path=%s",
|
1112
|
+
obj.lora_name,
|
1113
|
+
obj.lora_path,
|
1093
1114
|
)
|
1094
1115
|
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1116
|
+
async with self.lora_update_lock:
|
1117
|
+
if (
|
1118
|
+
self.server_args.max_loaded_loras is not None
|
1119
|
+
and self.lora_registry.num_registered_loras
|
1120
|
+
>= self.server_args.max_loaded_loras
|
1121
|
+
):
|
1122
|
+
raise ValueError(
|
1123
|
+
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1124
|
+
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1125
|
+
"Please unload some LoRA adapters before loading new ones."
|
1126
|
+
)
|
1098
1127
|
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1128
|
+
# Generate new uniquely identifiable LoRARef object.
|
1129
|
+
new_adapter = LoRARef(
|
1130
|
+
lora_name=obj.lora_name,
|
1131
|
+
lora_path=obj.lora_path,
|
1132
|
+
pinned=obj.pinned,
|
1133
|
+
)
|
1102
1134
|
|
1103
|
-
|
1135
|
+
# Trigger the actual loading operation at the backend processes.
|
1136
|
+
obj.lora_id = new_adapter.lora_id
|
1137
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1138
|
+
|
1139
|
+
# Register the LoRA adapter only after loading is successful.
|
1140
|
+
if result.success:
|
1141
|
+
await self.lora_registry.register(new_adapter)
|
1142
|
+
|
1143
|
+
return result
|
1144
|
+
except ValueError as e:
|
1145
|
+
return LoadLoRAAdapterReqOutput(
|
1146
|
+
success=False,
|
1147
|
+
error_message=str(e),
|
1148
|
+
)
|
1104
1149
|
|
1105
1150
|
async def unload_lora_adapter(
|
1106
1151
|
self,
|
@@ -1108,37 +1153,41 @@ class TokenizerManager:
|
|
1108
1153
|
_: Optional[fastapi.Request] = None,
|
1109
1154
|
) -> UnloadLoRAAdapterReqOutput:
|
1110
1155
|
self.auto_create_handle_loop()
|
1111
|
-
if not self.server_args.enable_lora:
|
1112
|
-
raise ValueError(
|
1113
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1114
|
-
)
|
1115
1156
|
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1157
|
+
try:
|
1158
|
+
if not self.server_args.enable_lora:
|
1159
|
+
raise ValueError(
|
1160
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1161
|
+
)
|
1119
1162
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1163
|
+
assert (
|
1164
|
+
obj.lora_name is not None
|
1165
|
+
), "lora_name must be provided to unload LoRA adapter"
|
1166
|
+
|
1167
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1168
|
+
# with dp_size > 1.
|
1169
|
+
assert (
|
1170
|
+
self.server_args.dp_size == 1
|
1171
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1172
|
+
logger.info(
|
1173
|
+
"Start unload Lora adapter. Lora name=%s",
|
1174
|
+
obj.lora_name,
|
1175
|
+
)
|
1129
1176
|
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1177
|
+
async with self.lora_update_lock:
|
1178
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1179
|
+
# from being started.
|
1180
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1181
|
+
obj.lora_id = lora_id
|
1135
1182
|
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1183
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1184
|
+
# ongoing requests using this LoRA adapter are finished.
|
1185
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1186
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1140
1187
|
|
1141
|
-
|
1188
|
+
return result
|
1189
|
+
except ValueError as e:
|
1190
|
+
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1142
1191
|
|
1143
1192
|
async def get_weights_by_name(
|
1144
1193
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -1508,8 +1557,17 @@ class TokenizerManager:
|
|
1508
1557
|
|
1509
1558
|
if isinstance(recv_obj, BatchStrOut):
|
1510
1559
|
state.text += recv_obj.output_strs[i]
|
1560
|
+
if state.obj.stream:
|
1561
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1562
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1563
|
+
state.last_output_offset = len(state.output_ids)
|
1564
|
+
else:
|
1565
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1566
|
+
output_token_ids = state.output_ids.copy()
|
1567
|
+
|
1511
1568
|
out_dict = {
|
1512
1569
|
"text": state.text,
|
1570
|
+
"output_ids": output_token_ids,
|
1513
1571
|
"meta_info": meta_info,
|
1514
1572
|
}
|
1515
1573
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
@@ -1767,6 +1825,8 @@ class TokenizerManager:
|
|
1767
1825
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1768
1826
|
|
1769
1827
|
def _handle_abort_req(self, recv_obj):
|
1828
|
+
if is_health_check_generate_req(recv_obj):
|
1829
|
+
return
|
1770
1830
|
state = self.rid_to_state[recv_obj.rid]
|
1771
1831
|
state.finished = True
|
1772
1832
|
if recv_obj.finished_reason:
|
@@ -1901,6 +1961,16 @@ class TokenizerManager:
|
|
1901
1961
|
return scores
|
1902
1962
|
|
1903
1963
|
|
1964
|
+
class ServerStatus(Enum):
|
1965
|
+
Up = "Up"
|
1966
|
+
Starting = "Starting"
|
1967
|
+
UnHealthy = "UnHealthy"
|
1968
|
+
Crashed = "Crashed"
|
1969
|
+
|
1970
|
+
def is_healthy(self) -> bool:
|
1971
|
+
return self == ServerStatus.Up
|
1972
|
+
|
1973
|
+
|
1904
1974
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
1905
1975
|
is_cross_node = server_args.dist_init_addr
|
1906
1976
|
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -311,3 +311,6 @@ class TpModelWorker:
|
|
311
311
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
312
312
|
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
313
313
|
return result
|
314
|
+
|
315
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
316
|
+
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
|
|
288
288
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
289
289
|
return self.worker.unload_lora_adapter(recv_req)
|
290
290
|
|
291
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
292
|
+
return self.worker.can_run_lora_batch(lora_ids)
|
293
|
+
|
291
294
|
def __delete__(self):
|
292
295
|
self.input_queue.put((None, None))
|
293
296
|
self.copy_queue.put((None, None, None))
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import logging
|
2
|
+
import multiprocessing as mp
|
2
3
|
from http import HTTPStatus
|
3
|
-
from typing import Optional
|
4
|
+
from typing import Dict, List, Optional
|
4
5
|
|
5
6
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
7
|
|
@@ -38,3 +39,46 @@ def validate_input_length(
|
|
38
39
|
return error_msg
|
39
40
|
|
40
41
|
return None
|
42
|
+
|
43
|
+
|
44
|
+
class DPBalanceMeta:
|
45
|
+
"""
|
46
|
+
This class will be use in scheduler and dp controller
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, num_workers: int):
|
50
|
+
self.num_workers = num_workers
|
51
|
+
self._manager = mp.Manager()
|
52
|
+
self.mutex = self._manager.Lock()
|
53
|
+
|
54
|
+
init_local_tokens = [0] * self.num_workers
|
55
|
+
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
56
|
+
|
57
|
+
self.shared_state = self._manager.Namespace()
|
58
|
+
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
59
|
+
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
60
|
+
|
61
|
+
def destructor(self):
|
62
|
+
# we must destructor this class manually
|
63
|
+
self._manager.shutdown()
|
64
|
+
|
65
|
+
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
66
|
+
return [dict(d) for d in self.shared_state.onfly_info]
|
67
|
+
|
68
|
+
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
69
|
+
self.shared_state.onfly_info = data
|
70
|
+
|
71
|
+
def get_shared_local_tokens(self) -> List[int]:
|
72
|
+
return list(self.shared_state.local_tokens)
|
73
|
+
|
74
|
+
def set_shared_local_tokens(self, data: List[int]):
|
75
|
+
self.shared_state.local_tokens = data
|
76
|
+
|
77
|
+
def __getstate__(self):
|
78
|
+
state = self.__dict__.copy()
|
79
|
+
del state["_manager"]
|
80
|
+
return state
|
81
|
+
|
82
|
+
def __setstate__(self, state):
|
83
|
+
self.__dict__.update(state)
|
84
|
+
self._manager = None
|
@@ -0,0 +1,182 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.utils.cpp_extension import load
|
8
|
+
|
9
|
+
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
10
|
+
radix_tree_cpp = load(
|
11
|
+
name="radix_tree_cpp",
|
12
|
+
sources=[
|
13
|
+
f"{_abs_path}/tree_v2_binding.cpp",
|
14
|
+
f"{_abs_path}/tree_v2_debug.cpp",
|
15
|
+
f"{_abs_path}/tree_v2.cpp",
|
16
|
+
],
|
17
|
+
extra_cflags=["-O3", "-std=c++20"],
|
18
|
+
)
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
|
22
|
+
class TreeNodeCpp:
|
23
|
+
"""
|
24
|
+
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
|
25
|
+
"""
|
26
|
+
|
27
|
+
class IOHandle:
|
28
|
+
"""
|
29
|
+
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
|
30
|
+
"""
|
31
|
+
|
32
|
+
class RadixTreeCpp:
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
disabled: bool,
|
36
|
+
host_size: Optional[int],
|
37
|
+
page_size: int,
|
38
|
+
write_through_threshold: int,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Initializes the RadixTreeCpp instance.
|
42
|
+
Args:
|
43
|
+
disabled (bool): If True, the radix tree is disabled.
|
44
|
+
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
|
45
|
+
page_size (int): Size of the page for the radix tree.
|
46
|
+
write_through_threshold (int): Threshold for writing through from GPU to CPU.
|
47
|
+
"""
|
48
|
+
self.tree = radix_tree_cpp.RadixTree( # type: ignore
|
49
|
+
disabled, host_size, page_size, write_through_threshold
|
50
|
+
)
|
51
|
+
|
52
|
+
def match_prefix(
|
53
|
+
self, prefix: List[int]
|
54
|
+
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
|
55
|
+
"""
|
56
|
+
Matches a prefix in the radix tree.
|
57
|
+
Args:
|
58
|
+
prefix (List[int]): The prefix to match.
|
59
|
+
Returns:
|
60
|
+
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
|
61
|
+
0. A list of indices that is matched by the prefix on the GPU.
|
62
|
+
1. Sum length of the indices matched on the CPU.
|
63
|
+
2. The last node of the prefix matched on the GPU.
|
64
|
+
3. The last node of the prefix matched on the CPU.
|
65
|
+
"""
|
66
|
+
return self.tree.match_prefix(prefix)
|
67
|
+
|
68
|
+
def evict(self, num_tokens: int) -> List[torch.Tensor]:
|
69
|
+
"""
|
70
|
+
Evicts a number of tokens from the radix tree.
|
71
|
+
Args:
|
72
|
+
num_tokens (int): The number of tokens to evict.
|
73
|
+
Returns:
|
74
|
+
List[torch.Tensor]: A list of indices that were evicted.
|
75
|
+
"""
|
76
|
+
return self.tree.evict(num_tokens)
|
77
|
+
|
78
|
+
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
|
79
|
+
"""
|
80
|
+
Locks or unlocks a reference to a tree node.
|
81
|
+
After locking, the node will not be evicted from the radix tree.
|
82
|
+
Args:
|
83
|
+
handle (TreeNodeCpp): The tree node to lock or unlock.
|
84
|
+
lock (bool): If True, locks the node; if False, unlocks it.
|
85
|
+
"""
|
86
|
+
return self.tree.lock_ref(handle, lock)
|
87
|
+
|
88
|
+
def writing_through(
|
89
|
+
self, key: List[int], indices: torch.Tensor
|
90
|
+
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
91
|
+
"""
|
92
|
+
Inserts a key-value pair into the radix tree and perform write-through check.
|
93
|
+
Args:
|
94
|
+
key (List[int]): The key to insert.
|
95
|
+
indices (torch.Tensor): The value associated with the key.
|
96
|
+
Returns:
|
97
|
+
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
98
|
+
0. A list of (IOHandle, device indices, host indices) tuples.
|
99
|
+
These IOhandles require write-through to the CPU in python side.
|
100
|
+
1. The number of indices that are matched on device.
|
101
|
+
"""
|
102
|
+
return self.tree.writing_through(key, indices)
|
103
|
+
|
104
|
+
def loading_onboard(
|
105
|
+
self,
|
106
|
+
host_node: TreeNodeCpp,
|
107
|
+
new_device_indices: torch.Tensor,
|
108
|
+
) -> Tuple[IOHandle, List[torch.Tensor]]:
|
109
|
+
"""
|
110
|
+
Updates the device indices of tree nodes within a range on the tree.
|
111
|
+
Args:
|
112
|
+
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
|
113
|
+
new_device_indices (torch.Tensor): The new device indices to set.
|
114
|
+
The length of this tensor must be exactly host indices length.
|
115
|
+
Returns:
|
116
|
+
Tuple[IOHandle, List[torch.Tensor]]:
|
117
|
+
0. An IOHandle that requires loading to the CPU in python side.
|
118
|
+
1. A list of host indices corresponding to the new device indices.
|
119
|
+
"""
|
120
|
+
return self.tree.loading_onboard(host_node, new_device_indices)
|
121
|
+
|
122
|
+
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
|
123
|
+
"""
|
124
|
+
Commits the write-through process for a tree node.
|
125
|
+
Args:
|
126
|
+
handle (IOHandle): The IOHandle to commit.
|
127
|
+
success (bool): If True, commits the write-through; if False, just indicates failure.
|
128
|
+
"""
|
129
|
+
return self.tree.commit_writing_through(handle, success)
|
130
|
+
|
131
|
+
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
|
132
|
+
"""
|
133
|
+
Commits the load onboard process for tree nodes within a range on the tree.
|
134
|
+
Args:
|
135
|
+
handle (IOHandle): The IOHandle to commit.
|
136
|
+
success (bool): If True, commits the load-onboard; if False, just indicates failure.
|
137
|
+
"""
|
138
|
+
return self.tree.commit_loading_onboard(handle, success)
|
139
|
+
|
140
|
+
def evictable_size(self) -> int:
|
141
|
+
"""
|
142
|
+
Returns the size of the evictable part of the radix tree.
|
143
|
+
This is the size of the part that can be evicted from the GPU (ref_count = 0).
|
144
|
+
Returns:
|
145
|
+
int: The size of the evictable part.
|
146
|
+
"""
|
147
|
+
return self.tree.evictable_size()
|
148
|
+
|
149
|
+
def protected_size(self) -> int:
|
150
|
+
"""
|
151
|
+
Returns the size of the protected part of the radix tree.
|
152
|
+
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
|
153
|
+
Returns:
|
154
|
+
int: The size of the protected part.
|
155
|
+
"""
|
156
|
+
return self.tree.protected_size()
|
157
|
+
|
158
|
+
def total_size(self) -> int:
|
159
|
+
"""
|
160
|
+
Returns the total size of the radix tree (including CPU nodes).
|
161
|
+
Returns:
|
162
|
+
int: The total size of the radix tree.
|
163
|
+
"""
|
164
|
+
return self.tree.total_size()
|
165
|
+
|
166
|
+
def reset(self) -> None:
|
167
|
+
"""
|
168
|
+
Resets the radix tree, clearing all nodes and indices.
|
169
|
+
"""
|
170
|
+
return self.tree.reset()
|
171
|
+
|
172
|
+
def debug_print(self) -> None:
|
173
|
+
"""
|
174
|
+
Prints the internal state of the radix tree for debugging purposes.
|
175
|
+
"""
|
176
|
+
return self.tree.debug_print()
|
177
|
+
|
178
|
+
else:
|
179
|
+
# Real implementation of the classes for runtime
|
180
|
+
RadixTreeCpp = radix_tree_cpp.RadixTree
|
181
|
+
TreeNodeCpp = object
|
182
|
+
IOHandle = object
|
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
|
|
33
33
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
34
34
|
"""
|
35
35
|
|
36
|
-
# todo,
|
37
|
-
# potentially pass model and TP configs into storage backend
|
36
|
+
# todo, potentially pass model and TP configs into storage backend
|
38
37
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
39
38
|
|
40
39
|
@abstractmethod
|
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
|
|
117
116
|
def get(
|
118
117
|
self,
|
119
118
|
key: str,
|
120
|
-
target_location:
|
119
|
+
target_location: torch.Tensor,
|
121
120
|
target_sizes: Optional[Any] = None,
|
122
121
|
) -> torch.Tensor | None:
|
123
122
|
key = self._get_suffixed_key(key)
|
124
123
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
125
124
|
try:
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
target_location.
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
return target_location
|
135
|
-
else:
|
136
|
-
loaded_tensor = torch.load(tensor_path)
|
137
|
-
if isinstance(loaded_tensor, torch.Tensor):
|
138
|
-
return loaded_tensor
|
139
|
-
else:
|
140
|
-
logger.error(f"Loaded data for key {key} is not a tensor.")
|
141
|
-
return None
|
125
|
+
# Load directly into target_location's memory buffer
|
126
|
+
with open(tensor_path, "rb") as f:
|
127
|
+
target_location.set_(
|
128
|
+
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
129
|
+
.reshape(target_location.shape)
|
130
|
+
.untyped_storage()
|
131
|
+
)
|
132
|
+
return target_location
|
142
133
|
except FileNotFoundError:
|
134
|
+
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
143
135
|
return None
|
144
136
|
|
145
137
|
def batch_get(
|
146
138
|
self,
|
147
139
|
keys: List[str],
|
148
|
-
target_locations:
|
140
|
+
target_locations: List[torch.Tensor],
|
149
141
|
target_sizes: Optional[Any] = None,
|
150
142
|
) -> List[torch.Tensor | None]:
|
151
143
|
return [
|
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
|
|
168
160
|
logger.debug(f"Key {key} already exists. Skipped.")
|
169
161
|
return True
|
170
162
|
try:
|
171
|
-
torch.
|
163
|
+
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
172
164
|
return True
|
173
165
|
except Exception as e:
|
174
166
|
logger.error(f"Failed to save tensor {key}: {e}")
|