sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/disaggregation/mini_lb.py +2 -2
  4. sglang/srt/distributed/parallel_state.py +43 -40
  5. sglang/srt/entrypoints/http_server.py +5 -1
  6. sglang/srt/entrypoints/openai/protocol.py +3 -3
  7. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  8. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  9. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  10. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  11. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  12. sglang/srt/layers/attention/aiter_backend.py +93 -68
  13. sglang/srt/layers/communicator.py +45 -7
  14. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/moe/utils.py +0 -1
  16. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  17. sglang/srt/layers/quantization/mxfp4.py +4 -1
  18. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  19. sglang/srt/layers/quantization/quark/utils.py +97 -0
  20. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  21. sglang/srt/layers/rocm_linear_utils.py +44 -0
  22. sglang/srt/layers/rotary_embedding.py +0 -18
  23. sglang/srt/managers/cache_controller.py +42 -39
  24. sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
  25. sglang/srt/managers/schedule_policy.py +3 -2
  26. sglang/srt/managers/scheduler.py +4 -100
  27. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  28. sglang/srt/managers/template_manager.py +3 -3
  29. sglang/srt/managers/tokenizer_manager.py +1 -0
  30. sglang/srt/mem_cache/allocator.py +1 -1
  31. sglang/srt/mem_cache/hicache_storage.py +15 -10
  32. sglang/srt/mem_cache/hiradix_cache.py +5 -5
  33. sglang/srt/mem_cache/memory_pool_host.py +16 -11
  34. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
  35. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  36. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  37. sglang/srt/metrics/collector.py +12 -4
  38. sglang/srt/metrics/utils.py +48 -0
  39. sglang/srt/model_executor/forward_batch_info.py +16 -17
  40. sglang/srt/model_executor/model_runner.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +240 -36
  42. sglang/srt/models/glm4_moe.py +10 -1
  43. sglang/srt/models/internvl.py +28 -0
  44. sglang/srt/models/minicpmv.py +165 -3
  45. sglang/srt/models/qwen2_moe.py +4 -1
  46. sglang/srt/models/qwen3.py +8 -2
  47. sglang/srt/models/qwen3_moe.py +39 -8
  48. sglang/srt/models/torch_native_llama.py +1 -1
  49. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  50. sglang/srt/server_args.py +79 -2
  51. sglang/srt/speculative/eagle_worker.py +158 -112
  52. sglang/srt/utils.py +12 -0
  53. sglang/test/few_shot_gsm8k.py +1 -0
  54. sglang/utils.py +1 -0
  55. sglang/version.py +1 -1
  56. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
  57. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
  58. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  59. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  60. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  61. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  62. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  63. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  64. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  65. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  66. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import threading
23
23
  from multiprocessing import shared_memory
24
24
  from typing import Dict
25
25
 
26
+ import setproctitle
26
27
  import zmq
27
28
  import zmq.asyncio
28
29
 
@@ -476,6 +477,9 @@ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
476
477
  server_args: ServerArgs,
477
478
  port_args: PortArgs,
478
479
  ):
480
+ setproctitle.setproctitle(
481
+ f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
482
+ )
479
483
  # prevent init prefill bootstrapserver again
480
484
  disaggregation_mode = server_args.disaggregation_mode
481
485
  server_args.disaggregation_mode = "null"
@@ -380,8 +380,9 @@ class PrefillAdder:
380
380
  self.log_input_tokens += extend_input_len
381
381
 
382
382
  def add_chunked_req(self, req: Req):
383
- truncated = req.extend_input_len > self.rem_chunk_tokens
384
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
383
+ _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
384
+ truncated = req.extend_input_len > _rem_tokens
385
+ req.extend_input_len = min(req.extend_input_len, _rem_tokens)
385
386
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
386
387
  self.can_run_list.append(req)
387
388
  self._update_prefill_budget(
@@ -141,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
141
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
142
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
143
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
144
- from sglang.srt.reasoning_parser import ReasoningParser
144
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
145
145
  from sglang.srt.server_args import PortArgs, ServerArgs
146
146
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
147
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -500,6 +500,7 @@ class Scheduler(
500
500
  # Init metrics stats
501
501
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
502
  self.init_kv_events(server_args.kv_events_config)
503
+ self.init_dp_balance(dp_balance_meta)
503
504
 
504
505
  # Init disaggregation
505
506
  self.disaggregation_mode = DisaggregationMode(
@@ -545,15 +546,6 @@ class Scheduler(
545
546
  ]
546
547
  )
547
548
 
548
- self.balance_meta = dp_balance_meta
549
- if (
550
- server_args.enable_dp_attention
551
- and server_args.load_balance_method == "minimum_tokens"
552
- ):
553
- assert dp_balance_meta is not None
554
-
555
- self.recv_dp_balance_id_this_term = []
556
-
557
549
  def init_tokenizer(self):
558
550
  server_args = self.server_args
559
551
  self.is_generation = self.model_config.is_generation
@@ -1126,11 +1118,7 @@ class Scheduler(
1126
1118
  self,
1127
1119
  recv_req: TokenizedGenerateReqInput,
1128
1120
  ):
1129
- if (
1130
- self.server_args.enable_dp_attention
1131
- and self.server_args.load_balance_method == "minimum_tokens"
1132
- ):
1133
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1121
+ self.maybe_update_dp_balance_data(recv_req)
1134
1122
 
1135
1123
  # Create a new request
1136
1124
  if (
@@ -1568,11 +1556,7 @@ class Scheduler(
1568
1556
 
1569
1557
  # Handle DP attention
1570
1558
  if need_dp_attn_preparation:
1571
- if (
1572
- self.server_args.load_balance_method == "minimum_tokens"
1573
- and self.forward_ct % 40 == 0
1574
- ):
1575
- self.handle_dp_balance_data(ret)
1559
+ self.maybe_handle_dp_balance_data()
1576
1560
  ret = self.prepare_mlp_sync_batch(ret)
1577
1561
 
1578
1562
  return ret
@@ -1897,86 +1881,6 @@ class Scheduler(
1897
1881
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1898
1882
  )
1899
1883
 
1900
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1901
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1902
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1903
- recv_list = self.recv_dp_balance_id_this_term
1904
- assert len(recv_list) <= 511, (
1905
- "The number of requests received this round is too large. "
1906
- "Please increase gather_tensor_size and onfly_info_size."
1907
- )
1908
- # The maximum size of the tensor used for gathering data from all workers.
1909
- gather_tensor_size = 512
1910
-
1911
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1912
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1913
- recv_tensor[0] = holding_tokens_list
1914
- recv_tensor[1] = len(
1915
- recv_list
1916
- ) # The first element is the length of the list.
1917
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1918
- recv_list, dtype=torch.int32
1919
- )
1920
-
1921
- if self.tp_rank == 0:
1922
- gathered_list = [
1923
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1924
- for _ in range(self.balance_meta.num_workers)
1925
- ]
1926
- else:
1927
- gathered_list = None
1928
-
1929
- torch.distributed.gather(
1930
- recv_tensor, gathered_list, group=self.tp_cpu_group
1931
- )
1932
-
1933
- gathered_id_list_per_worker = None
1934
- if self.tp_rank == 0:
1935
- gathered_id_list_per_worker = []
1936
- holding_tokens_list = []
1937
- for tensor in gathered_list:
1938
- holding_tokens_list.append(tensor[0].item())
1939
- list_length = tensor[1].item()
1940
- gathered_id_list_per_worker.append(
1941
- tensor[2 : list_length + 2].tolist()
1942
- )
1943
-
1944
- return gathered_id_list_per_worker, holding_tokens_list
1945
-
1946
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1947
- meta = self.balance_meta
1948
-
1949
- with meta.mutex:
1950
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1951
- assert len(new_recv_rid_lists) == len(
1952
- onfly_list
1953
- ), "num_worker not equal"
1954
- # 1.Check if the rid received by each worker this round is present in onfly.
1955
- # If it is, remove the corresponding onfly item.
1956
- worker_id = 0
1957
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1958
- for new_recv_rid in new_recv_rids:
1959
- assert (
1960
- new_recv_rid in on_fly_reqs
1961
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1962
- del on_fly_reqs[new_recv_rid]
1963
- worker_id += 1
1964
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1965
- meta.set_shared_onfly_info(onfly_list)
1966
- meta.set_shared_local_tokens(local_tokens)
1967
-
1968
- holding_tokens = self.get_load()
1969
-
1970
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1971
- holding_tokens
1972
- )
1973
-
1974
- self.recv_dp_balance_id_this_term.clear()
1975
- if self.tp_rank == 0: # only first worker write info
1976
- write_shared_dp_balance_info(
1977
- new_recv_dp_balance_id_list, holding_token_list
1978
- )
1979
-
1980
1884
  @staticmethod
1981
1885
  def prepare_mlp_sync_batch_raw(
1982
1886
  local_batch: ScheduleBatch,
@@ -1,15 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
5
  from collections import defaultdict
4
- from typing import List, Optional
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+
8
+ import torch
5
9
 
6
10
  from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
7
11
  from sglang.srt.disaggregation.utils import DisaggregationMode
12
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
8
13
  from sglang.srt.managers.schedule_policy import PrefillAdder
9
14
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
15
+ from sglang.srt.managers.utils import DPBalanceMeta
10
16
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
11
17
  from sglang.srt.utils import get_bool_env_var
12
18
 
19
+ if TYPE_CHECKING:
20
+ from sglang.srt.managers.scheduler import Scheduler
21
+
13
22
  logger = logging.getLogger(__name__)
14
23
 
15
24
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@@ -28,7 +37,9 @@ class KvMetrics:
28
37
 
29
38
 
30
39
  class SchedulerMetricsMixin:
31
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
40
+ def init_metrics(
41
+ self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
42
+ ):
32
43
  self.last_gen_throughput: float = 0.0
33
44
  self.last_input_throughput: float = 0.0
34
45
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
50
61
  labels["dp_rank"] = dp_rank
51
62
  self.metrics_collector = SchedulerMetricsCollector(labels=labels)
52
63
 
53
- def init_kv_events(self, kv_events_config: Optional[str]):
64
+ def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65
+ self.balance_meta = dp_balance_meta
66
+ if (
67
+ self.server_args.enable_dp_attention
68
+ and self.server_args.load_balance_method == "minimum_tokens"
69
+ ):
70
+ assert dp_balance_meta is not None
71
+
72
+ self.recv_dp_balance_id_this_term = []
73
+
74
+ def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
54
75
  if self.enable_kv_cache_events:
55
76
  self.kv_event_publisher = EventPublisherFactory.create(
56
77
  kv_events_config, self.attn_dp_rank
57
78
  )
58
79
 
59
80
  def log_prefill_stats(
60
- self,
81
+ self: Scheduler,
61
82
  adder: PrefillAdder,
62
83
  can_run_list: List[Req],
63
84
  running_bs: int,
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
138
159
  self._publish_kv_events()
139
160
 
140
161
  def log_decode_stats(
141
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
162
+ self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
142
163
  ):
143
164
  batch = running_batch or self.running_batch
144
165
 
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
220
241
  self._emit_kv_metrics()
221
242
  self._publish_kv_events()
222
243
 
223
- def _emit_kv_metrics(self):
244
+ def _emit_kv_metrics(self: Scheduler):
224
245
  kv_metrics = KvMetrics()
225
246
  kv_metrics.request_active_slots = self.stats.num_running_reqs
226
247
  kv_metrics.request_total_slots = self.max_running_requests
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
236
257
  if not self.send_metrics_from_scheduler.closed:
237
258
  self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
238
259
 
239
- def _publish_kv_events(self):
260
+ def _publish_kv_events(self: Scheduler):
240
261
  if self.enable_kv_cache_events:
241
262
  events = self.tree_cache.take_events()
242
263
  if events:
243
264
  batch = KVEventBatch(ts=time.time(), events=events)
244
265
  self.kv_event_publisher.publish(batch)
266
+
267
+ def maybe_update_dp_balance_data(
268
+ self: Scheduler, recv_req: TokenizedGenerateReqInput
269
+ ):
270
+ if (
271
+ self.server_args.enable_dp_attention
272
+ and self.server_args.load_balance_method == "minimum_tokens"
273
+ ):
274
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
275
+
276
+ def maybe_handle_dp_balance_data(self: Scheduler):
277
+ if (
278
+ self.server_args.load_balance_method == "minimum_tokens"
279
+ and self.forward_ct % 40 == 0
280
+ ):
281
+ holding_tokens = self.get_load()
282
+
283
+ new_recv_dp_balance_id_list, holding_token_list = (
284
+ self.gather_dp_balance_info(holding_tokens)
285
+ )
286
+
287
+ self.recv_dp_balance_id_this_term.clear()
288
+ if self.tp_rank == 0: # only first worker write info
289
+ self.write_shared_dp_balance_info(
290
+ new_recv_dp_balance_id_list, holding_token_list
291
+ )
292
+
293
+ def gather_dp_balance_info(
294
+ self: Scheduler, holding_tokens_list
295
+ ) -> Union[None, List[List[int]]]:
296
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297
+ recv_list = self.recv_dp_balance_id_this_term
298
+ assert len(recv_list) <= 511, (
299
+ "The number of requests received this round is too large. "
300
+ "Please increase gather_tensor_size and onfly_info_size."
301
+ )
302
+ # The maximum size of the tensor used for gathering data from all workers.
303
+ gather_tensor_size = 512
304
+
305
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
307
+ recv_tensor[0] = holding_tokens_list
308
+ recv_tensor[1] = len(recv_list) # The first element is the length of the list.
309
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
310
+
311
+ if self.tp_rank == 0:
312
+ gathered_list = [
313
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
314
+ for _ in range(self.balance_meta.num_workers)
315
+ ]
316
+ else:
317
+ gathered_list = None
318
+
319
+ torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
320
+
321
+ gathered_id_list_per_worker = None
322
+ if self.tp_rank == 0:
323
+ gathered_id_list_per_worker = []
324
+ holding_tokens_list = []
325
+ for tensor in gathered_list:
326
+ holding_tokens_list.append(tensor[0].item())
327
+ list_length = tensor[1].item()
328
+ gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
329
+
330
+ return gathered_id_list_per_worker, holding_tokens_list
331
+
332
+ def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
333
+ meta = self.balance_meta
334
+
335
+ with meta.mutex:
336
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
337
+ assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
338
+ # 1.Check if the rid received by each worker this round is present in onfly.
339
+ # If it is, remove the corresponding onfly item.
340
+ worker_id = 0
341
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
342
+ for new_recv_rid in new_recv_rids:
343
+ assert (
344
+ new_recv_rid in on_fly_reqs
345
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
346
+ del on_fly_reqs[new_recv_rid]
347
+ worker_id += 1
348
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
349
+ meta.set_shared_onfly_info(onfly_list)
350
+ meta.set_shared_local_tokens(local_tokens)
@@ -24,20 +24,20 @@ import os
24
24
  import re
25
25
  from typing import Optional
26
26
 
27
- from sglang.srt.code_completion_parser import (
27
+ from sglang.srt.parser.code_completion_parser import (
28
28
  CompletionTemplate,
29
29
  FimPosition,
30
30
  completion_template_exists,
31
31
  register_completion_template,
32
32
  )
33
- from sglang.srt.conversation import (
33
+ from sglang.srt.parser.conversation import (
34
34
  Conversation,
35
35
  SeparatorStyle,
36
36
  chat_template_exists,
37
37
  get_conv_template_by_model_path,
38
38
  register_conv_template,
39
39
  )
40
- from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+ from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
@@ -329,6 +329,7 @@ class TokenizerManager:
329
329
  # Metrics
330
330
  if self.enable_metrics:
331
331
  self.metrics_collector = TokenizerMetricsCollector(
332
+ server_args=server_args,
332
333
  labels={
333
334
  "model_name": self.server_args.served_model_name,
334
335
  # TODO: Add lora name/path in the future,
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
283
283
  self.swa_attn_allocator.clear()
284
284
  self.full_attn_allocator.clear()
285
285
  self.full_to_swa_index_mapping.fill_(0)
286
- self.is_in_free_group = False
286
+ self.is_not_in_free_group = True
287
287
  self.free_group = []
288
288
 
289
289
 
@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
27
27
  tp_rank: int
28
28
  tp_size: int
29
29
  is_mla_model: bool
30
+ is_page_first_layout: bool
30
31
  model_name: Optional[str]
31
32
  extra_config: Optional[dict] = None
32
33
 
@@ -135,18 +136,24 @@ class HiCacheFile(HiCacheStorage):
135
136
  ):
136
137
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
137
138
 
138
- tp_rank, tp_size, is_mla = (
139
+ tp_rank, tp_size, model_name, is_mla_model = (
139
140
  storage_config.tp_rank,
140
141
  storage_config.tp_size,
142
+ storage_config.model_name,
141
143
  storage_config.is_mla_model,
142
144
  )
143
- self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
145
+ model_name = "-".join(model_name.split("/")) if model_name else ""
146
+ if is_mla_model:
147
+ self.config_suffix = f"_{model_name}"
148
+ else:
149
+ self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
150
+
144
151
  if not os.path.exists(self.file_path) and tp_rank == 0:
145
152
  os.makedirs(self.file_path)
146
153
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
147
154
 
148
155
  def _get_suffixed_key(self, key: str) -> str:
149
- return key + self.tp_suffix
156
+ return key + self.config_suffix
150
157
 
151
158
  def get(
152
159
  self,
@@ -157,13 +164,11 @@ class HiCacheFile(HiCacheStorage):
157
164
  key = self._get_suffixed_key(key)
158
165
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
159
166
  try:
160
- # Load directly into target_location's memory buffer
161
- with open(tensor_path, "rb") as f:
162
- target_location.set_(
163
- torch.frombuffer(f.read(), dtype=target_location.dtype)
164
- .reshape(target_location.shape)
165
- .untyped_storage()
166
- )
167
+ expected = target_location.numel() * target_location.element_size()
168
+ with open(tensor_path, "rb", buffering=0) as f:
169
+ buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
170
+ if f.readinto(buf) != expected:
171
+ raise IOError(f"Short read for {key}")
167
172
  return target_location
168
173
  except FileNotFoundError:
169
174
  logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
468
468
 
469
469
  # todo: more policies for prefetch progress such as timeout
470
470
  # the current policy is to prefetch with best effort and terminate when queuing is over
471
- last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
471
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
472
472
  req_id
473
- )
473
+ ]
474
474
 
475
475
  if operation.host_indices is None:
476
476
  # prefetch has not been issued due to insufficient host memory
@@ -512,6 +512,7 @@ class HiRadixCache(RadixCache):
512
512
  host_indices[min_completed_tokens:completed_tokens]
513
513
  )
514
514
  last_host_node.release_host()
515
+ del self.ongoing_prefetch[req_id]
515
516
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
516
517
 
517
518
  return True
@@ -775,9 +776,7 @@ class HiRadixCache(RadixCache):
775
776
  if rid not in self.ongoing_prefetch:
776
777
  return
777
778
 
778
- last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
779
- rid
780
- )
779
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
781
780
  if operation.host_indices is None:
782
781
  return
783
782
 
@@ -785,5 +784,6 @@ class HiRadixCache(RadixCache):
785
784
  if self.tp_world_size > 1:
786
785
  torch.distributed.barrier(group=self.tp_group)
787
786
  last_host_node.release_host()
787
+ del self.ongoing_prefetch[rid]
788
788
  self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
789
789
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
500
500
  element_size_list = [element_size] * len(key_list)
501
501
  return key_list, ptr_list, element_size_list
502
502
 
503
- def get_buffer_with_hash(self, keys, indices):
503
+ def get_buffer_with_hash(self, keys, indices=None):
504
504
  assert self.layout == "page_first"
505
- assert len(keys) == (len(indices) // self.page_size)
505
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
506
506
 
507
507
  key_list = []
508
508
  buf_list = []
509
509
 
510
- for key, i in zip(keys, range(0, len(indices), self.page_size)):
510
+ for i in range(len(keys)):
511
+ key = keys[i]
511
512
  key_list.append(f"{key}-k")
512
- buf_list.append(self.k_buffer[i : i + self.page_size])
513
513
  key_list.append(f"{key}-v")
514
- buf_list.append(self.v_buffer[i : i + self.page_size])
514
+ if indices is not None:
515
+ index = indices[i * self.page_size]
516
+ buf_list.append(self.k_buffer[index : index + self.page_size])
517
+ buf_list.append(self.v_buffer[index : index + self.page_size])
515
518
 
516
- return key_list, buf_list
519
+ return key_list, buf_list, 2
517
520
 
518
521
 
519
522
  class MLATokenToKVPoolHost(HostKVCache):
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
728
731
  element_size_list = [element_size] * len(key_list)
729
732
  return key_list, ptr_list, element_size_list
730
733
 
731
- def get_buffer_with_hash(self, keys, indices):
734
+ def get_buffer_with_hash(self, keys, indices=None):
732
735
  assert self.layout == "page_first"
733
- assert len(keys) == (len(indices) // self.page_size)
736
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
734
737
 
735
738
  buf_list = []
736
739
 
737
- for i in range(0, len(indices), self.page_size):
738
- buf_list.append(self.kv_buffer[i : i + self.page_size])
740
+ if indices is not None:
741
+ for i in range(len(keys)):
742
+ index = indices[i * self.page_size]
743
+ buf_list.append(self.kv_buffer[index : index + self.page_size])
739
744
 
740
- return keys, buf_list
745
+ return keys, buf_list, 1
@@ -128,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage):
128
128
  dtype: torch.dtype,
129
129
  metadata_client: Hf3fsMetadataInterface,
130
130
  is_mla_model: bool = False,
131
+ is_page_first_layout: bool = False,
131
132
  ):
132
133
  self.rank = rank
133
134
  self.file_path = file_path
@@ -138,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage):
138
139
  self.dtype = dtype
139
140
  self.metadata_client = metadata_client
140
141
  self.is_mla_model = is_mla_model
142
+ self.is_page_first_layout = is_page_first_layout
141
143
  self.numel = self.bytes_per_page // self.dtype.itemsize
142
144
  self.num_pages = self.file_size // self.bytes_per_page
143
145
  self.skip_backup = False
@@ -193,9 +195,13 @@ class HiCacheHF3FS(HiCacheStorage):
193
195
  )
194
196
 
195
197
  if storage_config is not None:
196
- rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
198
+ rank, is_mla_model, is_page_first_layout = (
199
+ storage_config.tp_rank,
200
+ storage_config.is_mla_model,
201
+ storage_config.is_page_first_layout,
202
+ )
197
203
  else:
198
- rank, is_mla_model = 0, False
204
+ rank, is_mla_model, is_page_first_layout = 0, False, False
199
205
 
200
206
  mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
201
207
 
@@ -213,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage):
213
219
  entries=8,
214
220
  dtype=dtype,
215
221
  metadata_client=Hf3fsLocalMetadataClient(),
222
+ is_page_first_layout=is_page_first_layout,
216
223
  )
217
224
 
218
225
  try:
@@ -261,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage):
261
268
  dtype=dtype,
262
269
  metadata_client=metadata_client,
263
270
  is_mla_model=is_mla_model,
271
+ is_page_first_layout=is_page_first_layout,
264
272
  )
265
273
 
266
274
  def get(