sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache):
358
358
  dst_v=device_pool.v_buffer[layer_id],
359
359
  src_indices=host_indices,
360
360
  dst_indices=device_indices,
361
+ layer_id=layer_id,
361
362
  item_size=self.token_stride_size,
362
363
  src_layout_dim=self.layout_dim,
363
364
  )
@@ -471,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache):
471
472
  * self.dtype.itemsize
472
473
  )
473
474
  for index in range(0, len(indices), self.page_size):
474
- for layer_id in range(self.layer_num):
475
- k_ptr = (
476
- kv_buffer_data_ptr
477
- + indices[index]
478
- * self.head_num
479
- * self.head_dim
480
- * self.dtype.itemsize
481
- + layer_id
482
- * self.size
483
- * self.head_num
484
- * self.head_dim
485
- * self.dtype.itemsize
486
- )
487
- v_ptr = k_ptr + v_offset
488
- ptr_list.append(k_ptr)
489
- ptr_list.append(v_ptr)
490
- key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{layer_id}_k")
492
- key_list.append(f"{key_}_{layer_id}_v")
475
+ k_ptr = (
476
+ kv_buffer_data_ptr
477
+ + indices[index]
478
+ * self.layer_num
479
+ * self.head_num
480
+ * self.head_dim
481
+ * self.dtype.itemsize
482
+ )
483
+ v_ptr = k_ptr + v_offset
484
+ ptr_list.append(k_ptr)
485
+ ptr_list.append(v_ptr)
486
+ key_ = keys[index // self.page_size]
487
+ key_list.append(f"{key_}_k")
488
+ key_list.append(f"{key_}_v")
493
489
  element_size = (
494
- self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
490
+ self.layer_num
491
+ * self.dtype.itemsize
492
+ * self.page_size
493
+ * self.head_num
494
+ * self.head_dim
495
495
  )
496
496
  element_size_list = [element_size] * len(key_list)
497
497
  return key_list, ptr_list, element_size_list
@@ -585,6 +585,7 @@ class MLATokenToKVPoolHost(HostKVCache):
585
585
  dst=device_pool.kv_buffer[layer_id],
586
586
  src_indices=host_indices,
587
587
  dst_indices=device_indices,
588
+ layer_id=layer_id,
588
589
  item_size=self.token_stride_size,
589
590
  src_layout_dim=self.layout_dim,
590
591
  )
@@ -685,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache):
685
686
  key_list = []
686
687
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
687
688
  for index in range(0, len(indices), self.page_size):
688
- for layer_id in range(self.layer_num):
689
- k_ptr = (
690
- kv_buffer_data_ptr
691
- + indices[index]
692
- * (self.kv_lora_rank + self.qk_rope_head_dim)
693
- * self.dtype.itemsize
694
- + layer_id
695
- * self.size
696
- * (self.kv_lora_rank + self.qk_rope_head_dim)
697
- * self.dtype.itemsize
698
- )
699
- ptr_list.append(k_ptr)
700
- key_ = keys[index // self.page_size]
701
- key_list.append(f"{key_}_{layer_id}_k")
689
+ k_ptr = (
690
+ kv_buffer_data_ptr
691
+ + indices[index]
692
+ * self.layer_num
693
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
694
+ * self.dtype.itemsize
695
+ )
696
+ ptr_list.append(k_ptr)
697
+ key_ = keys[index // self.page_size]
698
+ key_list.append(f"{key_}_k")
702
699
  element_size = (
703
- self.dtype.itemsize
700
+ self.layer_num
701
+ * self.dtype.itemsize
704
702
  * self.page_size
705
703
  * (self.kv_lora_rank + self.qk_rope_head_dim)
706
704
  )
@@ -62,6 +62,7 @@ class TreeNode:
62
62
  self.host_value: Optional[torch.Tensor] = None
63
63
  # store hash values of each pages
64
64
  self.hash_value: Optional[List[str]] = None
65
+ self.backuped_storage = False
65
66
 
66
67
  self.id = TreeNode.counter if id is None else id
67
68
  TreeNode.counter += 1
@@ -74,10 +75,6 @@ class TreeNode:
74
75
  def backuped(self):
75
76
  return self.host_value is not None
76
77
 
77
- @property
78
- def backuped_storage(self):
79
- return self.hash_value is not None and len(self.hash_value) > 0
80
-
81
78
  def protect_host(self):
82
79
  """Protect the host value from eviction."""
83
80
  self.host_ref_counter += 1
@@ -498,7 +495,7 @@ class RadixCache(BasePrefixCache):
498
495
  # One BlockStored per ``page_size`` chunk.
499
496
  if self.enable_kv_cache_events:
500
497
  # First chunk links to the last page of the parent node (if any).
501
- if node.parent is None:
498
+ if node.parent is None or node != self.root_node:
502
499
  parent_block_hash = None
503
500
  else:
504
501
  last_page_start = (
@@ -0,0 +1,443 @@
1
+ import argparse
2
+ import atexit
3
+ import json
4
+ import logging
5
+ import threading
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import requests
10
+ from fastapi import FastAPI, HTTPException, Request, status
11
+ from requests.adapters import HTTPAdapter
12
+ from urllib3.util.retry import Retry
13
+
14
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import Hf3fsMetadataInterface
15
+
16
+ # --- Configuration ---
17
+ logging.basicConfig(
18
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
19
+ )
20
+
21
+
22
+ # --- Data Models ---
23
+ class RankMetadata:
24
+ """Holds all metadata for a single rank."""
25
+
26
+ def __init__(self, num_pages: int):
27
+ self.lock = threading.RLock()
28
+ self.num_pages = num_pages
29
+ self.free_pages: List[int] = list(range(num_pages))
30
+ self.key_to_index: Dict[str, int] = {}
31
+ # Todo: Support multi files for HF3FS
32
+
33
+ def exists_keys(self, keys: List[str]) -> List[bool]:
34
+ """Check if keys exist in metadata."""
35
+ with self.lock:
36
+ return [key in self.key_to_index for key in keys]
37
+
38
+ def reserve_and_allocate_page_indices(
39
+ self, keys: List[Tuple[str, str]]
40
+ ) -> List[Tuple[bool, int]]:
41
+ """Reserve and allocate page indices for keys."""
42
+ with self.lock:
43
+ results = [None] * len(keys)
44
+ new_keys_to_process = []
45
+
46
+ for i, (key, prefix_key) in enumerate(keys):
47
+ if key in self.key_to_index:
48
+ results[i] = (True, self.key_to_index[key])
49
+ else:
50
+ new_keys_to_process.append((i, key, prefix_key))
51
+
52
+ # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
53
+ for i, key, prefix_key in new_keys_to_process:
54
+ if len(self.free_pages) > 0:
55
+ page_idx = self.free_pages.pop()
56
+ results[i] = (False, page_idx)
57
+ else:
58
+ results[i] = (False, -1)
59
+
60
+ return results
61
+
62
+ def confirm_write(
63
+ self,
64
+ written_keys_to_confirm: List[Tuple[str, int]],
65
+ pages_to_release: List[int],
66
+ ) -> None:
67
+ """Confirm write operations and release pages."""
68
+ with self.lock:
69
+ for key, page_index in written_keys_to_confirm:
70
+ self.key_to_index[key] = page_index
71
+
72
+ for page_index in pages_to_release:
73
+ if page_index not in self.free_pages:
74
+ self.free_pages.append(page_index)
75
+
76
+ def delete_keys(self, keys: List[str]) -> int:
77
+ """Delete keys and return count of deleted keys."""
78
+ with self.lock:
79
+ count = 0
80
+ for key in keys:
81
+ if key in self.key_to_index:
82
+ page_index = self.key_to_index.pop(key)
83
+ if page_index not in self.free_pages:
84
+ self.free_pages.append(page_index)
85
+ count += 1
86
+ return count
87
+
88
+ def clear_all(self) -> None:
89
+ """Clear all metadata."""
90
+ with self.lock:
91
+ self.free_pages = list(range(self.num_pages))
92
+ self.key_to_index.clear()
93
+
94
+ def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
95
+ """Get page indices for keys."""
96
+ with self.lock:
97
+ return [self.key_to_index.get(key) for key in keys]
98
+
99
+
100
+ class GlobalMetadataState:
101
+ """Manages the state for all ranks and persistence."""
102
+
103
+ def __init__(self, persistence_path: Optional[str], save_interval: int):
104
+ self.global_lock = threading.RLock()
105
+ self.ranks: Dict[int, RankMetadata] = {}
106
+ self.persistence_path = Path(persistence_path) if persistence_path else None
107
+ self.save_interval = save_interval
108
+ self.save_timer: Optional[threading.Timer] = None
109
+ self.is_shutting_down = False
110
+
111
+ def load_from_disk(self):
112
+ if not self.persistence_path or not self.persistence_path.exists():
113
+ logging.info("Persistence file not found. Starting with a clean state.")
114
+ return
115
+
116
+ logging.info(f"Loading state from {self.persistence_path}")
117
+ try:
118
+ with open(self.persistence_path, "r") as f:
119
+ persisted_data = json.load(f)
120
+
121
+ with self.global_lock:
122
+ for rank_id_str, data in persisted_data.items():
123
+ rank_id = int(rank_id_str)
124
+ num_pages = data["num_pages"]
125
+ rank_meta = RankMetadata(num_pages)
126
+ rank_meta.free_pages = data["free_pages"]
127
+ rank_meta.key_to_index = dict(data["key_to_index"])
128
+ self.ranks[rank_id] = rank_meta
129
+ logging.info(
130
+ f"Successfully loaded metadata for {len(self.ranks)} ranks."
131
+ )
132
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
133
+ logging.error(
134
+ f"Failed to load or parse persistence file: {e}. Starting fresh.",
135
+ exc_info=True,
136
+ )
137
+ self.ranks.clear()
138
+
139
+ def save_to_disk(self):
140
+ if not self.persistence_path:
141
+ return
142
+
143
+ logging.info("Persisting metadata to disk...")
144
+ with self.global_lock:
145
+ serializable_state = {}
146
+ for rank_id, rank_meta in self.ranks.items():
147
+ with rank_meta.lock:
148
+ serializable_state[rank_id] = {
149
+ "num_pages": rank_meta.num_pages,
150
+ "free_pages": rank_meta.free_pages,
151
+ "key_to_index": list(rank_meta.key_to_index.items()),
152
+ }
153
+
154
+ try:
155
+ temp_path = self.persistence_path.with_suffix(".tmp")
156
+ with open(temp_path, "w") as f:
157
+ json.dump(serializable_state, f, indent=4)
158
+ temp_path.rename(self.persistence_path)
159
+ logging.info(f"Metadata successfully persisted to {self.persistence_path}")
160
+ except Exception as e:
161
+ logging.error(f"Failed to save metadata to disk: {e}", exc_info=True)
162
+
163
+ def schedule_save(self):
164
+ if self.is_shutting_down or not self.persistence_path:
165
+ return
166
+ self.save_to_disk()
167
+ self.save_timer = threading.Timer(self.save_interval, self.schedule_save)
168
+ self.save_timer.start()
169
+
170
+ def shutdown(self):
171
+ logging.info("Shutting down metadata server...")
172
+ self.is_shutting_down = True
173
+ if self.save_timer:
174
+ self.save_timer.cancel()
175
+ self.save_to_disk()
176
+ logging.info("Shutdown complete.")
177
+
178
+
179
+ # --- Global MetadataServer implementation ---
180
+ class Hf3fsMetadataServer:
181
+ """HF3FS Metadata Server that manages metadata for multiple ranks."""
182
+
183
+ def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
184
+ self.state = GlobalMetadataState(persistence_path, save_interval)
185
+ self.app = FastAPI()
186
+ self._setup_routes()
187
+
188
+ def _setup_routes(self):
189
+ """Setup FastAPI routes."""
190
+ self.app.post("/{rank}/initialize")(self.initialize)
191
+ self.app.post("/{rank}/exists")(self.exists)
192
+ self.app.post("/{rank}/reserve_and_allocate_page_indices")(
193
+ self.reserve_and_allocate_page_indices
194
+ )
195
+ self.app.post("/{rank}/confirm_write")(self.confirm_write)
196
+ self.app.post("/{rank}/delete_keys")(self.delete_keys)
197
+ self.app.post("/{rank}/clear")(self.clear)
198
+ self.app.post("/{rank}/get_page_indices")(self.get_page_indices)
199
+
200
+ def get_rank_metadata(self, rank: int) -> RankMetadata:
201
+ """Get rank metadata with proper error handling."""
202
+ with self.state.global_lock:
203
+ if rank not in self.state.ranks:
204
+ raise HTTPException(
205
+ status_code=404,
206
+ detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
207
+ )
208
+ return self.state.ranks[rank]
209
+
210
+ async def initialize(self, rank: int, request: Request):
211
+ """Initialize a rank with specified number of pages."""
212
+ data = await request.json()
213
+ num_pages = data["num_pages"]
214
+ with self.state.global_lock:
215
+ if rank in self.state.ranks:
216
+ logging.info(
217
+ f"Rank {rank} already exists. Initialization request ignored."
218
+ )
219
+ if self.state.ranks[rank].num_pages != num_pages:
220
+ logging.warning(
221
+ f"Rank {rank} initialized with different num_pages. Existing: {self.state.ranks[rank].num_pages}, New: {num_pages}"
222
+ )
223
+ else:
224
+ logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
225
+ self.state.ranks[rank] = RankMetadata(num_pages)
226
+ return {"message": f"Rank {rank} is ready."}
227
+
228
+ async def exists(self, rank: int, request: Request):
229
+ """Check if keys exist in metadata."""
230
+ data = await request.json()
231
+ keys = data["keys"]
232
+ metadata = self.get_rank_metadata(rank)
233
+ results = metadata.exists_keys(keys)
234
+ return {"exists": results}
235
+
236
+ async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
237
+ """Reserve and allocate page indices for keys."""
238
+ data = await request.json()
239
+ metadata = self.get_rank_metadata(rank)
240
+ keys = data["keys"]
241
+ results = metadata.reserve_and_allocate_page_indices(keys)
242
+ return {"indices": results}
243
+
244
+ async def confirm_write(self, rank: int, request: Request):
245
+ """Confirm write operations and release pages."""
246
+ data = await request.json()
247
+ metadata = self.get_rank_metadata(rank)
248
+ success_written_keys = data.get("written_keys_to_confirm", [])
249
+ released_pages = data.get("pages_to_release", [])
250
+
251
+ metadata.confirm_write(success_written_keys, released_pages)
252
+
253
+ return {
254
+ "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
255
+ }
256
+
257
+ async def delete_keys(self, rank: int, request: Request):
258
+ """Delete keys from metadata."""
259
+ data = await request.json()
260
+ metadata = self.get_rank_metadata(rank)
261
+ count = metadata.delete_keys(data["keys"])
262
+ return {"message": f"Rank {rank}: {count} keys deleted."}
263
+
264
+ async def clear(self, rank: int):
265
+ """Clear all metadata for a rank."""
266
+ metadata = self.get_rank_metadata(rank)
267
+ metadata.clear_all()
268
+ return {"message": f"Rank {rank}: Metadata cleared."}
269
+
270
+ async def get_page_indices(self, rank: int, request: Request):
271
+ """Get page indices for keys."""
272
+ data = await request.json()
273
+ metadata = self.get_rank_metadata(rank)
274
+ keys = data["keys"]
275
+ results = metadata.get_page_indices(keys)
276
+ return {"indices": results}
277
+
278
+ def run(self, host: str = "0.0.0.0", port: int = 18000):
279
+ """Run the metadata server."""
280
+ self.state.load_from_disk()
281
+ if self.state.persistence_path:
282
+ self.state.schedule_save()
283
+ atexit.register(self.state.shutdown)
284
+
285
+ import uvicorn
286
+
287
+ logging.info(f"Starting metadata server on http://{host}:{port}")
288
+ if self.state.persistence_path:
289
+ logging.info(
290
+ f"Persistence is ENABLED. Saving to '{self.state.persistence_path}' every {self.state.save_interval} seconds."
291
+ )
292
+ else:
293
+ logging.info("Persistence is DISABLED.")
294
+
295
+ uvicorn.run(self.app, host=host, port=port)
296
+
297
+
298
+ # --- Client implementation ---
299
+ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
300
+ """Global http metadata client for HF3FS."""
301
+
302
+ def __init__(self, base_url: str, max_retries: int = 3):
303
+ self.base_url = base_url.rstrip("/")
304
+ self._session = requests.Session()
305
+
306
+ retry_strategy = Retry(
307
+ total=max_retries,
308
+ backoff_factor=0.3,
309
+ status_forcelist=[500, 502, 503, 504],
310
+ allowed_methods=["GET", "POST"],
311
+ )
312
+ adapter = HTTPAdapter(max_retries=retry_strategy)
313
+ self._session.mount("http://", adapter)
314
+
315
+ def _post(self, endpoint: str, json_data: dict) -> dict:
316
+ try:
317
+ response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
318
+ response.raise_for_status()
319
+ return response.json()
320
+ except requests.exceptions.RequestException as e:
321
+ logging.error(f"Failed to POST to {endpoint} after retries: {e}")
322
+ raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
323
+
324
+ def initialize(self, rank: int, num_pages: int) -> None:
325
+ self._post(f"{rank}/initialize", {"num_pages": num_pages})
326
+
327
+ def reserve_and_allocate_page_indices(
328
+ self, rank: int, keys: List[Tuple[str, str]]
329
+ ) -> List[Tuple[bool, int]]:
330
+ response = self._post(
331
+ f"{rank}/reserve_and_allocate_page_indices", {"keys": keys}
332
+ )
333
+ return [tuple(item) for item in response.get("indices")]
334
+
335
+ def confirm_write(
336
+ self,
337
+ rank: int,
338
+ written_keys_to_confirm: List[Tuple[str, int]],
339
+ pages_to_release: List[int],
340
+ ) -> None:
341
+ self._post(
342
+ f"{rank}/confirm_write",
343
+ {
344
+ "written_keys_to_confirm": written_keys_to_confirm,
345
+ "pages_to_release": pages_to_release,
346
+ },
347
+ )
348
+
349
+ def delete_keys(self, rank: int, keys: List[str]) -> None:
350
+ self._post(f"{rank}/delete_keys", {"keys": keys})
351
+
352
+ def exists(self, rank: int, keys: List[str]) -> List[bool]:
353
+ response = self._post(f"{rank}/exists", {"keys": keys})
354
+ return response.get("exists", [])
355
+
356
+ def clear(self, rank: int) -> None:
357
+ self._post(f"{rank}/clear", {})
358
+
359
+ def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
360
+ response = self._post(f"{rank}/get_page_indices", {"keys": keys})
361
+ return response.get("indices")
362
+
363
+
364
+ class Hf3fsLocalMetadataClient(Hf3fsMetadataInterface):
365
+ """Local metadata client that directly operates on single RankMetadata in memory without metadata server."""
366
+
367
+ def __init__(self):
368
+ self.rank_metadata = None
369
+
370
+ def initialize(self, rank: int, num_pages: int) -> None:
371
+ self.rank_metadata = RankMetadata(num_pages)
372
+
373
+ def reserve_and_allocate_page_indices(
374
+ self, rank: int, keys: List[Tuple[str, str]]
375
+ ) -> List[Tuple[bool, int]]:
376
+ """Reserve and allocate page indices for keys."""
377
+ return self.rank_metadata.reserve_and_allocate_page_indices(keys)
378
+
379
+ def confirm_write(
380
+ self,
381
+ rank: int,
382
+ written_keys_to_confirm: List[Tuple[str, int]],
383
+ pages_to_release: List[int],
384
+ ) -> None:
385
+ """Confirm write operations."""
386
+ self.rank_metadata.confirm_write(written_keys_to_confirm, pages_to_release)
387
+
388
+ def delete_keys(self, rank: int, keys: List[str]) -> None:
389
+ """Delete keys."""
390
+ self.rank_metadata.delete_keys(keys)
391
+
392
+ def exists(self, rank: int, keys: List[str]) -> List[bool]:
393
+ """Check if keys exist."""
394
+ return self.rank_metadata.exists_keys(keys)
395
+
396
+ def clear(self, rank: int) -> None:
397
+ """Clear all metadata for rank."""
398
+ self.rank_metadata.clear_all()
399
+
400
+ def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
401
+ """Get page indices for keys."""
402
+ return self.rank_metadata.get_page_indices(keys)
403
+
404
+
405
+ def run_metadata_server(
406
+ host: str = "0.0.0.0",
407
+ port: int = 18000,
408
+ persistence_path: Optional[str] = None,
409
+ save_interval: int = 60,
410
+ ):
411
+ """Run the HF3FS metadata server."""
412
+ global server
413
+ server = Hf3fsMetadataServer(
414
+ persistence_path=persistence_path, save_interval=save_interval
415
+ )
416
+
417
+ server.run(host=host, port=port)
418
+
419
+
420
+ # --- Main Execution ---
421
+ if __name__ == "__main__":
422
+ parser = argparse.ArgumentParser(description="HF3FS Metadata Server")
423
+ parser.add_argument(
424
+ "--host", type=str, default="0.0.0.0", help="Host to bind the server to."
425
+ )
426
+ parser.add_argument(
427
+ "--port", type=int, default=18000, help="Port to run the server on."
428
+ )
429
+ parser.add_argument(
430
+ "--persistence-path",
431
+ type=str,
432
+ default=None,
433
+ help="Path to the file for persisting metadata. If not provided, persistence is disabled.",
434
+ )
435
+ parser.add_argument(
436
+ "--save-interval",
437
+ type=int,
438
+ default=60,
439
+ help="Interval in seconds for periodically saving metadata to disk.",
440
+ )
441
+ args = parser.parse_args()
442
+
443
+ run_metadata_server(args.host, args.port, args.persistence_path, args.save_interval)