sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -102,10 +102,7 @@ class HiRadixCache(RadixCache):
|
|
102
102
|
self.ongoing_backup = {}
|
103
103
|
# todo: dynamically adjust the threshold
|
104
104
|
self.write_through_threshold = (
|
105
|
-
1 if hicache_write_policy == "write_through" else
|
106
|
-
)
|
107
|
-
self.write_through_threshold_storage = (
|
108
|
-
1 if hicache_write_policy == "write_through" else 3
|
105
|
+
1 if hicache_write_policy == "write_through" else 2
|
109
106
|
)
|
110
107
|
self.load_back_threshold = 10
|
111
108
|
super().__init__(
|
@@ -125,6 +122,15 @@ class HiRadixCache(RadixCache):
|
|
125
122
|
height += 1
|
126
123
|
return height
|
127
124
|
|
125
|
+
def clear_storage_backend(self):
|
126
|
+
if self.enable_storage:
|
127
|
+
self.cache_controller.storage_backend.clear()
|
128
|
+
logger.info("Hierarchical cache storage backend cleared successfully!")
|
129
|
+
return True
|
130
|
+
else:
|
131
|
+
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
|
+
return False
|
133
|
+
|
128
134
|
def write_backup(self, node: TreeNode, write_back=False):
|
129
135
|
host_indices = self.cache_controller.write(
|
130
136
|
device_indices=node.value,
|
@@ -155,8 +161,9 @@ class HiRadixCache(RadixCache):
|
|
155
161
|
self.ongoing_backup[operation_id] = node
|
156
162
|
node.protect_host()
|
157
163
|
|
158
|
-
def
|
159
|
-
|
164
|
+
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
165
|
+
# skip the hit count update for chunked requests
|
166
|
+
if self.cache_controller.write_policy == "write_back" or chunked:
|
160
167
|
return
|
161
168
|
node.hit_count += 1
|
162
169
|
|
@@ -164,14 +171,6 @@ class HiRadixCache(RadixCache):
|
|
164
171
|
if node.hit_count >= self.write_through_threshold:
|
165
172
|
# write to host if the node is not backuped
|
166
173
|
self.write_backup(node)
|
167
|
-
else:
|
168
|
-
if (
|
169
|
-
self.enable_storage
|
170
|
-
and (not node.backuped_storage)
|
171
|
-
and node.hit_count >= self.write_through_threshold_storage
|
172
|
-
):
|
173
|
-
# if the node is backuped on host memory but not on storage
|
174
|
-
self.write_backup_storage(node)
|
175
174
|
|
176
175
|
def writing_check(self, write_back=False):
|
177
176
|
if write_back:
|
@@ -192,8 +191,11 @@ class HiRadixCache(RadixCache):
|
|
192
191
|
)
|
193
192
|
for _ in range(queue_size.item()):
|
194
193
|
ack_id = self.cache_controller.ack_write_queue.get()
|
195
|
-
self.
|
194
|
+
backuped_node = self.ongoing_write_through[ack_id]
|
195
|
+
self.dec_lock_ref(backuped_node)
|
196
196
|
del self.ongoing_write_through[ack_id]
|
197
|
+
if self.enable_storage:
|
198
|
+
self.write_backup_storage(backuped_node)
|
197
199
|
|
198
200
|
def loading_check(self):
|
199
201
|
while not self.cache_controller.ack_load_queue.empty():
|
@@ -376,57 +378,54 @@ class HiRadixCache(RadixCache):
|
|
376
378
|
self.writing_check()
|
377
379
|
self.loading_check()
|
378
380
|
if self.enable_storage:
|
379
|
-
self.
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
381
|
+
self.drain_storage_control_queues()
|
382
|
+
|
383
|
+
def drain_storage_control_queues(self):
|
384
|
+
"""
|
385
|
+
Combine prefetch revoke, backup ack, and host mem release checks
|
386
|
+
to minimize TP synchronization and Python overhead.
|
387
|
+
"""
|
388
|
+
cc = self.cache_controller
|
389
|
+
|
390
|
+
qsizes = torch.tensor(
|
391
|
+
[
|
392
|
+
cc.prefetch_revoke_queue.qsize(),
|
393
|
+
cc.ack_backup_queue.qsize(),
|
394
|
+
cc.host_mem_release_queue.qsize(),
|
395
|
+
],
|
396
|
+
dtype=torch.int,
|
385
397
|
)
|
386
398
|
if self.tp_world_size > 1:
|
387
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
388
399
|
torch.distributed.all_reduce(
|
389
|
-
|
390
|
-
op=torch.distributed.ReduceOp.MIN,
|
391
|
-
group=self.tp_group,
|
400
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
392
401
|
)
|
393
|
-
for _ in range(queue_size.item()):
|
394
|
-
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
395
|
-
if req_id in self.ongoing_prefetch:
|
396
|
-
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
397
|
-
last_host_node.release_host()
|
398
|
-
del self.ongoing_prefetch[req_id]
|
399
|
-
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
400
|
-
else:
|
401
|
-
# the revoked operation already got terminated
|
402
|
-
pass
|
403
402
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
)
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
403
|
+
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
404
|
+
|
405
|
+
# process prefetch revokes
|
406
|
+
for _ in range(n_revoke):
|
407
|
+
req_id = cc.prefetch_revoke_queue.get()
|
408
|
+
info = self.ongoing_prefetch.pop(req_id, None)
|
409
|
+
if info is not None:
|
410
|
+
last_host_node, token_ids, _, _ = info
|
411
|
+
last_host_node.release_host()
|
412
|
+
cc.prefetch_tokens_occupied -= len(token_ids)
|
413
|
+
# else: the revoked operation already got terminated, nothing to do
|
414
|
+
|
415
|
+
# process backup acks
|
416
|
+
for _ in range(n_backup):
|
417
|
+
ack_id = cc.ack_backup_queue.get()
|
418
|
+
entry = self.ongoing_backup.pop(ack_id, None)
|
419
|
+
if entry is not None:
|
420
|
+
entry.release_host()
|
421
|
+
|
422
|
+
# release host memory
|
423
|
+
host_indices_list = []
|
424
|
+
for _ in range(n_release):
|
425
|
+
host_indices_list.append(cc.host_mem_release_queue.get())
|
426
|
+
if host_indices_list:
|
427
|
+
host_indices = torch.cat(host_indices_list, dim=0)
|
428
|
+
cc.mem_pool_host.free(host_indices)
|
430
429
|
|
431
430
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
432
431
|
can_terminate = True
|
@@ -469,9 +468,9 @@ class HiRadixCache(RadixCache):
|
|
469
468
|
|
470
469
|
# todo: more policies for prefetch progress such as timeout
|
471
470
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
472
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
471
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
473
472
|
req_id
|
474
|
-
|
473
|
+
)
|
475
474
|
|
476
475
|
if operation.host_indices is None:
|
477
476
|
# prefetch has not been issued due to insufficient host memory
|
@@ -509,11 +508,10 @@ class HiRadixCache(RadixCache):
|
|
509
508
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
510
509
|
|
511
510
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
512
|
-
self.cache_controller.
|
511
|
+
self.cache_controller.append_host_mem_release(
|
513
512
|
host_indices[min_completed_tokens:completed_tokens]
|
514
513
|
)
|
515
514
|
last_host_node.release_host()
|
516
|
-
del self.ongoing_prefetch[req_id]
|
517
515
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
518
516
|
|
519
517
|
return True
|
@@ -565,7 +563,11 @@ class HiRadixCache(RadixCache):
|
|
565
563
|
len(new_input_tokens) % self.page_size
|
566
564
|
)
|
567
565
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
568
|
-
if
|
566
|
+
if (
|
567
|
+
not self.enable_storage
|
568
|
+
or prefetch_length < self.prefetch_threshold
|
569
|
+
or self.cache_controller.prefetch_rate_limited()
|
570
|
+
):
|
569
571
|
return
|
570
572
|
|
571
573
|
last_host_node.protect_host()
|
@@ -573,6 +575,10 @@ class HiRadixCache(RadixCache):
|
|
573
575
|
if host_indices is None:
|
574
576
|
self.evict_host(prefetch_length)
|
575
577
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
578
|
+
if host_indices is None:
|
579
|
+
last_host_node.release_host()
|
580
|
+
# no sufficient host memory for prefetch
|
581
|
+
return
|
576
582
|
operation = self.cache_controller.prefetch(
|
577
583
|
req_id, host_indices, new_input_tokens, last_hash
|
578
584
|
)
|
@@ -672,11 +678,11 @@ class HiRadixCache(RadixCache):
|
|
672
678
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
673
679
|
return new_node
|
674
680
|
|
675
|
-
def
|
676
|
-
node.last_access_time = time.monotonic()
|
681
|
+
def insert(self, key: List, value, chunked=False):
|
677
682
|
if len(key) == 0:
|
678
683
|
return 0
|
679
684
|
|
685
|
+
node = self.root_node
|
680
686
|
child_key = self.get_child_key_fn(key)
|
681
687
|
total_prefix_length = 0
|
682
688
|
|
@@ -693,7 +699,7 @@ class HiRadixCache(RadixCache):
|
|
693
699
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
694
700
|
self.evictable_size_ += len(node.value)
|
695
701
|
else:
|
696
|
-
self.
|
702
|
+
self._inc_hit_count(node, chunked)
|
697
703
|
total_prefix_length += prefix_len
|
698
704
|
else:
|
699
705
|
# partial match, split the node
|
@@ -703,7 +709,7 @@ class HiRadixCache(RadixCache):
|
|
703
709
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
704
710
|
self.evictable_size_ += len(new_node.value)
|
705
711
|
else:
|
706
|
-
self.
|
712
|
+
self._inc_hit_count(new_node, chunked)
|
707
713
|
total_prefix_length += prefix_len
|
708
714
|
node = new_node
|
709
715
|
|
@@ -737,7 +743,7 @@ class HiRadixCache(RadixCache):
|
|
737
743
|
last_hash = new_node.hash_value[-1]
|
738
744
|
|
739
745
|
if self.cache_controller.write_policy != "write_back":
|
740
|
-
self.
|
746
|
+
self._inc_hit_count(new_node, chunked)
|
741
747
|
return total_prefix_length
|
742
748
|
|
743
749
|
def _collect_leaves_device(self):
|
@@ -764,3 +770,20 @@ class HiRadixCache(RadixCache):
|
|
764
770
|
if not cur_child.evicted:
|
765
771
|
stack.append(cur_child)
|
766
772
|
return ret_list
|
773
|
+
|
774
|
+
def release_aborted_request(self, rid: str):
|
775
|
+
if rid not in self.ongoing_prefetch:
|
776
|
+
return
|
777
|
+
|
778
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
779
|
+
rid
|
780
|
+
)
|
781
|
+
if operation.host_indices is None:
|
782
|
+
return
|
783
|
+
|
784
|
+
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
785
|
+
if self.tp_world_size > 1:
|
786
|
+
torch.distributed.barrier(group=self.tp_group)
|
787
|
+
last_host_node.release_host()
|
788
|
+
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
789
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|
183
183
|
self.req_to_token_pool.free(req.req_pool_idx)
|
184
184
|
self.dec_lock_ref(req.last_node)
|
185
185
|
|
186
|
-
def cache_unfinished_req(self, req: Req):
|
186
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
187
187
|
"""Cache request when it is unfinished."""
|
188
188
|
if self.disable:
|
189
189
|
return
|
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
918
918
|
layer_num,
|
919
919
|
self.size // self.page_size + 1,
|
920
920
|
self.page_size,
|
921
|
+
1,
|
921
922
|
self.kv_lora_rank,
|
922
923
|
),
|
923
924
|
dtype=self.store_dtype,
|
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
928
929
|
layer_num,
|
929
930
|
self.size // self.page_size + 1,
|
930
931
|
self.page_size,
|
932
|
+
1,
|
931
933
|
self.qk_rope_head_dim,
|
932
934
|
),
|
933
935
|
dtype=self.store_dtype,
|
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1000
1002
|
layer_id = layer.layer_id
|
1001
1003
|
if cache_k.dtype != self.dtype:
|
1002
1004
|
cache_k = cache_k.to(self.dtype)
|
1005
|
+
cache_v = cache_v.to(self.dtype)
|
1003
1006
|
|
1004
1007
|
if self.store_dtype != self.dtype:
|
1005
1008
|
cache_k = cache_k.view(self.store_dtype)
|
1009
|
+
cache_v = cache_v.view(self.store_dtype)
|
1006
1010
|
|
1007
1011
|
if cache_v is None:
|
1008
1012
|
cache_k, cache_v = cache_k.split(
|
@@ -7,7 +7,6 @@ from functools import wraps
|
|
7
7
|
import psutil
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
12
11
|
from sglang.srt.utils import is_npu
|
13
12
|
|
@@ -464,11 +463,11 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
464
463
|
else:
|
465
464
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
465
|
|
467
|
-
def get_buffer_meta(self, keys, indices):
|
468
|
-
local_rank = get_tensor_model_parallel_rank()
|
466
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
469
467
|
ptr_list = []
|
470
468
|
key_list = []
|
471
469
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
470
|
+
indices = indices.tolist()
|
472
471
|
v_offset = (
|
473
472
|
self.layer_num
|
474
473
|
* self.size
|
@@ -704,10 +703,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
704
703
|
else:
|
705
704
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
706
705
|
|
707
|
-
def get_buffer_meta(self, keys, indices):
|
706
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
708
707
|
ptr_list = []
|
709
708
|
key_list = []
|
710
709
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
710
|
+
indices = indices.tolist()
|
711
711
|
for index in range(0, len(indices), self.page_size):
|
712
712
|
k_ptr = (
|
713
713
|
kv_buffer_data_ptr
|
@@ -62,7 +62,6 @@ class TreeNode:
|
|
62
62
|
self.host_value: Optional[torch.Tensor] = None
|
63
63
|
# store hash values of each pages
|
64
64
|
self.hash_value: Optional[List[str]] = None
|
65
|
-
self.backuped_storage = False
|
66
65
|
|
67
66
|
self.id = TreeNode.counter if id is None else id
|
68
67
|
TreeNode.counter += 1
|
@@ -195,7 +194,7 @@ class RadixCache(BasePrefixCache):
|
|
195
194
|
last_host_node=last_node,
|
196
195
|
)
|
197
196
|
|
198
|
-
def insert(self, key: List, value=None):
|
197
|
+
def insert(self, key: List, value=None, chunked=False):
|
199
198
|
if self.disable:
|
200
199
|
return 0
|
201
200
|
|
@@ -240,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
240
239
|
self.req_to_token_pool.free(req.req_pool_idx)
|
241
240
|
self.dec_lock_ref(req.last_node)
|
242
241
|
|
243
|
-
def cache_unfinished_req(self, req: Req):
|
242
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
244
243
|
"""Cache request when it is unfinished."""
|
245
244
|
if self.disable:
|
246
245
|
return
|
@@ -261,7 +260,9 @@ class RadixCache(BasePrefixCache):
|
|
261
260
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
262
261
|
|
263
262
|
# Radix Cache takes one ref in memory pool
|
264
|
-
new_prefix_len = self.insert(
|
263
|
+
new_prefix_len = self.insert(
|
264
|
+
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
265
|
+
)
|
265
266
|
self.token_to_kv_pool_allocator.free(
|
266
267
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
267
268
|
)
|
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
181
181
|
self.dec_lock_ref(req.last_node)
|
182
182
|
self.req_to_token_pool.free(req.req_pool_idx)
|
183
183
|
|
184
|
-
def cache_unfinished_req(self, req: Req):
|
184
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
185
185
|
"""Cache request when it is unfinished."""
|
186
186
|
assert req.req_pool_idx is not None
|
187
187
|
token_ids = req.fill_ids
|
@@ -4,10 +4,12 @@ import json
|
|
4
4
|
import logging
|
5
5
|
import threading
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import Dict, List, Optional, OrderedDict, Tuple
|
8
8
|
|
9
|
+
import orjson
|
9
10
|
import requests
|
10
|
-
from fastapi import FastAPI, HTTPException, Request,
|
11
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
12
|
+
from fastapi.responses import ORJSONResponse
|
11
13
|
from requests.adapters import HTTPAdapter
|
12
14
|
from urllib3.util.retry import Retry
|
13
15
|
|
@@ -24,10 +26,10 @@ class RankMetadata:
|
|
24
26
|
"""Holds all metadata for a single rank."""
|
25
27
|
|
26
28
|
def __init__(self, num_pages: int):
|
27
|
-
self.lock = threading.
|
29
|
+
self.lock = threading.Lock()
|
28
30
|
self.num_pages = num_pages
|
29
31
|
self.free_pages: List[int] = list(range(num_pages))
|
30
|
-
self.key_to_index:
|
32
|
+
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
31
33
|
# Todo: Support multi files for HF3FS
|
32
34
|
|
33
35
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
@@ -46,16 +48,18 @@ class RankMetadata:
|
|
46
48
|
for i, (key, prefix_key) in enumerate(keys):
|
47
49
|
if key in self.key_to_index:
|
48
50
|
results[i] = (True, self.key_to_index[key])
|
51
|
+
self.key_to_index.move_to_end(key)
|
49
52
|
else:
|
50
53
|
new_keys_to_process.append((i, key, prefix_key))
|
51
54
|
|
52
55
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
53
56
|
for i, key, prefix_key in new_keys_to_process:
|
54
57
|
if len(self.free_pages) > 0:
|
55
|
-
|
56
|
-
results[i] = (False, page_idx)
|
58
|
+
page_index = self.free_pages.pop()
|
57
59
|
else:
|
58
|
-
|
60
|
+
page_index = self.key_to_index.popitem(last=False)[1]
|
61
|
+
|
62
|
+
results[i] = (False, page_index)
|
59
63
|
|
60
64
|
return results
|
61
65
|
|
@@ -68,6 +72,7 @@ class RankMetadata:
|
|
68
72
|
with self.lock:
|
69
73
|
for key, page_index in written_keys_to_confirm:
|
70
74
|
self.key_to_index[key] = page_index
|
75
|
+
self.key_to_index.move_to_end(key)
|
71
76
|
|
72
77
|
for page_index in pages_to_release:
|
73
78
|
if page_index not in self.free_pages:
|
@@ -94,7 +99,14 @@ class RankMetadata:
|
|
94
99
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
95
100
|
"""Get page indices for keys."""
|
96
101
|
with self.lock:
|
97
|
-
|
102
|
+
results = []
|
103
|
+
for key in keys:
|
104
|
+
if key in self.key_to_index:
|
105
|
+
results.append(self.key_to_index[key])
|
106
|
+
self.key_to_index.move_to_end(key)
|
107
|
+
else:
|
108
|
+
results.append(None)
|
109
|
+
return results
|
98
110
|
|
99
111
|
|
100
112
|
class GlobalMetadataState:
|
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|
182
194
|
|
183
195
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
184
196
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
185
|
-
self.app = FastAPI()
|
197
|
+
self.app = FastAPI(default_response_class=ORJSONResponse)
|
198
|
+
|
186
199
|
self._setup_routes()
|
187
200
|
|
188
201
|
def _setup_routes(self):
|
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|
199
212
|
|
200
213
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
201
214
|
"""Get rank metadata with proper error handling."""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
if rank not in self.state.ranks:
|
216
|
+
raise HTTPException(
|
217
|
+
status_code=404,
|
218
|
+
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
219
|
+
)
|
220
|
+
return self.state.ranks[rank]
|
221
|
+
|
222
|
+
async def _read_json(self, request: Request) -> dict:
|
223
|
+
"""Parse request JSON using orjson if available."""
|
224
|
+
body = await request.body()
|
225
|
+
return orjson.loads(body)
|
226
|
+
|
227
|
+
def _json_response(self, content: dict):
|
228
|
+
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
229
|
+
return ORJSONResponse(content)
|
209
230
|
|
210
231
|
async def initialize(self, rank: int, request: Request):
|
211
232
|
"""Initialize a rank with specified number of pages."""
|
212
|
-
data = await
|
233
|
+
data = await self._read_json(request)
|
213
234
|
num_pages = data["num_pages"]
|
214
235
|
with self.state.global_lock:
|
215
236
|
if rank in self.state.ranks:
|
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|
223
244
|
else:
|
224
245
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
225
246
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
226
|
-
return
|
247
|
+
return Response(status_code=204)
|
227
248
|
|
228
249
|
async def exists(self, rank: int, request: Request):
|
229
250
|
"""Check if keys exist in metadata."""
|
230
|
-
data = await
|
251
|
+
data = await self._read_json(request)
|
231
252
|
keys = data["keys"]
|
232
253
|
metadata = self.get_rank_metadata(rank)
|
233
254
|
results = metadata.exists_keys(keys)
|
234
|
-
return {"exists": results}
|
255
|
+
return self._json_response({"exists": results})
|
235
256
|
|
236
257
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
237
258
|
"""Reserve and allocate page indices for keys."""
|
238
|
-
data = await
|
259
|
+
data = await self._read_json(request)
|
239
260
|
metadata = self.get_rank_metadata(rank)
|
240
261
|
keys = data["keys"]
|
241
262
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
242
|
-
return {"indices": results}
|
263
|
+
return self._json_response({"indices": results})
|
243
264
|
|
244
265
|
async def confirm_write(self, rank: int, request: Request):
|
245
266
|
"""Confirm write operations and release pages."""
|
246
|
-
data = await
|
267
|
+
data = await self._read_json(request)
|
247
268
|
metadata = self.get_rank_metadata(rank)
|
248
269
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
249
270
|
released_pages = data.get("pages_to_release", [])
|
250
271
|
|
251
272
|
metadata.confirm_write(success_written_keys, released_pages)
|
252
273
|
|
253
|
-
return
|
254
|
-
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
255
|
-
}
|
274
|
+
return Response(status_code=204)
|
256
275
|
|
257
276
|
async def delete_keys(self, rank: int, request: Request):
|
258
277
|
"""Delete keys from metadata."""
|
259
|
-
data = await
|
278
|
+
data = await self._read_json(request)
|
260
279
|
metadata = self.get_rank_metadata(rank)
|
261
280
|
count = metadata.delete_keys(data["keys"])
|
262
|
-
return
|
281
|
+
return Response(status_code=204)
|
263
282
|
|
264
283
|
async def clear(self, rank: int):
|
265
284
|
"""Clear all metadata for a rank."""
|
266
285
|
metadata = self.get_rank_metadata(rank)
|
267
286
|
metadata.clear_all()
|
268
|
-
return
|
287
|
+
return Response(status_code=204)
|
269
288
|
|
270
289
|
async def get_page_indices(self, rank: int, request: Request):
|
271
290
|
"""Get page indices for keys."""
|
272
|
-
data = await
|
291
|
+
data = await self._read_json(request)
|
273
292
|
metadata = self.get_rank_metadata(rank)
|
274
293
|
keys = data["keys"]
|
275
294
|
results = metadata.get_page_indices(keys)
|
276
|
-
return {"indices": results}
|
295
|
+
return self._json_response({"indices": results})
|
277
296
|
|
278
297
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
279
298
|
"""Run the metadata server."""
|
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|
309
328
|
status_forcelist=[500, 502, 503, 504],
|
310
329
|
allowed_methods=["GET", "POST"],
|
311
330
|
)
|
312
|
-
adapter = HTTPAdapter(
|
331
|
+
adapter = HTTPAdapter(
|
332
|
+
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
333
|
+
)
|
313
334
|
self._session.mount("http://", adapter)
|
314
335
|
|
315
336
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
316
337
|
try:
|
317
|
-
|
338
|
+
url = f"{self.base_url}/{endpoint}"
|
339
|
+
headers = {"Content-Type": "application/json"}
|
340
|
+
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
341
|
+
response = self._session.post(url, data=payload, headers=headers)
|
318
342
|
response.raise_for_status()
|
319
|
-
|
343
|
+
|
344
|
+
if response.status_code == 204 or not response.content:
|
345
|
+
return {}
|
346
|
+
return orjson.loads(response.content) # type: ignore[union-attr]
|
320
347
|
except requests.exceptions.RequestException as e:
|
321
348
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
322
349
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|