sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -67,6 +67,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
67
67
|
from sglang.srt.layers.moe import initialize_moe_config
|
68
68
|
from sglang.srt.managers.io_struct import (
|
69
69
|
AbortReq,
|
70
|
+
BatchTokenizedEmbeddingReqInput,
|
71
|
+
BatchTokenizedGenerateReqInput,
|
72
|
+
ClearHiCacheReqInput,
|
73
|
+
ClearHiCacheReqOutput,
|
70
74
|
CloseSessionReqInput,
|
71
75
|
ExpertDistributionReq,
|
72
76
|
ExpertDistributionReqOutput,
|
@@ -80,6 +84,8 @@ from sglang.srt.managers.io_struct import (
|
|
80
84
|
InitWeightsUpdateGroupReqInput,
|
81
85
|
LoadLoRAAdapterReqInput,
|
82
86
|
LoadLoRAAdapterReqOutput,
|
87
|
+
MultiTokenizerRegisterReq,
|
88
|
+
MultiTokenizerWarpper,
|
83
89
|
OpenSessionReqInput,
|
84
90
|
OpenSessionReqOutput,
|
85
91
|
ProfileReq,
|
@@ -253,7 +259,6 @@ class Scheduler(
|
|
253
259
|
# Init inter-process communication
|
254
260
|
context = zmq.Context(2)
|
255
261
|
self.idle_sleeper = None
|
256
|
-
|
257
262
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
258
263
|
self.recv_from_tokenizer = get_zmq_socket(
|
259
264
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
@@ -510,7 +515,10 @@ class Scheduler(
|
|
510
515
|
[
|
511
516
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
512
517
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
518
|
+
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
519
|
+
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
513
520
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
521
|
+
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
|
514
522
|
(AbortReq, self.abort_request),
|
515
523
|
(OpenSessionReqInput, self.open_session),
|
516
524
|
(CloseSessionReqInput, self.close_session),
|
@@ -533,6 +541,7 @@ class Scheduler(
|
|
533
541
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
534
542
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
535
543
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
544
|
+
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
536
545
|
]
|
537
546
|
)
|
538
547
|
|
@@ -623,6 +632,8 @@ class Scheduler(
|
|
623
632
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
624
633
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
625
634
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
635
|
+
model_name=server_args.served_model_name,
|
636
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
626
637
|
)
|
627
638
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
628
639
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -1018,14 +1029,26 @@ class Scheduler(
|
|
1018
1029
|
req
|
1019
1030
|
for req in recv_reqs
|
1020
1031
|
if isinstance(
|
1021
|
-
req,
|
1032
|
+
req,
|
1033
|
+
(
|
1034
|
+
TokenizedGenerateReqInput,
|
1035
|
+
TokenizedEmbeddingReqInput,
|
1036
|
+
BatchTokenizedGenerateReqInput,
|
1037
|
+
BatchTokenizedEmbeddingReqInput,
|
1038
|
+
),
|
1022
1039
|
)
|
1023
1040
|
]
|
1024
1041
|
control_reqs = [
|
1025
1042
|
req
|
1026
1043
|
for req in recv_reqs
|
1027
1044
|
if not isinstance(
|
1028
|
-
req,
|
1045
|
+
req,
|
1046
|
+
(
|
1047
|
+
TokenizedGenerateReqInput,
|
1048
|
+
TokenizedEmbeddingReqInput,
|
1049
|
+
BatchTokenizedGenerateReqInput,
|
1050
|
+
BatchTokenizedEmbeddingReqInput,
|
1051
|
+
),
|
1029
1052
|
)
|
1030
1053
|
]
|
1031
1054
|
else:
|
@@ -1080,6 +1103,17 @@ class Scheduler(
|
|
1080
1103
|
)
|
1081
1104
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1082
1105
|
continue
|
1106
|
+
|
1107
|
+
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1108
|
+
if isinstance(recv_req, MultiTokenizerWarpper):
|
1109
|
+
worker_id = recv_req.worker_id
|
1110
|
+
recv_req = recv_req.obj
|
1111
|
+
output = self._request_dispatcher(recv_req)
|
1112
|
+
if output is not None:
|
1113
|
+
output = MultiTokenizerWarpper(worker_id, output)
|
1114
|
+
self.send_to_tokenizer.send_pyobj(output)
|
1115
|
+
continue
|
1116
|
+
|
1083
1117
|
output = self._request_dispatcher(recv_req)
|
1084
1118
|
if output is not None:
|
1085
1119
|
if isinstance(output, RpcReqOutput):
|
@@ -1253,6 +1287,17 @@ class Scheduler(
|
|
1253
1287
|
else:
|
1254
1288
|
self._add_request_to_queue(req)
|
1255
1289
|
|
1290
|
+
def handle_batch_generate_request(
|
1291
|
+
self,
|
1292
|
+
recv_req: BatchTokenizedGenerateReqInput,
|
1293
|
+
):
|
1294
|
+
"""Handle optimized batch generate request."""
|
1295
|
+
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
|
1296
|
+
|
1297
|
+
# Process each request in the batch
|
1298
|
+
for tokenized_req in recv_req:
|
1299
|
+
self.handle_generate_request(tokenized_req)
|
1300
|
+
|
1256
1301
|
def _add_request_to_queue(self, req: Req):
|
1257
1302
|
req.queue_time_start = time.perf_counter()
|
1258
1303
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -1269,10 +1314,11 @@ class Scheduler(
|
|
1269
1314
|
def _prefetch_kvcache(self, req: Req):
|
1270
1315
|
if self.enable_hicache_storage:
|
1271
1316
|
req.init_next_round_input(self.tree_cache)
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1317
|
+
if req.last_node.backuped:
|
1318
|
+
# only to initiate the prefetch if the last node is backuped
|
1319
|
+
# otherwise, the allocated GPU memory must be locked for integrity
|
1320
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1321
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1276
1322
|
new_input_tokens = req.fill_ids[matched_len:]
|
1277
1323
|
self.tree_cache.prefetch_from_storage(
|
1278
1324
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
@@ -1335,6 +1381,19 @@ class Scheduler(
|
|
1335
1381
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1336
1382
|
self._add_request_to_queue(req)
|
1337
1383
|
|
1384
|
+
def handle_batch_embedding_request(
|
1385
|
+
self,
|
1386
|
+
recv_req: BatchTokenizedEmbeddingReqInput,
|
1387
|
+
):
|
1388
|
+
"""Handle optimized batch embedding request."""
|
1389
|
+
logger.debug(
|
1390
|
+
f"Processing batch embedding request with {len(recv_req)} requests"
|
1391
|
+
)
|
1392
|
+
|
1393
|
+
# Process each request in the batch
|
1394
|
+
for tokenized_req in recv_req:
|
1395
|
+
self.handle_embedding_request(tokenized_req)
|
1396
|
+
|
1338
1397
|
def self_check_during_idle(self):
|
1339
1398
|
self.check_memory()
|
1340
1399
|
self.check_tree_cache()
|
@@ -1460,7 +1519,7 @@ class Scheduler(
|
|
1460
1519
|
# Move the chunked request out of the batch so that we can merge
|
1461
1520
|
# only finished requests to running_batch.
|
1462
1521
|
chunked_req_to_exclude.add(self.chunked_req)
|
1463
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1522
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1464
1523
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1465
1524
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1466
1525
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
@@ -2164,6 +2223,16 @@ class Scheduler(
|
|
2164
2223
|
success = self.flush_cache()
|
2165
2224
|
return FlushCacheReqOutput(success=success)
|
2166
2225
|
|
2226
|
+
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
|
2227
|
+
if self.enable_hierarchical_cache:
|
2228
|
+
self.tree_cache.clear_storage_backend()
|
2229
|
+
logger.info("Hierarchical cache cleared successfully!")
|
2230
|
+
if_success = True
|
2231
|
+
else:
|
2232
|
+
logging.warning("Hierarchical cache is not enabled.")
|
2233
|
+
if_success = False
|
2234
|
+
return ClearHiCacheReqOutput(success=if_success)
|
2235
|
+
|
2167
2236
|
def flush_cache(self):
|
2168
2237
|
"""Flush the memory pool and cache."""
|
2169
2238
|
if (
|
@@ -2335,6 +2404,10 @@ class Scheduler(
|
|
2335
2404
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2336
2405
|
req = self.waiting_queue.pop(i)
|
2337
2406
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2407
|
+
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2408
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2409
|
+
self.tree_cache.cache_finished_req(req)
|
2410
|
+
|
2338
2411
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2339
2412
|
|
2340
2413
|
# Delete the requests in the grammar queue
|
@@ -2414,6 +2487,10 @@ class Scheduler(
|
|
2414
2487
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2415
2488
|
return result
|
2416
2489
|
|
2490
|
+
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
2491
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2492
|
+
return recv_req
|
2493
|
+
|
2417
2494
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2418
2495
|
t = recv_req.forward_sleep_time
|
2419
2496
|
if t is not None and t <= 0:
|
@@ -2513,7 +2590,15 @@ def is_health_check_generate_req(recv_req):
|
|
2513
2590
|
|
2514
2591
|
|
2515
2592
|
def is_work_request(recv_req):
|
2516
|
-
return isinstance(
|
2593
|
+
return isinstance(
|
2594
|
+
recv_req,
|
2595
|
+
(
|
2596
|
+
TokenizedGenerateReqInput,
|
2597
|
+
TokenizedEmbeddingReqInput,
|
2598
|
+
BatchTokenizedGenerateReqInput,
|
2599
|
+
BatchTokenizedEmbeddingReqInput,
|
2600
|
+
),
|
2601
|
+
)
|
2517
2602
|
|
2518
2603
|
|
2519
2604
|
def run_scheduler_process(
|
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
|
|
93
93
|
# This updates radix so others can match
|
94
94
|
self.tree_cache.cache_unfinished_req(req)
|
95
95
|
|
96
|
-
if
|
96
|
+
if batch.return_logprob:
|
97
97
|
assert extend_logprob_start_len_per_req is not None
|
98
98
|
assert extend_input_len_per_req is not None
|
99
99
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
100
|
extend_input_len = extend_input_len_per_req[i]
|
101
101
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
102
|
+
if req.return_logprob:
|
103
|
+
self.add_logprob_return_values(
|
104
|
+
i,
|
105
|
+
req,
|
106
|
+
logprob_pt,
|
107
|
+
next_token_ids,
|
108
|
+
num_input_logprobs,
|
109
|
+
logits_output,
|
110
|
+
)
|
110
111
|
logprob_pt += num_input_logprobs
|
111
112
|
|
112
113
|
if (
|
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
|
|
146
147
|
skip_stream_req = req
|
147
148
|
|
148
149
|
# Incrementally update input logprobs.
|
149
|
-
if
|
150
|
+
if batch.return_logprob:
|
150
151
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
151
152
|
extend_input_len = extend_input_len_per_req[i]
|
152
153
|
if extend_logprob_start_len < extend_input_len:
|
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
|
|
154
155
|
num_input_logprobs = (
|
155
156
|
extend_input_len - extend_logprob_start_len
|
156
157
|
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
158
|
+
if req.return_logprob:
|
159
|
+
self.add_input_logprob_return_values(
|
160
|
+
i,
|
161
|
+
req,
|
162
|
+
logits_output,
|
163
|
+
logprob_pt,
|
164
|
+
num_input_logprobs,
|
165
|
+
last_prefill_chunk=False,
|
166
|
+
)
|
165
167
|
logprob_pt += num_input_logprobs
|
166
168
|
|
167
169
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
|
|
121
121
|
url = params["url"]
|
122
122
|
|
123
123
|
worker = self.tp_worker.worker
|
124
|
-
|
125
124
|
worker.model_runner.save_remote_model(url)
|
126
125
|
|
126
|
+
if self.draft_worker is not None:
|
127
|
+
draft_url = params.get("draft_url", None)
|
128
|
+
assert (
|
129
|
+
draft_url is not None
|
130
|
+
), "draft_url must be provided when draft model is enabled"
|
131
|
+
draft_worker = self.draft_worker.worker
|
132
|
+
draft_worker.model_runner.save_remote_model(draft_url)
|
133
|
+
|
127
134
|
def save_sharded_model(self, params):
|
128
135
|
worker = self.tp_worker.worker
|
129
136
|
|
@@ -71,6 +71,10 @@ from sglang.srt.managers.io_struct import (
|
|
71
71
|
BatchMultimodalOut,
|
72
72
|
BatchStrOut,
|
73
73
|
BatchTokenIDOut,
|
74
|
+
BatchTokenizedEmbeddingReqInput,
|
75
|
+
BatchTokenizedGenerateReqInput,
|
76
|
+
ClearHiCacheReqInput,
|
77
|
+
ClearHiCacheReqOutput,
|
74
78
|
CloseSessionReqInput,
|
75
79
|
ConfigureLoggingReq,
|
76
80
|
EmbeddingReqInput,
|
@@ -90,6 +94,7 @@ from sglang.srt.managers.io_struct import (
|
|
90
94
|
LoadLoRAAdapterReqInput,
|
91
95
|
LoadLoRAAdapterReqOutput,
|
92
96
|
LoRAUpdateResult,
|
97
|
+
MultiTokenizerWarpper,
|
93
98
|
OpenSessionReqInput,
|
94
99
|
OpenSessionReqOutput,
|
95
100
|
ProfileReq,
|
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
|
|
127
132
|
dataclass_to_string_truncated,
|
128
133
|
freeze_gc,
|
129
134
|
get_bool_env_var,
|
135
|
+
get_origin_rid,
|
130
136
|
get_zmq_socket,
|
131
137
|
kill_process_tree,
|
132
138
|
)
|
@@ -262,9 +268,15 @@ class TokenizerManager:
|
|
262
268
|
self.recv_from_detokenizer = get_zmq_socket(
|
263
269
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
264
270
|
)
|
265
|
-
self.
|
266
|
-
|
267
|
-
|
271
|
+
if self.server_args.tokenizer_worker_num > 1:
|
272
|
+
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
273
|
+
self.send_to_scheduler = get_zmq_socket(
|
274
|
+
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
self.send_to_scheduler = get_zmq_socket(
|
278
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
279
|
+
)
|
268
280
|
|
269
281
|
# Request states
|
270
282
|
self.no_create_loop = False
|
@@ -308,35 +320,7 @@ class TokenizerManager:
|
|
308
320
|
self.lora_update_lock = asyncio.Lock()
|
309
321
|
|
310
322
|
# For PD disaggregtion
|
311
|
-
self.
|
312
|
-
self.server_args.disaggregation_mode
|
313
|
-
)
|
314
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
315
|
-
self.server_args.disaggregation_transfer_backend
|
316
|
-
)
|
317
|
-
# Start kv boostrap server on prefill
|
318
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
319
|
-
# only start bootstrap server on prefill tm
|
320
|
-
kv_bootstrap_server_class = get_kv_class(
|
321
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
322
|
-
)
|
323
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
324
|
-
self.server_args.disaggregation_bootstrap_port
|
325
|
-
)
|
326
|
-
is_create_store = (
|
327
|
-
self.server_args.node_rank == 0
|
328
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
329
|
-
)
|
330
|
-
if is_create_store:
|
331
|
-
try:
|
332
|
-
from mf_adapter import create_config_store
|
333
|
-
|
334
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
335
|
-
create_config_store(ascend_url)
|
336
|
-
except Exception as e:
|
337
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
338
|
-
error_message += f" With exception {e}"
|
339
|
-
raise error_message
|
323
|
+
self.init_disaggregation()
|
340
324
|
|
341
325
|
# For load balancing
|
342
326
|
self.current_load = 0
|
@@ -384,6 +368,9 @@ class TokenizerManager:
|
|
384
368
|
self.flush_cache_communicator = _Communicator(
|
385
369
|
self.send_to_scheduler, server_args.dp_size
|
386
370
|
)
|
371
|
+
self.clear_hicache_storage_communicator = _Communicator(
|
372
|
+
self.send_to_scheduler, server_args.dp_size
|
373
|
+
)
|
387
374
|
self.profile_communicator = _Communicator(
|
388
375
|
self.send_to_scheduler, server_args.dp_size
|
389
376
|
)
|
@@ -445,6 +432,10 @@ class TokenizerManager:
|
|
445
432
|
SlowDownReqOutput,
|
446
433
|
self.slow_down_communicator.handle_recv,
|
447
434
|
),
|
435
|
+
(
|
436
|
+
ClearHiCacheReqOutput,
|
437
|
+
self.clear_hicache_storage_communicator.handle_recv,
|
438
|
+
),
|
448
439
|
(
|
449
440
|
FlushCacheReqOutput,
|
450
441
|
self.flush_cache_communicator.handle_recv,
|
@@ -477,6 +468,37 @@ class TokenizerManager:
|
|
477
468
|
]
|
478
469
|
)
|
479
470
|
|
471
|
+
def init_disaggregation(self):
|
472
|
+
self.disaggregation_mode = DisaggregationMode(
|
473
|
+
self.server_args.disaggregation_mode
|
474
|
+
)
|
475
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
476
|
+
self.server_args.disaggregation_transfer_backend
|
477
|
+
)
|
478
|
+
# Start kv boostrap server on prefill
|
479
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
480
|
+
# only start bootstrap server on prefill tm
|
481
|
+
kv_bootstrap_server_class = get_kv_class(
|
482
|
+
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
483
|
+
)
|
484
|
+
self.bootstrap_server = kv_bootstrap_server_class(
|
485
|
+
self.server_args.disaggregation_bootstrap_port
|
486
|
+
)
|
487
|
+
is_create_store = (
|
488
|
+
self.server_args.node_rank == 0
|
489
|
+
and self.server_args.disaggregation_transfer_backend == "ascend"
|
490
|
+
)
|
491
|
+
if is_create_store:
|
492
|
+
try:
|
493
|
+
from mf_adapter import create_config_store
|
494
|
+
|
495
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
496
|
+
create_config_store(ascend_url)
|
497
|
+
except Exception as e:
|
498
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
499
|
+
error_message += f" With exception {e}"
|
500
|
+
raise error_message
|
501
|
+
|
480
502
|
async def generate_request(
|
481
503
|
self,
|
482
504
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -486,6 +508,15 @@ class TokenizerManager:
|
|
486
508
|
self.auto_create_handle_loop()
|
487
509
|
obj.normalize_batch_and_arguments()
|
488
510
|
|
511
|
+
if self.server_args.tokenizer_worker_num > 1:
|
512
|
+
# Modify rid, add worker_id
|
513
|
+
if isinstance(obj.rid, list):
|
514
|
+
# If it's an array, add worker_id prefix to each element
|
515
|
+
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
516
|
+
else:
|
517
|
+
# If it's a single value, add worker_id prefix
|
518
|
+
obj.rid = f"{self.worker_id}_{obj.rid}"
|
519
|
+
|
489
520
|
if self.log_requests:
|
490
521
|
max_length, skip_names, _ = self.log_request_metadata
|
491
522
|
logger.info(
|
@@ -768,6 +799,30 @@ class TokenizerManager:
|
|
768
799
|
self.rid_to_state[obj.rid] = state
|
769
800
|
return state
|
770
801
|
|
802
|
+
def _send_batch_request(
|
803
|
+
self,
|
804
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
805
|
+
tokenized_objs: List[
|
806
|
+
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
807
|
+
],
|
808
|
+
created_time: Optional[float] = None,
|
809
|
+
):
|
810
|
+
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
|
811
|
+
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
|
812
|
+
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
|
813
|
+
else:
|
814
|
+
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
|
815
|
+
|
816
|
+
self.send_to_scheduler.send_pyobj(batch_req)
|
817
|
+
|
818
|
+
# Create states for each individual request in the batch
|
819
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
820
|
+
tmp_obj = obj[i]
|
821
|
+
state = ReqState(
|
822
|
+
[], False, asyncio.Event(), tmp_obj, created_time=created_time
|
823
|
+
)
|
824
|
+
self.rid_to_state[tmp_obj.rid] = state
|
825
|
+
|
771
826
|
async def _wait_one_response(
|
772
827
|
self,
|
773
828
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -870,10 +925,17 @@ class TokenizerManager:
|
|
870
925
|
|
871
926
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
872
927
|
|
873
|
-
|
928
|
+
# Send as a single batched request
|
929
|
+
self._send_batch_request(obj, tokenized_objs, created_time)
|
930
|
+
|
931
|
+
# Set up generators for each request in the batch
|
932
|
+
for i in range(batch_size):
|
874
933
|
tmp_obj = obj[i]
|
875
|
-
|
876
|
-
|
934
|
+
generators.append(
|
935
|
+
self._wait_one_response(
|
936
|
+
tmp_obj, self.rid_to_state[tmp_obj.rid], request
|
937
|
+
)
|
938
|
+
)
|
877
939
|
rids.append(tmp_obj.rid)
|
878
940
|
else:
|
879
941
|
# Sequential tokenization and processing
|
@@ -955,6 +1017,13 @@ class TokenizerManager:
|
|
955
1017
|
async def flush_cache(self) -> FlushCacheReqOutput:
|
956
1018
|
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
957
1019
|
|
1020
|
+
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1021
|
+
"""Clear the hierarchical cache storage."""
|
1022
|
+
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1023
|
+
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1024
|
+
0
|
1025
|
+
]
|
1026
|
+
|
958
1027
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
959
1028
|
if not abort_all and rid not in self.rid_to_state:
|
960
1029
|
return
|
@@ -1047,6 +1116,8 @@ class TokenizerManager:
|
|
1047
1116
|
async def _wait_for_model_update_from_disk(
|
1048
1117
|
self, obj: UpdateWeightFromDiskReqInput
|
1049
1118
|
) -> Tuple[bool, str]:
|
1119
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1120
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1050
1121
|
self.send_to_scheduler.send_pyobj(obj)
|
1051
1122
|
self.model_update_result = asyncio.Future()
|
1052
1123
|
if self.server_args.dp_size == 1:
|
@@ -1266,6 +1337,8 @@ class TokenizerManager:
|
|
1266
1337
|
elif obj.session_id in self.session_futures:
|
1267
1338
|
return None
|
1268
1339
|
|
1340
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1341
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1269
1342
|
self.send_to_scheduler.send_pyobj(obj)
|
1270
1343
|
|
1271
1344
|
self.session_futures[obj.session_id] = asyncio.Future()
|
@@ -1286,13 +1359,11 @@ class TokenizerManager:
|
|
1286
1359
|
# Many DP ranks
|
1287
1360
|
return [res.internal_state for res in responses]
|
1288
1361
|
|
1289
|
-
async def set_internal_state(
|
1290
|
-
self, obj: SetInternalStateReq
|
1291
|
-
) -> SetInternalStateReqOutput:
|
1362
|
+
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1292
1363
|
responses: List[SetInternalStateReqOutput] = (
|
1293
1364
|
await self.set_internal_state_communicator(obj)
|
1294
1365
|
)
|
1295
|
-
return [res.
|
1366
|
+
return [res.updated for res in responses]
|
1296
1367
|
|
1297
1368
|
async def get_load(self) -> dict:
|
1298
1369
|
# TODO(lsyin): fake load report server
|
@@ -1543,7 +1614,6 @@ class TokenizerManager:
|
|
1543
1614
|
|
1544
1615
|
async def handle_loop(self):
|
1545
1616
|
"""The event loop that handles requests"""
|
1546
|
-
|
1547
1617
|
while True:
|
1548
1618
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
1549
1619
|
self._result_dispatcher(recv_obj)
|
@@ -1563,9 +1633,12 @@ class TokenizerManager:
|
|
1563
1633
|
)
|
1564
1634
|
continue
|
1565
1635
|
|
1636
|
+
origin_rid = rid
|
1637
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1638
|
+
origin_rid = get_origin_rid(rid)
|
1566
1639
|
# Build meta_info and return value
|
1567
1640
|
meta_info = {
|
1568
|
-
"id":
|
1641
|
+
"id": origin_rid,
|
1569
1642
|
"finish_reason": recv_obj.finished_reasons[i],
|
1570
1643
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1571
1644
|
"weight_version": self.server_args.weight_version,
|
@@ -1871,6 +1944,9 @@ class TokenizerManager:
|
|
1871
1944
|
if is_health_check_generate_req(recv_obj):
|
1872
1945
|
return
|
1873
1946
|
state = self.rid_to_state[recv_obj.rid]
|
1947
|
+
origin_rid = recv_obj.rid
|
1948
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1949
|
+
origin_rid = get_origin_rid(origin_rid)
|
1874
1950
|
state.finished = True
|
1875
1951
|
if recv_obj.finished_reason:
|
1876
1952
|
out = {
|
@@ -1883,7 +1959,7 @@ class TokenizerManager:
|
|
1883
1959
|
out = {
|
1884
1960
|
"text": "",
|
1885
1961
|
"meta_info": {
|
1886
|
-
"id":
|
1962
|
+
"id": origin_rid,
|
1887
1963
|
"finish_reason": {
|
1888
1964
|
"type": "abort",
|
1889
1965
|
"message": "Abort before prefill",
|
@@ -2069,6 +2145,8 @@ T = TypeVar("T")
|
|
2069
2145
|
class _Communicator(Generic[T]):
|
2070
2146
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2071
2147
|
|
2148
|
+
enable_multi_tokenizer = False
|
2149
|
+
|
2072
2150
|
def __init__(self, sender, fan_out: int):
|
2073
2151
|
self._sender = sender
|
2074
2152
|
self._fan_out = fan_out
|
@@ -2085,6 +2163,8 @@ class _Communicator(Generic[T]):
|
|
2085
2163
|
assert self._result_values is None
|
2086
2164
|
|
2087
2165
|
if obj:
|
2166
|
+
if _Communicator.enable_multi_tokenizer:
|
2167
|
+
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2088
2168
|
self._sender.send_pyobj(obj)
|
2089
2169
|
|
2090
2170
|
self._result_event = asyncio.Event()
|
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|
47
47
|
self.req_to_token_pool.free(req.req_pool_idx)
|
48
48
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
49
49
|
|
50
|
-
def cache_unfinished_req(self, req: Req):
|
50
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
51
51
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
52
|
req.req_pool_idx, : len(req.fill_ids)
|
53
53
|
]
|