sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -102,10 +102,7 @@ class HiRadixCache(RadixCache):
102
102
  self.ongoing_backup = {}
103
103
  # todo: dynamically adjust the threshold
104
104
  self.write_through_threshold = (
105
- 1 if hicache_write_policy == "write_through" else 3
106
- )
107
- self.write_through_threshold_storage = (
108
- 1 if hicache_write_policy == "write_through" else 3
105
+ 1 if hicache_write_policy == "write_through" else 2
109
106
  )
110
107
  self.load_back_threshold = 10
111
108
  super().__init__(
@@ -125,6 +122,15 @@ class HiRadixCache(RadixCache):
125
122
  height += 1
126
123
  return height
127
124
 
125
+ def clear_storage_backend(self):
126
+ if self.enable_storage:
127
+ self.cache_controller.storage_backend.clear()
128
+ logger.info("Hierarchical cache storage backend cleared successfully!")
129
+ return True
130
+ else:
131
+ logger.warning("Hierarchical cache storage backend is not enabled.")
132
+ return False
133
+
128
134
  def write_backup(self, node: TreeNode, write_back=False):
129
135
  host_indices = self.cache_controller.write(
130
136
  device_indices=node.value,
@@ -155,8 +161,9 @@ class HiRadixCache(RadixCache):
155
161
  self.ongoing_backup[operation_id] = node
156
162
  node.protect_host()
157
163
 
158
- def inc_hit_count(self, node: TreeNode):
159
- if self.cache_controller.write_policy == "write_back":
164
+ def _inc_hit_count(self, node: TreeNode, chunked=False):
165
+ # skip the hit count update for chunked requests
166
+ if self.cache_controller.write_policy == "write_back" or chunked:
160
167
  return
161
168
  node.hit_count += 1
162
169
 
@@ -164,14 +171,6 @@ class HiRadixCache(RadixCache):
164
171
  if node.hit_count >= self.write_through_threshold:
165
172
  # write to host if the node is not backuped
166
173
  self.write_backup(node)
167
- else:
168
- if (
169
- self.enable_storage
170
- and (not node.backuped_storage)
171
- and node.hit_count >= self.write_through_threshold_storage
172
- ):
173
- # if the node is backuped on host memory but not on storage
174
- self.write_backup_storage(node)
175
174
 
176
175
  def writing_check(self, write_back=False):
177
176
  if write_back:
@@ -192,8 +191,11 @@ class HiRadixCache(RadixCache):
192
191
  )
193
192
  for _ in range(queue_size.item()):
194
193
  ack_id = self.cache_controller.ack_write_queue.get()
195
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
194
+ backuped_node = self.ongoing_write_through[ack_id]
195
+ self.dec_lock_ref(backuped_node)
196
196
  del self.ongoing_write_through[ack_id]
197
+ if self.enable_storage:
198
+ self.write_backup_storage(backuped_node)
197
199
 
198
200
  def loading_check(self):
199
201
  while not self.cache_controller.ack_load_queue.empty():
@@ -376,57 +378,54 @@ class HiRadixCache(RadixCache):
376
378
  self.writing_check()
377
379
  self.loading_check()
378
380
  if self.enable_storage:
379
- self.check_revoked_prefetch()
380
- self.check_backup_progress()
381
-
382
- def check_revoked_prefetch(self):
383
- queue_size = torch.tensor(
384
- self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
381
+ self.drain_storage_control_queues()
382
+
383
+ def drain_storage_control_queues(self):
384
+ """
385
+ Combine prefetch revoke, backup ack, and host mem release checks
386
+ to minimize TP synchronization and Python overhead.
387
+ """
388
+ cc = self.cache_controller
389
+
390
+ qsizes = torch.tensor(
391
+ [
392
+ cc.prefetch_revoke_queue.qsize(),
393
+ cc.ack_backup_queue.qsize(),
394
+ cc.host_mem_release_queue.qsize(),
395
+ ],
396
+ dtype=torch.int,
385
397
  )
386
398
  if self.tp_world_size > 1:
387
- # synchrnoize TP workers to make the same update to hiradix cache
388
399
  torch.distributed.all_reduce(
389
- queue_size,
390
- op=torch.distributed.ReduceOp.MIN,
391
- group=self.tp_group,
400
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
392
401
  )
393
- for _ in range(queue_size.item()):
394
- req_id = self.cache_controller.prefetch_revoke_queue.get()
395
- if req_id in self.ongoing_prefetch:
396
- last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
397
- last_host_node.release_host()
398
- del self.ongoing_prefetch[req_id]
399
- self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
400
- else:
401
- # the revoked operation already got terminated
402
- pass
403
402
 
404
- def check_backup_progress(self):
405
- queue_size = torch.tensor(
406
- self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
407
- )
408
- if self.tp_world_size > 1:
409
- # synchrnoize TP workers to make the same update to hiradix cache
410
- torch.distributed.all_reduce(
411
- queue_size,
412
- op=torch.distributed.ReduceOp.MIN,
413
- group=self.tp_group,
414
- )
415
- for _ in range(queue_size.item()):
416
- ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
417
- host_node = self.ongoing_backup[ack_id]
418
-
419
- if completed_tokens > 0:
420
- if completed_tokens < len(host_node.key):
421
- # backup is only partially successful, split the node
422
- new_node = self._split_node(
423
- host_node.key, host_node, completed_tokens
424
- )
425
- new_node.backuped_storage = True
426
- else:
427
- host_node.backuped_storage = True
428
- host_node.release_host()
429
- del self.ongoing_backup[ack_id]
403
+ n_revoke, n_backup, n_release = map(int, qsizes.tolist())
404
+
405
+ # process prefetch revokes
406
+ for _ in range(n_revoke):
407
+ req_id = cc.prefetch_revoke_queue.get()
408
+ info = self.ongoing_prefetch.pop(req_id, None)
409
+ if info is not None:
410
+ last_host_node, token_ids, _, _ = info
411
+ last_host_node.release_host()
412
+ cc.prefetch_tokens_occupied -= len(token_ids)
413
+ # else: the revoked operation already got terminated, nothing to do
414
+
415
+ # process backup acks
416
+ for _ in range(n_backup):
417
+ ack_id = cc.ack_backup_queue.get()
418
+ entry = self.ongoing_backup.pop(ack_id, None)
419
+ if entry is not None:
420
+ entry.release_host()
421
+
422
+ # release host memory
423
+ host_indices_list = []
424
+ for _ in range(n_release):
425
+ host_indices_list.append(cc.host_mem_release_queue.get())
426
+ if host_indices_list:
427
+ host_indices = torch.cat(host_indices_list, dim=0)
428
+ cc.mem_pool_host.free(host_indices)
430
429
 
431
430
  def can_terminate_prefetch(self, operation: PrefetchOperation):
432
431
  can_terminate = True
@@ -469,9 +468,9 @@ class HiRadixCache(RadixCache):
469
468
 
470
469
  # todo: more policies for prefetch progress such as timeout
471
470
  # the current policy is to prefetch with best effort and terminate when queuing is over
472
- last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
471
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
473
472
  req_id
474
- ]
473
+ )
475
474
 
476
475
  if operation.host_indices is None:
477
476
  # prefetch has not been issued due to insufficient host memory
@@ -509,11 +508,10 @@ class HiRadixCache(RadixCache):
509
508
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
510
509
 
511
510
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
512
- self.cache_controller.mem_pool_host.free(
511
+ self.cache_controller.append_host_mem_release(
513
512
  host_indices[min_completed_tokens:completed_tokens]
514
513
  )
515
514
  last_host_node.release_host()
516
- del self.ongoing_prefetch[req_id]
517
515
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
518
516
 
519
517
  return True
@@ -565,7 +563,11 @@ class HiRadixCache(RadixCache):
565
563
  len(new_input_tokens) % self.page_size
566
564
  )
567
565
  new_input_tokens = new_input_tokens[:prefetch_length]
568
- if not self.enable_storage or prefetch_length < self.prefetch_threshold:
566
+ if (
567
+ not self.enable_storage
568
+ or prefetch_length < self.prefetch_threshold
569
+ or self.cache_controller.prefetch_rate_limited()
570
+ ):
569
571
  return
570
572
 
571
573
  last_host_node.protect_host()
@@ -573,6 +575,10 @@ class HiRadixCache(RadixCache):
573
575
  if host_indices is None:
574
576
  self.evict_host(prefetch_length)
575
577
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
578
+ if host_indices is None:
579
+ last_host_node.release_host()
580
+ # no sufficient host memory for prefetch
581
+ return
576
582
  operation = self.cache_controller.prefetch(
577
583
  req_id, host_indices, new_input_tokens, last_hash
578
584
  )
@@ -672,11 +678,11 @@ class HiRadixCache(RadixCache):
672
678
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
673
679
  return new_node
674
680
 
675
- def _insert_helper(self, node: TreeNode, key: List, value):
676
- node.last_access_time = time.monotonic()
681
+ def insert(self, key: List, value, chunked=False):
677
682
  if len(key) == 0:
678
683
  return 0
679
684
 
685
+ node = self.root_node
680
686
  child_key = self.get_child_key_fn(key)
681
687
  total_prefix_length = 0
682
688
 
@@ -693,7 +699,7 @@ class HiRadixCache(RadixCache):
693
699
  self.token_to_kv_pool_host.update_synced(node.host_value)
694
700
  self.evictable_size_ += len(node.value)
695
701
  else:
696
- self.inc_hit_count(node)
702
+ self._inc_hit_count(node, chunked)
697
703
  total_prefix_length += prefix_len
698
704
  else:
699
705
  # partial match, split the node
@@ -703,7 +709,7 @@ class HiRadixCache(RadixCache):
703
709
  self.token_to_kv_pool_host.update_synced(new_node.host_value)
704
710
  self.evictable_size_ += len(new_node.value)
705
711
  else:
706
- self.inc_hit_count(new_node)
712
+ self._inc_hit_count(new_node, chunked)
707
713
  total_prefix_length += prefix_len
708
714
  node = new_node
709
715
 
@@ -737,7 +743,7 @@ class HiRadixCache(RadixCache):
737
743
  last_hash = new_node.hash_value[-1]
738
744
 
739
745
  if self.cache_controller.write_policy != "write_back":
740
- self.inc_hit_count(new_node)
746
+ self._inc_hit_count(new_node, chunked)
741
747
  return total_prefix_length
742
748
 
743
749
  def _collect_leaves_device(self):
@@ -764,3 +770,20 @@ class HiRadixCache(RadixCache):
764
770
  if not cur_child.evicted:
765
771
  stack.append(cur_child)
766
772
  return ret_list
773
+
774
+ def release_aborted_request(self, rid: str):
775
+ if rid not in self.ongoing_prefetch:
776
+ return
777
+
778
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
779
+ rid
780
+ )
781
+ if operation.host_indices is None:
782
+ return
783
+
784
+ completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
785
+ if self.tp_world_size > 1:
786
+ torch.distributed.barrier(group=self.tp_group)
787
+ last_host_node.release_host()
788
+ self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
789
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
183
183
  self.req_to_token_pool.free(req.req_pool_idx)
184
184
  self.dec_lock_ref(req.last_node)
185
185
 
186
- def cache_unfinished_req(self, req: Req):
186
+ def cache_unfinished_req(self, req: Req, chunked=False):
187
187
  """Cache request when it is unfinished."""
188
188
  if self.disable:
189
189
  return
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
918
918
  layer_num,
919
919
  self.size // self.page_size + 1,
920
920
  self.page_size,
921
+ 1,
921
922
  self.kv_lora_rank,
922
923
  ),
923
924
  dtype=self.store_dtype,
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
928
929
  layer_num,
929
930
  self.size // self.page_size + 1,
930
931
  self.page_size,
932
+ 1,
931
933
  self.qk_rope_head_dim,
932
934
  ),
933
935
  dtype=self.store_dtype,
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1000
1002
  layer_id = layer.layer_id
1001
1003
  if cache_k.dtype != self.dtype:
1002
1004
  cache_k = cache_k.to(self.dtype)
1005
+ cache_v = cache_v.to(self.dtype)
1003
1006
 
1004
1007
  if self.store_dtype != self.dtype:
1005
1008
  cache_k = cache_k.view(self.store_dtype)
1009
+ cache_v = cache_v.view(self.store_dtype)
1006
1010
 
1007
1011
  if cache_v is None:
1008
1012
  cache_k, cache_v = cache_k.split(
@@ -7,7 +7,6 @@ from functools import wraps
7
7
  import psutil
8
8
  import torch
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
12
11
  from sglang.srt.utils import is_npu
13
12
 
@@ -464,11 +463,11 @@ class MHATokenToKVPoolHost(HostKVCache):
464
463
  else:
465
464
  raise ValueError(f"Unsupported layout: {self.layout}")
466
465
 
467
- def get_buffer_meta(self, keys, indices):
468
- local_rank = get_tensor_model_parallel_rank()
466
+ def get_buffer_meta(self, keys, indices, local_rank):
469
467
  ptr_list = []
470
468
  key_list = []
471
469
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
470
+ indices = indices.tolist()
472
471
  v_offset = (
473
472
  self.layer_num
474
473
  * self.size
@@ -704,10 +703,11 @@ class MLATokenToKVPoolHost(HostKVCache):
704
703
  else:
705
704
  raise ValueError(f"Unsupported layout: {self.layout}")
706
705
 
707
- def get_buffer_meta(self, keys, indices):
706
+ def get_buffer_meta(self, keys, indices, local_rank):
708
707
  ptr_list = []
709
708
  key_list = []
710
709
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
710
+ indices = indices.tolist()
711
711
  for index in range(0, len(indices), self.page_size):
712
712
  k_ptr = (
713
713
  kv_buffer_data_ptr
@@ -62,7 +62,6 @@ class TreeNode:
62
62
  self.host_value: Optional[torch.Tensor] = None
63
63
  # store hash values of each pages
64
64
  self.hash_value: Optional[List[str]] = None
65
- self.backuped_storage = False
66
65
 
67
66
  self.id = TreeNode.counter if id is None else id
68
67
  TreeNode.counter += 1
@@ -195,7 +194,7 @@ class RadixCache(BasePrefixCache):
195
194
  last_host_node=last_node,
196
195
  )
197
196
 
198
- def insert(self, key: List, value=None):
197
+ def insert(self, key: List, value=None, chunked=False):
199
198
  if self.disable:
200
199
  return 0
201
200
 
@@ -240,7 +239,7 @@ class RadixCache(BasePrefixCache):
240
239
  self.req_to_token_pool.free(req.req_pool_idx)
241
240
  self.dec_lock_ref(req.last_node)
242
241
 
243
- def cache_unfinished_req(self, req: Req):
242
+ def cache_unfinished_req(self, req: Req, chunked=False):
244
243
  """Cache request when it is unfinished."""
245
244
  if self.disable:
246
245
  return
@@ -261,7 +260,9 @@ class RadixCache(BasePrefixCache):
261
260
  page_aligned_token_ids = token_ids[:page_aligned_len]
262
261
 
263
262
  # Radix Cache takes one ref in memory pool
264
- new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
263
+ new_prefix_len = self.insert(
264
+ page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
265
+ )
265
266
  self.token_to_kv_pool_allocator.free(
266
267
  kv_indices[len(req.prefix_indices) : new_prefix_len]
267
268
  )
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
181
181
  self.dec_lock_ref(req.last_node)
182
182
  self.req_to_token_pool.free(req.req_pool_idx)
183
183
 
184
- def cache_unfinished_req(self, req: Req):
184
+ def cache_unfinished_req(self, req: Req, chunked=False):
185
185
  """Cache request when it is unfinished."""
186
186
  assert req.req_pool_idx is not None
187
187
  token_ids = req.fill_ids
@@ -4,10 +4,12 @@ import json
4
4
  import logging
5
5
  import threading
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple
7
+ from typing import Dict, List, Optional, OrderedDict, Tuple
8
8
 
9
+ import orjson
9
10
  import requests
10
- from fastapi import FastAPI, HTTPException, Request, status
11
+ from fastapi import FastAPI, HTTPException, Request, Response
12
+ from fastapi.responses import ORJSONResponse
11
13
  from requests.adapters import HTTPAdapter
12
14
  from urllib3.util.retry import Retry
13
15
 
@@ -24,10 +26,10 @@ class RankMetadata:
24
26
  """Holds all metadata for a single rank."""
25
27
 
26
28
  def __init__(self, num_pages: int):
27
- self.lock = threading.RLock()
29
+ self.lock = threading.Lock()
28
30
  self.num_pages = num_pages
29
31
  self.free_pages: List[int] = list(range(num_pages))
30
- self.key_to_index: Dict[str, int] = {}
32
+ self.key_to_index: OrderedDict[str, int] = OrderedDict()
31
33
  # Todo: Support multi files for HF3FS
32
34
 
33
35
  def exists_keys(self, keys: List[str]) -> List[bool]:
@@ -46,16 +48,18 @@ class RankMetadata:
46
48
  for i, (key, prefix_key) in enumerate(keys):
47
49
  if key in self.key_to_index:
48
50
  results[i] = (True, self.key_to_index[key])
51
+ self.key_to_index.move_to_end(key)
49
52
  else:
50
53
  new_keys_to_process.append((i, key, prefix_key))
51
54
 
52
55
  # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
53
56
  for i, key, prefix_key in new_keys_to_process:
54
57
  if len(self.free_pages) > 0:
55
- page_idx = self.free_pages.pop()
56
- results[i] = (False, page_idx)
58
+ page_index = self.free_pages.pop()
57
59
  else:
58
- results[i] = (False, -1)
60
+ page_index = self.key_to_index.popitem(last=False)[1]
61
+
62
+ results[i] = (False, page_index)
59
63
 
60
64
  return results
61
65
 
@@ -68,6 +72,7 @@ class RankMetadata:
68
72
  with self.lock:
69
73
  for key, page_index in written_keys_to_confirm:
70
74
  self.key_to_index[key] = page_index
75
+ self.key_to_index.move_to_end(key)
71
76
 
72
77
  for page_index in pages_to_release:
73
78
  if page_index not in self.free_pages:
@@ -94,7 +99,14 @@ class RankMetadata:
94
99
  def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
95
100
  """Get page indices for keys."""
96
101
  with self.lock:
97
- return [self.key_to_index.get(key) for key in keys]
102
+ results = []
103
+ for key in keys:
104
+ if key in self.key_to_index:
105
+ results.append(self.key_to_index[key])
106
+ self.key_to_index.move_to_end(key)
107
+ else:
108
+ results.append(None)
109
+ return results
98
110
 
99
111
 
100
112
  class GlobalMetadataState:
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
182
194
 
183
195
  def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
184
196
  self.state = GlobalMetadataState(persistence_path, save_interval)
185
- self.app = FastAPI()
197
+ self.app = FastAPI(default_response_class=ORJSONResponse)
198
+
186
199
  self._setup_routes()
187
200
 
188
201
  def _setup_routes(self):
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
199
212
 
200
213
  def get_rank_metadata(self, rank: int) -> RankMetadata:
201
214
  """Get rank metadata with proper error handling."""
202
- with self.state.global_lock:
203
- if rank not in self.state.ranks:
204
- raise HTTPException(
205
- status_code=404,
206
- detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
207
- )
208
- return self.state.ranks[rank]
215
+ if rank not in self.state.ranks:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
219
+ )
220
+ return self.state.ranks[rank]
221
+
222
+ async def _read_json(self, request: Request) -> dict:
223
+ """Parse request JSON using orjson if available."""
224
+ body = await request.body()
225
+ return orjson.loads(body)
226
+
227
+ def _json_response(self, content: dict):
228
+ """Return ORJSONResponse when available to bypass jsonable_encoder."""
229
+ return ORJSONResponse(content)
209
230
 
210
231
  async def initialize(self, rank: int, request: Request):
211
232
  """Initialize a rank with specified number of pages."""
212
- data = await request.json()
233
+ data = await self._read_json(request)
213
234
  num_pages = data["num_pages"]
214
235
  with self.state.global_lock:
215
236
  if rank in self.state.ranks:
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
223
244
  else:
224
245
  logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
225
246
  self.state.ranks[rank] = RankMetadata(num_pages)
226
- return {"message": f"Rank {rank} is ready."}
247
+ return Response(status_code=204)
227
248
 
228
249
  async def exists(self, rank: int, request: Request):
229
250
  """Check if keys exist in metadata."""
230
- data = await request.json()
251
+ data = await self._read_json(request)
231
252
  keys = data["keys"]
232
253
  metadata = self.get_rank_metadata(rank)
233
254
  results = metadata.exists_keys(keys)
234
- return {"exists": results}
255
+ return self._json_response({"exists": results})
235
256
 
236
257
  async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
237
258
  """Reserve and allocate page indices for keys."""
238
- data = await request.json()
259
+ data = await self._read_json(request)
239
260
  metadata = self.get_rank_metadata(rank)
240
261
  keys = data["keys"]
241
262
  results = metadata.reserve_and_allocate_page_indices(keys)
242
- return {"indices": results}
263
+ return self._json_response({"indices": results})
243
264
 
244
265
  async def confirm_write(self, rank: int, request: Request):
245
266
  """Confirm write operations and release pages."""
246
- data = await request.json()
267
+ data = await self._read_json(request)
247
268
  metadata = self.get_rank_metadata(rank)
248
269
  success_written_keys = data.get("written_keys_to_confirm", [])
249
270
  released_pages = data.get("pages_to_release", [])
250
271
 
251
272
  metadata.confirm_write(success_written_keys, released_pages)
252
273
 
253
- return {
254
- "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
255
- }
274
+ return Response(status_code=204)
256
275
 
257
276
  async def delete_keys(self, rank: int, request: Request):
258
277
  """Delete keys from metadata."""
259
- data = await request.json()
278
+ data = await self._read_json(request)
260
279
  metadata = self.get_rank_metadata(rank)
261
280
  count = metadata.delete_keys(data["keys"])
262
- return {"message": f"Rank {rank}: {count} keys deleted."}
281
+ return Response(status_code=204)
263
282
 
264
283
  async def clear(self, rank: int):
265
284
  """Clear all metadata for a rank."""
266
285
  metadata = self.get_rank_metadata(rank)
267
286
  metadata.clear_all()
268
- return {"message": f"Rank {rank}: Metadata cleared."}
287
+ return Response(status_code=204)
269
288
 
270
289
  async def get_page_indices(self, rank: int, request: Request):
271
290
  """Get page indices for keys."""
272
- data = await request.json()
291
+ data = await self._read_json(request)
273
292
  metadata = self.get_rank_metadata(rank)
274
293
  keys = data["keys"]
275
294
  results = metadata.get_page_indices(keys)
276
- return {"indices": results}
295
+ return self._json_response({"indices": results})
277
296
 
278
297
  def run(self, host: str = "0.0.0.0", port: int = 18000):
279
298
  """Run the metadata server."""
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
309
328
  status_forcelist=[500, 502, 503, 504],
310
329
  allowed_methods=["GET", "POST"],
311
330
  )
312
- adapter = HTTPAdapter(max_retries=retry_strategy)
331
+ adapter = HTTPAdapter(
332
+ max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
333
+ )
313
334
  self._session.mount("http://", adapter)
314
335
 
315
336
  def _post(self, endpoint: str, json_data: dict) -> dict:
316
337
  try:
317
- response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
338
+ url = f"{self.base_url}/{endpoint}"
339
+ headers = {"Content-Type": "application/json"}
340
+ payload = orjson.dumps(json_data) # type: ignore[union-attr]
341
+ response = self._session.post(url, data=payload, headers=headers)
318
342
  response.raise_for_status()
319
- return response.json()
343
+
344
+ if response.status_code == 204 or not response.content:
345
+ return {}
346
+ return orjson.loads(response.content) # type: ignore[union-attr]
320
347
  except requests.exceptions.RequestException as e:
321
348
  logging.error(f"Failed to POST to {endpoint} after retries: {e}")
322
349
  raise RuntimeError(f"Failed to connect to metadata server: {e}") from e