sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/srt/configs/model_config.py +2 -1
  2. sglang/srt/distributed/parallel_state.py +3 -1
  3. sglang/srt/entrypoints/engine.py +1 -1
  4. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  5. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  6. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  7. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  8. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  10. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  11. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  12. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  13. sglang/srt/layers/quantization/w4afp8.py +30 -25
  14. sglang/srt/managers/detokenizer_manager.py +0 -34
  15. sglang/srt/managers/multi_tokenizer_mixin.py +44 -6
  16. sglang/srt/managers/scheduler.py +3 -0
  17. sglang/srt/mem_cache/hiradix_cache.py +19 -3
  18. sglang/srt/mem_cache/memory_pool_host.py +2 -0
  19. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  20. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +27 -6
  21. sglang/srt/models/deepseek_v2.py +5 -0
  22. sglang/srt/models/gpt_oss.py +5 -4
  23. sglang/srt/models/longcat_flash.py +26 -15
  24. sglang/srt/models/longcat_flash_nextn.py +23 -15
  25. sglang/srt/utils.py +0 -10
  26. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  27. sglang/version.py +1 -1
  28. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +2 -2
  29. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +32 -29
  30. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  31. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
175
175
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
176
176
  assert num_experts % self.moe_ep_size == 0
177
177
  self.num_local_experts = num_experts // self.moe_ep_size
178
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
179
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
178
180
  if self.moe_ep_size > 1:
179
181
  # TODO(ch-wan): support shared experts fusion
180
182
  # Create a tensor of size num_experts filled with -1
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
593
595
 
594
596
  if (
595
597
  "compressed" in self.quant_method.__class__.__name__.lower()
596
- and param.data[expert_id] != 1
597
- and (param.data[expert_id] - loaded_weight).abs() > 1e-5
598
+ or "w4afp8" in self.quant_config.get_name()
599
+ and (param.data[expert_id] != 1).any()
600
+ and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
598
601
  ):
599
602
  raise ValueError(
600
603
  "input_scales of w1 and w3 of a layer "
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from sglang.srt.utils import is_cuda, is_hip
9
+
10
+ _is_cuda = is_cuda()
11
+ _is_hip = is_hip()
12
+
13
+ if _is_cuda or _is_hip:
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
15
+
16
+
17
+ def moe_align_block_size(
18
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
19
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ """
21
+ Aligns the token distribution across experts to be compatible with block
22
+ size for matrix multiplication.
23
+
24
+ Parameters:
25
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
26
+ top-k expert indices for each token.
27
+ - block_size: The block size used in block matrix multiplication.
28
+ - num_experts: The total number of experts.
29
+
30
+ Returns:
31
+ - sorted_token_ids: A tensor containing the sorted token indices according
32
+ to their allocated expert.
33
+ - expert_ids: A tensor indicating the assigned expert index for each block.
34
+ - num_tokens_post_padded: The total number of tokens after padding,
35
+ ensuring divisibility by block_size.
36
+
37
+ This function pads the number of tokens that each expert needs to process
38
+ so that it is divisible by block_size.
39
+ Padding ensures that during block matrix multiplication, the dimensions
40
+ align correctly.
41
+
42
+ Example:
43
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
44
+ block_size = 4, and num_experts = 4:
45
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
46
+ with each expert needing to process 3 tokens.
47
+ - As block_size is 4, we pad 1 token for each expert.
48
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
49
+ - Then append padding tokens [12, 12, 12, 12] for each block.
50
+ - After sorting by expert index, we obtain token_ids
51
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
52
+ Tokens 12 are non-existent (padding) and are ignored in
53
+ the subsequent matrix multiplication.
54
+ - The padding ensures that the total number of tokens is now divisible
55
+ by block_size for proper block matrix operations.
56
+ """
57
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
58
+ sorted_ids = torch.empty(
59
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
60
+ )
61
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
62
+ expert_ids = torch.empty(
63
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
64
+ )
65
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
66
+
67
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
68
+ cumsum_buffer = torch.empty(
69
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
70
+ )
71
+
72
+ # Threshold based on benchmark results
73
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
74
+ if not fuse_sorted_ids_padding:
75
+ sorted_ids.fill_(topk_ids.numel())
76
+
77
+ sgl_moe_align_block_size(
78
+ topk_ids,
79
+ num_experts + 1,
80
+ block_size,
81
+ sorted_ids,
82
+ expert_ids,
83
+ num_tokens_post_pad,
84
+ cumsum_buffer,
85
+ fuse_sorted_ids_padding,
86
+ )
87
+ return sorted_ids, expert_ids, num_tokens_post_pad
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
132
132
  kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
133
  )
134
134
 
135
+ old_compile_mode = deep_gemm.get_compile_mode()
136
+ deep_gemm.set_compile_mode(1)
135
137
  # TODO can use multi thread
136
138
  for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
137
139
  executor.execute(m=m)
140
+ deep_gemm.set_compile_mode(old_compile_mode)
141
+
142
+ # clean up input buffers
143
+ torch.cuda.current_stream().synchronize()
144
+ del executor
145
+ torch.cuda.empty_cache()
138
146
 
139
147
 
140
148
  class _BaseWarmupExecutor:
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
10
12
  from sglang.srt.layers.quantization.base_config import (
11
13
  FusedMoEMethodBase,
12
14
  QuantizationConfig,
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
91
93
  from sglang.srt.layers.linear import LinearBase
92
94
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
93
95
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
96
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
94
97
 
95
98
  if isinstance(layer, LinearBase):
96
99
  if is_layer_skipped(prefix, self.ignored_layers):
97
100
  return UnquantizedLinearMethod()
98
101
  return Fp8LinearMethod(self)
99
- elif isinstance(layer, EPMoE):
102
+ elif isinstance(layer, FusedMoE):
100
103
  return W4AFp8MoEMethod(self)
101
104
  return None
102
105
 
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
104
107
  return []
105
108
 
106
109
 
107
- class W4AFp8MoEMethod(FusedMoEMethodBase):
110
+ def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
111
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
112
+ s_shape = scales.shape
113
+ # Reshape to separate groups of 4
114
+ alignment = 4 if s_shape[2] % 4 == 0 else 1
115
+ scales_interleaved = scales.reshape(
116
+ s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
117
+ )
118
+ # Permute dimensions to interleave
119
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
120
+ # Reshape back to original dimensions but with interleaved values
121
+ scales_interleaved = scales_interleaved.reshape(
122
+ s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
123
+ )
124
+ return scales_interleaved.contiguous()
125
+
108
126
 
127
+ class W4AFp8MoEMethod(FusedMoEMethodBase):
109
128
  def __init__(self, quant_config: W4AFp8Config):
110
129
  self.quant_config = quant_config
111
130
 
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
234
253
 
235
254
  return
236
255
 
237
- def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
238
- """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
239
- s_shape = scales.shape
240
- # Reshape to separate groups of 4
241
- scales_interleaved = scales.reshape(
242
- s_shape[0], s_shape[1], (s_shape[2] // 4), 4
243
- )
244
- # Permute dimensions to interleave
245
- scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
246
- # Reshape back to original dimensions but with interleaved values
247
- scales_interleaved = scales_interleaved.reshape(
248
- s_shape[0], s_shape[2] // 4, s_shape[1] * 4
249
- )
250
- return scales_interleaved.contiguous()
251
-
252
256
  def process_weights_after_loading(self, layer: Module) -> None:
253
257
  dtype = torch.bfloat16
254
258
  device = layer.w2_weight.device
255
259
 
256
260
  # Interleave w13_weight_scale (gate_up_proj)
257
261
  w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
258
- w13_weight_scale = self._interleave_scales(w13_weight_scale)
262
+ w13_weight_scale = interleave_scales(w13_weight_scale)
259
263
  layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
260
264
 
261
265
  # Interleave w2_weight_scale (down_proj)
262
266
  w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
263
- w2_weight_scale = self._interleave_scales(w2_weight_scale)
267
+ w2_weight_scale = interleave_scales(w2_weight_scale)
264
268
  layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
265
269
 
266
270
  # Process input scales
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
291
295
 
292
296
  topk_weights, topk_ids, _ = topk_output
293
297
  local_topk_ids = topk_ids
294
- local_topk_ids = torch.where(
295
- topk_ids == -1,
296
- layer.num_experts,
297
- topk_ids,
298
- )
298
+ if get_moe_expert_parallel_world_size() > 1:
299
+ local_topk_ids = torch.where(
300
+ topk_ids == -1,
301
+ layer.num_experts,
302
+ topk_ids,
303
+ )
299
304
 
300
305
  output = cutlass_w4a8_moe(
301
306
  layer.start_expert_id,
@@ -39,7 +39,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
39
39
  from sglang.srt.utils import (
40
40
  configure_logger,
41
41
  freeze_gc,
42
- get_worker_ids_from_req_rids,
43
42
  get_zmq_socket,
44
43
  kill_itself_when_parent_died,
45
44
  )
@@ -120,39 +119,6 @@ class DetokenizerManager(MultiTokenizerMixin):
120
119
  if output is not None:
121
120
  self.send_to_tokenizer.send_pyobj(output)
122
121
 
123
- def multi_tokenizer_manager_event_loop(self):
124
- """The event loop that handles requests, for multi tokenizer manager mode only"""
125
- self.create_sockets_mapping()
126
- while True:
127
- recv_obj = self.recv_from_scheduler.recv_pyobj()
128
- output = self._request_dispatcher(recv_obj)
129
- if output is None:
130
- continue
131
- # Extract worker_id from rid
132
- if isinstance(recv_obj.rids, list):
133
- worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
134
- else:
135
- raise RuntimeError(
136
- f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
137
- )
138
-
139
- # Send data using the corresponding socket
140
- for i, worker_id in enumerate(worker_ids):
141
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
142
- if self.register_tokenizer_ipc(recv_obj, worker_id):
143
- logger.info(
144
- f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
145
- )
146
- continue
147
- else:
148
- if worker_id not in self.tokenizer_mapping:
149
- logger.error(
150
- f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
151
- )
152
- continue
153
- new_output = self._handle_output_by_index(output, i)
154
- self.tokenizer_mapping[worker_id].send_pyobj(new_output)
155
-
156
122
  def trim_matched_stop(
157
123
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
158
124
  ):
@@ -37,11 +37,7 @@ from sglang.srt.managers.io_struct import (
37
37
  )
38
38
  from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
39
39
  from sglang.srt.server_args import PortArgs, ServerArgs
40
- from sglang.srt.utils import (
41
- get_worker_ids_from_req_rids,
42
- get_zmq_socket,
43
- kill_process_tree,
44
- )
40
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
45
41
  from sglang.utils import get_exception_traceback
46
42
 
47
43
  logger = logging.getLogger(__name__)
@@ -344,6 +340,48 @@ class MultiTokenizerMixin:
344
340
  new_output = output
345
341
  return new_output
346
342
 
343
+ def get_worker_ids_from_req_rids(self, rids):
344
+ if isinstance(rids, list):
345
+ worker_ids = [int(rid.split("_")[0]) for rid in rids]
346
+ elif isinstance(rids, str):
347
+ worker_ids = [int(rids.split("_")[0])]
348
+ else:
349
+ worker_ids = []
350
+ return worker_ids
351
+
352
+ def multi_tokenizer_manager_event_loop(self):
353
+ """The event loop that handles requests, for multi tokenizer manager mode only"""
354
+ self.create_sockets_mapping()
355
+ while True:
356
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
357
+ output = self._request_dispatcher(recv_obj)
358
+ if output is None:
359
+ continue
360
+ # Extract worker_id from rid
361
+ if isinstance(recv_obj.rids, list):
362
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
363
+ else:
364
+ raise RuntimeError(
365
+ f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
366
+ )
367
+
368
+ # Send data using the corresponding socket
369
+ for i, worker_id in enumerate(worker_ids):
370
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
371
+ if self.register_tokenizer_ipc(recv_obj, worker_id):
372
+ logger.info(
373
+ f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
374
+ )
375
+ continue
376
+ else:
377
+ if worker_id not in self.tokenizer_mapping:
378
+ logger.error(
379
+ f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
380
+ )
381
+ continue
382
+ new_output = self._handle_output_by_index(output, i)
383
+ self.tokenizer_mapping[worker_id].send_pyobj(new_output)
384
+
347
385
  def clear_tokenizer_mapping(self):
348
386
  if hasattr(self, "tokenizer_mapping"):
349
387
  for socket in self.tokenizer_mapping.values():
@@ -406,7 +444,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
406
444
  worker_ids = [recv_obj.worker_id]
407
445
  recv_obj = recv_obj.obj
408
446
  else:
409
- worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
447
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
410
448
 
411
449
  if len(worker_ids) == 0:
412
450
  logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
@@ -2403,6 +2403,9 @@ class Scheduler(
2403
2403
  # This only works for requests that have not started anything.
2404
2404
  # We still need to send something back to TokenizerManager to clean up the state.
2405
2405
  req = self.waiting_queue.pop(i)
2406
+ if self.enable_hicache_storage:
2407
+ # to release prefetch events associated with the request
2408
+ self.tree_cache.release_aborted_request(req.rid)
2406
2409
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2407
2410
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2408
2411
  if self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
468
468
 
469
469
  # todo: more policies for prefetch progress such as timeout
470
470
  # the current policy is to prefetch with best effort and terminate when queuing is over
471
- last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
471
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
472
472
  req_id
473
- ]
473
+ )
474
474
 
475
475
  if operation.host_indices is None:
476
476
  # prefetch has not been issued due to insufficient host memory
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
512
512
  host_indices[min_completed_tokens:completed_tokens]
513
513
  )
514
514
  last_host_node.release_host()
515
- del self.ongoing_prefetch[req_id]
516
515
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
517
516
 
518
517
  return True
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
771
770
  if not cur_child.evicted:
772
771
  stack.append(cur_child)
773
772
  return ret_list
773
+
774
+ def release_aborted_request(self, rid: str):
775
+ if rid not in self.ongoing_prefetch:
776
+ return
777
+
778
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
779
+ rid
780
+ )
781
+ if operation.host_indices is None:
782
+ return
783
+
784
+ completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
785
+ if self.tp_world_size > 1:
786
+ torch.distributed.barrier(group=self.tp_group)
787
+ last_host_node.release_host()
788
+ self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
789
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -467,6 +467,7 @@ class MHATokenToKVPoolHost(HostKVCache):
467
467
  ptr_list = []
468
468
  key_list = []
469
469
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
470
+ indices = indices.tolist()
470
471
  v_offset = (
471
472
  self.layer_num
472
473
  * self.size
@@ -706,6 +707,7 @@ class MLATokenToKVPoolHost(HostKVCache):
706
707
  ptr_list = []
707
708
  key_list = []
708
709
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
710
+ indices = indices.tolist()
709
711
  for index in range(0, len(indices), self.page_size):
710
712
  k_ptr = (
711
713
  kv_buffer_data_ptr
@@ -4,10 +4,12 @@ import json
4
4
  import logging
5
5
  import threading
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple
7
+ from typing import Dict, List, Optional, OrderedDict, Tuple
8
8
 
9
+ import orjson
9
10
  import requests
10
- from fastapi import FastAPI, HTTPException, Request, status
11
+ from fastapi import FastAPI, HTTPException, Request, Response
12
+ from fastapi.responses import ORJSONResponse
11
13
  from requests.adapters import HTTPAdapter
12
14
  from urllib3.util.retry import Retry
13
15
 
@@ -24,10 +26,10 @@ class RankMetadata:
24
26
  """Holds all metadata for a single rank."""
25
27
 
26
28
  def __init__(self, num_pages: int):
27
- self.lock = threading.RLock()
29
+ self.lock = threading.Lock()
28
30
  self.num_pages = num_pages
29
31
  self.free_pages: List[int] = list(range(num_pages))
30
- self.key_to_index: Dict[str, int] = {}
32
+ self.key_to_index: OrderedDict[str, int] = OrderedDict()
31
33
  # Todo: Support multi files for HF3FS
32
34
 
33
35
  def exists_keys(self, keys: List[str]) -> List[bool]:
@@ -46,16 +48,18 @@ class RankMetadata:
46
48
  for i, (key, prefix_key) in enumerate(keys):
47
49
  if key in self.key_to_index:
48
50
  results[i] = (True, self.key_to_index[key])
51
+ self.key_to_index.move_to_end(key)
49
52
  else:
50
53
  new_keys_to_process.append((i, key, prefix_key))
51
54
 
52
55
  # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
53
56
  for i, key, prefix_key in new_keys_to_process:
54
57
  if len(self.free_pages) > 0:
55
- page_idx = self.free_pages.pop()
56
- results[i] = (False, page_idx)
58
+ page_index = self.free_pages.pop()
57
59
  else:
58
- results[i] = (False, -1)
60
+ page_index = self.key_to_index.popitem(last=False)[1]
61
+
62
+ results[i] = (False, page_index)
59
63
 
60
64
  return results
61
65
 
@@ -68,6 +72,7 @@ class RankMetadata:
68
72
  with self.lock:
69
73
  for key, page_index in written_keys_to_confirm:
70
74
  self.key_to_index[key] = page_index
75
+ self.key_to_index.move_to_end(key)
71
76
 
72
77
  for page_index in pages_to_release:
73
78
  if page_index not in self.free_pages:
@@ -94,7 +99,14 @@ class RankMetadata:
94
99
  def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
95
100
  """Get page indices for keys."""
96
101
  with self.lock:
97
- return [self.key_to_index.get(key) for key in keys]
102
+ results = []
103
+ for key in keys:
104
+ if key in self.key_to_index:
105
+ results.append(self.key_to_index[key])
106
+ self.key_to_index.move_to_end(key)
107
+ else:
108
+ results.append(None)
109
+ return results
98
110
 
99
111
 
100
112
  class GlobalMetadataState:
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
182
194
 
183
195
  def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
184
196
  self.state = GlobalMetadataState(persistence_path, save_interval)
185
- self.app = FastAPI()
197
+ self.app = FastAPI(default_response_class=ORJSONResponse)
198
+
186
199
  self._setup_routes()
187
200
 
188
201
  def _setup_routes(self):
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
199
212
 
200
213
  def get_rank_metadata(self, rank: int) -> RankMetadata:
201
214
  """Get rank metadata with proper error handling."""
202
- with self.state.global_lock:
203
- if rank not in self.state.ranks:
204
- raise HTTPException(
205
- status_code=404,
206
- detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
207
- )
208
- return self.state.ranks[rank]
215
+ if rank not in self.state.ranks:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
219
+ )
220
+ return self.state.ranks[rank]
221
+
222
+ async def _read_json(self, request: Request) -> dict:
223
+ """Parse request JSON using orjson if available."""
224
+ body = await request.body()
225
+ return orjson.loads(body)
226
+
227
+ def _json_response(self, content: dict):
228
+ """Return ORJSONResponse when available to bypass jsonable_encoder."""
229
+ return ORJSONResponse(content)
209
230
 
210
231
  async def initialize(self, rank: int, request: Request):
211
232
  """Initialize a rank with specified number of pages."""
212
- data = await request.json()
233
+ data = await self._read_json(request)
213
234
  num_pages = data["num_pages"]
214
235
  with self.state.global_lock:
215
236
  if rank in self.state.ranks:
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
223
244
  else:
224
245
  logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
225
246
  self.state.ranks[rank] = RankMetadata(num_pages)
226
- return {"message": f"Rank {rank} is ready."}
247
+ return Response(status_code=204)
227
248
 
228
249
  async def exists(self, rank: int, request: Request):
229
250
  """Check if keys exist in metadata."""
230
- data = await request.json()
251
+ data = await self._read_json(request)
231
252
  keys = data["keys"]
232
253
  metadata = self.get_rank_metadata(rank)
233
254
  results = metadata.exists_keys(keys)
234
- return {"exists": results}
255
+ return self._json_response({"exists": results})
235
256
 
236
257
  async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
237
258
  """Reserve and allocate page indices for keys."""
238
- data = await request.json()
259
+ data = await self._read_json(request)
239
260
  metadata = self.get_rank_metadata(rank)
240
261
  keys = data["keys"]
241
262
  results = metadata.reserve_and_allocate_page_indices(keys)
242
- return {"indices": results}
263
+ return self._json_response({"indices": results})
243
264
 
244
265
  async def confirm_write(self, rank: int, request: Request):
245
266
  """Confirm write operations and release pages."""
246
- data = await request.json()
267
+ data = await self._read_json(request)
247
268
  metadata = self.get_rank_metadata(rank)
248
269
  success_written_keys = data.get("written_keys_to_confirm", [])
249
270
  released_pages = data.get("pages_to_release", [])
250
271
 
251
272
  metadata.confirm_write(success_written_keys, released_pages)
252
273
 
253
- return {
254
- "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
255
- }
274
+ return Response(status_code=204)
256
275
 
257
276
  async def delete_keys(self, rank: int, request: Request):
258
277
  """Delete keys from metadata."""
259
- data = await request.json()
278
+ data = await self._read_json(request)
260
279
  metadata = self.get_rank_metadata(rank)
261
280
  count = metadata.delete_keys(data["keys"])
262
- return {"message": f"Rank {rank}: {count} keys deleted."}
281
+ return Response(status_code=204)
263
282
 
264
283
  async def clear(self, rank: int):
265
284
  """Clear all metadata for a rank."""
266
285
  metadata = self.get_rank_metadata(rank)
267
286
  metadata.clear_all()
268
- return {"message": f"Rank {rank}: Metadata cleared."}
287
+ return Response(status_code=204)
269
288
 
270
289
  async def get_page_indices(self, rank: int, request: Request):
271
290
  """Get page indices for keys."""
272
- data = await request.json()
291
+ data = await self._read_json(request)
273
292
  metadata = self.get_rank_metadata(rank)
274
293
  keys = data["keys"]
275
294
  results = metadata.get_page_indices(keys)
276
- return {"indices": results}
295
+ return self._json_response({"indices": results})
277
296
 
278
297
  def run(self, host: str = "0.0.0.0", port: int = 18000):
279
298
  """Run the metadata server."""
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
309
328
  status_forcelist=[500, 502, 503, 504],
310
329
  allowed_methods=["GET", "POST"],
311
330
  )
312
- adapter = HTTPAdapter(max_retries=retry_strategy)
331
+ adapter = HTTPAdapter(
332
+ max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
333
+ )
313
334
  self._session.mount("http://", adapter)
314
335
 
315
336
  def _post(self, endpoint: str, json_data: dict) -> dict:
316
337
  try:
317
- response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
338
+ url = f"{self.base_url}/{endpoint}"
339
+ headers = {"Content-Type": "application/json"}
340
+ payload = orjson.dumps(json_data) # type: ignore[union-attr]
341
+ response = self._session.post(url, data=payload, headers=headers)
318
342
  response.raise_for_status()
319
- return response.json()
343
+
344
+ if response.status_code == 204 or not response.content:
345
+ return {}
346
+ return orjson.loads(response.content) # type: ignore[union-attr]
320
347
  except requests.exceptions.RequestException as e:
321
348
  logging.error(f"Failed to POST to {endpoint} after retries: {e}")
322
349
  raise RuntimeError(f"Failed to connect to metadata server: {e}") from e