sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import (
|
|
69
69
|
AbortReq,
|
70
70
|
BatchTokenizedEmbeddingReqInput,
|
71
71
|
BatchTokenizedGenerateReqInput,
|
72
|
+
ClearHiCacheReqInput,
|
73
|
+
ClearHiCacheReqOutput,
|
72
74
|
CloseSessionReqInput,
|
73
75
|
ExpertDistributionReq,
|
74
76
|
ExpertDistributionReqOutput,
|
@@ -82,6 +84,8 @@ from sglang.srt.managers.io_struct import (
|
|
82
84
|
InitWeightsUpdateGroupReqInput,
|
83
85
|
LoadLoRAAdapterReqInput,
|
84
86
|
LoadLoRAAdapterReqOutput,
|
87
|
+
MultiTokenizerRegisterReq,
|
88
|
+
MultiTokenizerWarpper,
|
85
89
|
OpenSessionReqInput,
|
86
90
|
OpenSessionReqOutput,
|
87
91
|
ProfileReq,
|
@@ -255,7 +259,6 @@ class Scheduler(
|
|
255
259
|
# Init inter-process communication
|
256
260
|
context = zmq.Context(2)
|
257
261
|
self.idle_sleeper = None
|
258
|
-
|
259
262
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
260
263
|
self.recv_from_tokenizer = get_zmq_socket(
|
261
264
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
@@ -515,6 +518,7 @@ class Scheduler(
|
|
515
518
|
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
516
519
|
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
517
520
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
521
|
+
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
|
518
522
|
(AbortReq, self.abort_request),
|
519
523
|
(OpenSessionReqInput, self.open_session),
|
520
524
|
(CloseSessionReqInput, self.close_session),
|
@@ -537,6 +541,7 @@ class Scheduler(
|
|
537
541
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
538
542
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
539
543
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
544
|
+
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
540
545
|
]
|
541
546
|
)
|
542
547
|
|
@@ -1098,6 +1103,17 @@ class Scheduler(
|
|
1098
1103
|
)
|
1099
1104
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1100
1105
|
continue
|
1106
|
+
|
1107
|
+
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1108
|
+
if isinstance(recv_req, MultiTokenizerWarpper):
|
1109
|
+
worker_id = recv_req.worker_id
|
1110
|
+
recv_req = recv_req.obj
|
1111
|
+
output = self._request_dispatcher(recv_req)
|
1112
|
+
if output is not None:
|
1113
|
+
output = MultiTokenizerWarpper(worker_id, output)
|
1114
|
+
self.send_to_tokenizer.send_pyobj(output)
|
1115
|
+
continue
|
1116
|
+
|
1101
1117
|
output = self._request_dispatcher(recv_req)
|
1102
1118
|
if output is not None:
|
1103
1119
|
if isinstance(output, RpcReqOutput):
|
@@ -1503,7 +1519,7 @@ class Scheduler(
|
|
1503
1519
|
# Move the chunked request out of the batch so that we can merge
|
1504
1520
|
# only finished requests to running_batch.
|
1505
1521
|
chunked_req_to_exclude.add(self.chunked_req)
|
1506
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1522
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1507
1523
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1508
1524
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1509
1525
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
@@ -2207,6 +2223,16 @@ class Scheduler(
|
|
2207
2223
|
success = self.flush_cache()
|
2208
2224
|
return FlushCacheReqOutput(success=success)
|
2209
2225
|
|
2226
|
+
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
|
2227
|
+
if self.enable_hierarchical_cache:
|
2228
|
+
self.tree_cache.clear_storage_backend()
|
2229
|
+
logger.info("Hierarchical cache cleared successfully!")
|
2230
|
+
if_success = True
|
2231
|
+
else:
|
2232
|
+
logging.warning("Hierarchical cache is not enabled.")
|
2233
|
+
if_success = False
|
2234
|
+
return ClearHiCacheReqOutput(success=if_success)
|
2235
|
+
|
2210
2236
|
def flush_cache(self):
|
2211
2237
|
"""Flush the memory pool and cache."""
|
2212
2238
|
if (
|
@@ -2377,7 +2403,14 @@ class Scheduler(
|
|
2377
2403
|
# This only works for requests that have not started anything.
|
2378
2404
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2379
2405
|
req = self.waiting_queue.pop(i)
|
2406
|
+
if self.enable_hicache_storage:
|
2407
|
+
# to release prefetch events associated with the request
|
2408
|
+
self.tree_cache.release_aborted_request(req.rid)
|
2380
2409
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2410
|
+
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2411
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2412
|
+
self.tree_cache.cache_finished_req(req)
|
2413
|
+
|
2381
2414
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2382
2415
|
|
2383
2416
|
# Delete the requests in the grammar queue
|
@@ -2457,6 +2490,10 @@ class Scheduler(
|
|
2457
2490
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2458
2491
|
return result
|
2459
2492
|
|
2493
|
+
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
2494
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2495
|
+
return recv_req
|
2496
|
+
|
2460
2497
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2461
2498
|
t = recv_req.forward_sleep_time
|
2462
2499
|
if t is not None and t <= 0:
|
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
|
|
93
93
|
# This updates radix so others can match
|
94
94
|
self.tree_cache.cache_unfinished_req(req)
|
95
95
|
|
96
|
-
if
|
96
|
+
if batch.return_logprob:
|
97
97
|
assert extend_logprob_start_len_per_req is not None
|
98
98
|
assert extend_input_len_per_req is not None
|
99
99
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
100
|
extend_input_len = extend_input_len_per_req[i]
|
101
101
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
102
|
+
if req.return_logprob:
|
103
|
+
self.add_logprob_return_values(
|
104
|
+
i,
|
105
|
+
req,
|
106
|
+
logprob_pt,
|
107
|
+
next_token_ids,
|
108
|
+
num_input_logprobs,
|
109
|
+
logits_output,
|
110
|
+
)
|
110
111
|
logprob_pt += num_input_logprobs
|
111
112
|
|
112
113
|
if (
|
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
|
|
146
147
|
skip_stream_req = req
|
147
148
|
|
148
149
|
# Incrementally update input logprobs.
|
149
|
-
if
|
150
|
+
if batch.return_logprob:
|
150
151
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
151
152
|
extend_input_len = extend_input_len_per_req[i]
|
152
153
|
if extend_logprob_start_len < extend_input_len:
|
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
|
|
154
155
|
num_input_logprobs = (
|
155
156
|
extend_input_len - extend_logprob_start_len
|
156
157
|
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
158
|
+
if req.return_logprob:
|
159
|
+
self.add_input_logprob_return_values(
|
160
|
+
i,
|
161
|
+
req,
|
162
|
+
logits_output,
|
163
|
+
logprob_pt,
|
164
|
+
num_input_logprobs,
|
165
|
+
last_prefill_chunk=False,
|
166
|
+
)
|
165
167
|
logprob_pt += num_input_logprobs
|
166
168
|
|
167
169
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
|
|
121
121
|
url = params["url"]
|
122
122
|
|
123
123
|
worker = self.tp_worker.worker
|
124
|
-
|
125
124
|
worker.model_runner.save_remote_model(url)
|
126
125
|
|
126
|
+
if self.draft_worker is not None:
|
127
|
+
draft_url = params.get("draft_url", None)
|
128
|
+
assert (
|
129
|
+
draft_url is not None
|
130
|
+
), "draft_url must be provided when draft model is enabled"
|
131
|
+
draft_worker = self.draft_worker.worker
|
132
|
+
draft_worker.model_runner.save_remote_model(draft_url)
|
133
|
+
|
127
134
|
def save_sharded_model(self, params):
|
128
135
|
worker = self.tp_worker.worker
|
129
136
|
|
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
|
73
73
|
BatchTokenIDOut,
|
74
74
|
BatchTokenizedEmbeddingReqInput,
|
75
75
|
BatchTokenizedGenerateReqInput,
|
76
|
+
ClearHiCacheReqInput,
|
77
|
+
ClearHiCacheReqOutput,
|
76
78
|
CloseSessionReqInput,
|
77
79
|
ConfigureLoggingReq,
|
78
80
|
EmbeddingReqInput,
|
@@ -92,6 +94,7 @@ from sglang.srt.managers.io_struct import (
|
|
92
94
|
LoadLoRAAdapterReqInput,
|
93
95
|
LoadLoRAAdapterReqOutput,
|
94
96
|
LoRAUpdateResult,
|
97
|
+
MultiTokenizerWarpper,
|
95
98
|
OpenSessionReqInput,
|
96
99
|
OpenSessionReqOutput,
|
97
100
|
ProfileReq,
|
@@ -129,6 +132,7 @@ from sglang.srt.utils import (
|
|
129
132
|
dataclass_to_string_truncated,
|
130
133
|
freeze_gc,
|
131
134
|
get_bool_env_var,
|
135
|
+
get_origin_rid,
|
132
136
|
get_zmq_socket,
|
133
137
|
kill_process_tree,
|
134
138
|
)
|
@@ -264,9 +268,15 @@ class TokenizerManager:
|
|
264
268
|
self.recv_from_detokenizer = get_zmq_socket(
|
265
269
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
266
270
|
)
|
267
|
-
self.
|
268
|
-
|
269
|
-
|
271
|
+
if self.server_args.tokenizer_worker_num > 1:
|
272
|
+
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
273
|
+
self.send_to_scheduler = get_zmq_socket(
|
274
|
+
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
self.send_to_scheduler = get_zmq_socket(
|
278
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
279
|
+
)
|
270
280
|
|
271
281
|
# Request states
|
272
282
|
self.no_create_loop = False
|
@@ -310,35 +320,7 @@ class TokenizerManager:
|
|
310
320
|
self.lora_update_lock = asyncio.Lock()
|
311
321
|
|
312
322
|
# For PD disaggregtion
|
313
|
-
self.
|
314
|
-
self.server_args.disaggregation_mode
|
315
|
-
)
|
316
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
317
|
-
self.server_args.disaggregation_transfer_backend
|
318
|
-
)
|
319
|
-
# Start kv boostrap server on prefill
|
320
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
321
|
-
# only start bootstrap server on prefill tm
|
322
|
-
kv_bootstrap_server_class = get_kv_class(
|
323
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
324
|
-
)
|
325
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
326
|
-
self.server_args.disaggregation_bootstrap_port
|
327
|
-
)
|
328
|
-
is_create_store = (
|
329
|
-
self.server_args.node_rank == 0
|
330
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
331
|
-
)
|
332
|
-
if is_create_store:
|
333
|
-
try:
|
334
|
-
from mf_adapter import create_config_store
|
335
|
-
|
336
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
337
|
-
create_config_store(ascend_url)
|
338
|
-
except Exception as e:
|
339
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
340
|
-
error_message += f" With exception {e}"
|
341
|
-
raise error_message
|
323
|
+
self.init_disaggregation()
|
342
324
|
|
343
325
|
# For load balancing
|
344
326
|
self.current_load = 0
|
@@ -386,6 +368,9 @@ class TokenizerManager:
|
|
386
368
|
self.flush_cache_communicator = _Communicator(
|
387
369
|
self.send_to_scheduler, server_args.dp_size
|
388
370
|
)
|
371
|
+
self.clear_hicache_storage_communicator = _Communicator(
|
372
|
+
self.send_to_scheduler, server_args.dp_size
|
373
|
+
)
|
389
374
|
self.profile_communicator = _Communicator(
|
390
375
|
self.send_to_scheduler, server_args.dp_size
|
391
376
|
)
|
@@ -447,6 +432,10 @@ class TokenizerManager:
|
|
447
432
|
SlowDownReqOutput,
|
448
433
|
self.slow_down_communicator.handle_recv,
|
449
434
|
),
|
435
|
+
(
|
436
|
+
ClearHiCacheReqOutput,
|
437
|
+
self.clear_hicache_storage_communicator.handle_recv,
|
438
|
+
),
|
450
439
|
(
|
451
440
|
FlushCacheReqOutput,
|
452
441
|
self.flush_cache_communicator.handle_recv,
|
@@ -479,6 +468,37 @@ class TokenizerManager:
|
|
479
468
|
]
|
480
469
|
)
|
481
470
|
|
471
|
+
def init_disaggregation(self):
|
472
|
+
self.disaggregation_mode = DisaggregationMode(
|
473
|
+
self.server_args.disaggregation_mode
|
474
|
+
)
|
475
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
476
|
+
self.server_args.disaggregation_transfer_backend
|
477
|
+
)
|
478
|
+
# Start kv boostrap server on prefill
|
479
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
480
|
+
# only start bootstrap server on prefill tm
|
481
|
+
kv_bootstrap_server_class = get_kv_class(
|
482
|
+
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
483
|
+
)
|
484
|
+
self.bootstrap_server = kv_bootstrap_server_class(
|
485
|
+
self.server_args.disaggregation_bootstrap_port
|
486
|
+
)
|
487
|
+
is_create_store = (
|
488
|
+
self.server_args.node_rank == 0
|
489
|
+
and self.server_args.disaggregation_transfer_backend == "ascend"
|
490
|
+
)
|
491
|
+
if is_create_store:
|
492
|
+
try:
|
493
|
+
from mf_adapter import create_config_store
|
494
|
+
|
495
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
496
|
+
create_config_store(ascend_url)
|
497
|
+
except Exception as e:
|
498
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
499
|
+
error_message += f" With exception {e}"
|
500
|
+
raise error_message
|
501
|
+
|
482
502
|
async def generate_request(
|
483
503
|
self,
|
484
504
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -488,6 +508,15 @@ class TokenizerManager:
|
|
488
508
|
self.auto_create_handle_loop()
|
489
509
|
obj.normalize_batch_and_arguments()
|
490
510
|
|
511
|
+
if self.server_args.tokenizer_worker_num > 1:
|
512
|
+
# Modify rid, add worker_id
|
513
|
+
if isinstance(obj.rid, list):
|
514
|
+
# If it's an array, add worker_id prefix to each element
|
515
|
+
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
516
|
+
else:
|
517
|
+
# If it's a single value, add worker_id prefix
|
518
|
+
obj.rid = f"{self.worker_id}_{obj.rid}"
|
519
|
+
|
491
520
|
if self.log_requests:
|
492
521
|
max_length, skip_names, _ = self.log_request_metadata
|
493
522
|
logger.info(
|
@@ -988,6 +1017,13 @@ class TokenizerManager:
|
|
988
1017
|
async def flush_cache(self) -> FlushCacheReqOutput:
|
989
1018
|
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
990
1019
|
|
1020
|
+
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1021
|
+
"""Clear the hierarchical cache storage."""
|
1022
|
+
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1023
|
+
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1024
|
+
0
|
1025
|
+
]
|
1026
|
+
|
991
1027
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
992
1028
|
if not abort_all and rid not in self.rid_to_state:
|
993
1029
|
return
|
@@ -1080,6 +1116,8 @@ class TokenizerManager:
|
|
1080
1116
|
async def _wait_for_model_update_from_disk(
|
1081
1117
|
self, obj: UpdateWeightFromDiskReqInput
|
1082
1118
|
) -> Tuple[bool, str]:
|
1119
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1120
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1083
1121
|
self.send_to_scheduler.send_pyobj(obj)
|
1084
1122
|
self.model_update_result = asyncio.Future()
|
1085
1123
|
if self.server_args.dp_size == 1:
|
@@ -1299,6 +1337,8 @@ class TokenizerManager:
|
|
1299
1337
|
elif obj.session_id in self.session_futures:
|
1300
1338
|
return None
|
1301
1339
|
|
1340
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1341
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1302
1342
|
self.send_to_scheduler.send_pyobj(obj)
|
1303
1343
|
|
1304
1344
|
self.session_futures[obj.session_id] = asyncio.Future()
|
@@ -1319,13 +1359,11 @@ class TokenizerManager:
|
|
1319
1359
|
# Many DP ranks
|
1320
1360
|
return [res.internal_state for res in responses]
|
1321
1361
|
|
1322
|
-
async def set_internal_state(
|
1323
|
-
self, obj: SetInternalStateReq
|
1324
|
-
) -> SetInternalStateReqOutput:
|
1362
|
+
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1325
1363
|
responses: List[SetInternalStateReqOutput] = (
|
1326
1364
|
await self.set_internal_state_communicator(obj)
|
1327
1365
|
)
|
1328
|
-
return [res.
|
1366
|
+
return [res.updated for res in responses]
|
1329
1367
|
|
1330
1368
|
async def get_load(self) -> dict:
|
1331
1369
|
# TODO(lsyin): fake load report server
|
@@ -1576,7 +1614,6 @@ class TokenizerManager:
|
|
1576
1614
|
|
1577
1615
|
async def handle_loop(self):
|
1578
1616
|
"""The event loop that handles requests"""
|
1579
|
-
|
1580
1617
|
while True:
|
1581
1618
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
1582
1619
|
self._result_dispatcher(recv_obj)
|
@@ -1596,9 +1633,12 @@ class TokenizerManager:
|
|
1596
1633
|
)
|
1597
1634
|
continue
|
1598
1635
|
|
1636
|
+
origin_rid = rid
|
1637
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1638
|
+
origin_rid = get_origin_rid(rid)
|
1599
1639
|
# Build meta_info and return value
|
1600
1640
|
meta_info = {
|
1601
|
-
"id":
|
1641
|
+
"id": origin_rid,
|
1602
1642
|
"finish_reason": recv_obj.finished_reasons[i],
|
1603
1643
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1604
1644
|
"weight_version": self.server_args.weight_version,
|
@@ -1904,6 +1944,9 @@ class TokenizerManager:
|
|
1904
1944
|
if is_health_check_generate_req(recv_obj):
|
1905
1945
|
return
|
1906
1946
|
state = self.rid_to_state[recv_obj.rid]
|
1947
|
+
origin_rid = recv_obj.rid
|
1948
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1949
|
+
origin_rid = get_origin_rid(origin_rid)
|
1907
1950
|
state.finished = True
|
1908
1951
|
if recv_obj.finished_reason:
|
1909
1952
|
out = {
|
@@ -1916,7 +1959,7 @@ class TokenizerManager:
|
|
1916
1959
|
out = {
|
1917
1960
|
"text": "",
|
1918
1961
|
"meta_info": {
|
1919
|
-
"id":
|
1962
|
+
"id": origin_rid,
|
1920
1963
|
"finish_reason": {
|
1921
1964
|
"type": "abort",
|
1922
1965
|
"message": "Abort before prefill",
|
@@ -2102,6 +2145,8 @@ T = TypeVar("T")
|
|
2102
2145
|
class _Communicator(Generic[T]):
|
2103
2146
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2104
2147
|
|
2148
|
+
enable_multi_tokenizer = False
|
2149
|
+
|
2105
2150
|
def __init__(self, sender, fan_out: int):
|
2106
2151
|
self._sender = sender
|
2107
2152
|
self._fan_out = fan_out
|
@@ -2118,6 +2163,8 @@ class _Communicator(Generic[T]):
|
|
2118
2163
|
assert self._result_values is None
|
2119
2164
|
|
2120
2165
|
if obj:
|
2166
|
+
if _Communicator.enable_multi_tokenizer:
|
2167
|
+
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2121
2168
|
self._sender.send_pyobj(obj)
|
2122
2169
|
|
2123
2170
|
self._result_event = asyncio.Event()
|
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|
47
47
|
self.req_to_token_pool.free(req.req_pool_idx)
|
48
48
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
49
49
|
|
50
|
-
def cache_unfinished_req(self, req: Req):
|
50
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
51
51
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
52
|
req.req_pool_idx, : len(req.fill_ids)
|
53
53
|
]
|
@@ -102,6 +102,20 @@ class HiCacheStorage(ABC):
|
|
102
102
|
"""
|
103
103
|
pass
|
104
104
|
|
105
|
+
@abstractmethod
|
106
|
+
def delete(self, key: str) -> bool:
|
107
|
+
"""
|
108
|
+
Delete the entry associated with the given key.
|
109
|
+
"""
|
110
|
+
pass
|
111
|
+
|
112
|
+
@abstractmethod
|
113
|
+
def clear(self) -> bool:
|
114
|
+
"""
|
115
|
+
Clear all entries in the storage.
|
116
|
+
"""
|
117
|
+
pass
|
118
|
+
|
105
119
|
def batch_exists(self, keys: List[str]) -> int:
|
106
120
|
"""
|
107
121
|
Check if the keys exist in the storage.
|
@@ -175,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
|
|
175
189
|
target_location: Optional[Any] = None,
|
176
190
|
target_sizes: Optional[Any] = None,
|
177
191
|
) -> bool:
|
178
|
-
key = self._get_suffixed_key(key)
|
179
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
180
192
|
if self.exists(key):
|
181
193
|
logger.debug(f"Key {key} already exists. Skipped.")
|
182
194
|
return True
|
195
|
+
|
196
|
+
key = self._get_suffixed_key(key)
|
197
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
183
198
|
try:
|
184
199
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
185
200
|
return True
|
@@ -213,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
|
|
213
228
|
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
214
229
|
return
|
215
230
|
|
216
|
-
def clear(self) ->
|
231
|
+
def clear(self) -> bool:
|
217
232
|
try:
|
218
233
|
for filename in os.listdir(self.file_path):
|
219
234
|
file_path = os.path.join(self.file_path, filename)
|
220
235
|
if os.path.isfile(file_path):
|
221
236
|
os.remove(file_path)
|
222
237
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
238
|
+
return True
|
223
239
|
except Exception as e:
|
224
240
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
241
|
+
return False
|