sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,12 @@ import heapq
2
2
  import logging
3
3
  import threading
4
4
  import time
5
+ from queue import Queue
5
6
  from typing import List, Optional
6
7
 
7
8
  import torch
8
9
 
9
- from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
10
11
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
12
  from sglang.srt.mem_cache.base_prefix_cache import MatchResult
12
13
  from sglang.srt.mem_cache.memory_pool import (
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
37
38
  hicache_io_backend: str,
38
39
  hicache_mem_layout: str,
39
40
  hicache_storage_backend: Optional[str] = None,
41
+ hicache_storage_prefetch_policy: Optional[str] = "best_effort",
40
42
  ):
41
43
 
42
44
  if hicache_io_backend == "direct":
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
85
87
  prefetch_threshold=self.prefetch_threshold,
86
88
  )
87
89
 
90
+ self.prefetch_stop_policy = hicache_storage_prefetch_policy
91
+ # todo: customizable storage prefetch timeout
92
+ self.prefetch_timeout = 3 # seconds
93
+ logger.info(
94
+ f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
95
+ )
96
+
88
97
  # record the nodes with ongoing write through
89
98
  self.ongoing_write_through = {}
90
99
  # record the node segments with ongoing load back
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
385
394
  for _ in range(queue_size.item()):
386
395
  req_id = self.cache_controller.prefetch_revoke_queue.get()
387
396
  if req_id in self.ongoing_prefetch:
388
- last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
397
+ last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
389
398
  last_host_node.release_host()
390
399
  del self.ongoing_prefetch[req_id]
400
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
391
401
  else:
392
402
  # the revoked operation already got terminated
393
403
  pass
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
419
429
  host_node.release_host()
420
430
  del self.ongoing_backup[ack_id]
421
431
 
422
- def check_prefetch_progress(self, req_id: str):
432
+ def can_terminate_prefetch(self, operation: PrefetchOperation):
433
+ can_terminate = True
434
+
435
+ if self.prefetch_stop_policy == "best_effort":
436
+ return can_terminate
437
+
438
+ completed = (
439
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
440
+ )
441
+
442
+ if self.prefetch_stop_policy == "wait_complete":
443
+ can_terminate = completed
444
+ elif self.prefetch_stop_policy == "timeout":
445
+ can_terminate = completed or (
446
+ time.monotonic() - operation.start_time > self.prefetch_timeout
447
+ )
448
+ else:
449
+ # unknown prefetch stop policy, just return True
450
+ return True
451
+
452
+ if self.tp_world_size > 1:
453
+ can_terminate = torch.tensor(can_terminate, dtype=torch.int)
454
+ torch.distributed.all_reduce(
455
+ can_terminate,
456
+ op=torch.distributed.ReduceOp.MIN,
457
+ group=self.tp_group,
458
+ )
459
+ can_terminate = bool(can_terminate.item())
460
+
461
+ return can_terminate
462
+
463
+ def check_prefetch_progress(self, req_id: str) -> bool:
423
464
  if req_id not in self.ongoing_prefetch:
424
465
  # there is no ongoing prefetch for this request or it has been revoked
425
- return
466
+ return True
426
467
 
427
468
  # todo: more policies for prefetch progress such as timeout
428
469
  # the current policy is to prefetch with best effort and terminate when queuing is over
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
430
471
  req_id
431
472
  ]
432
473
 
474
+ if not self.can_terminate_prefetch(operation):
475
+ return False
476
+
433
477
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
434
478
  operation
435
479
  )
436
480
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
437
481
 
438
482
  min_completed_tokens = completed_tokens
439
- if self.tp_world_size > 1:
483
+ if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
440
484
  # synchrnoize TP workers to make the same update to hiradix cache
441
485
  completed_tokens_tensor = torch.tensor(
442
486
  min_completed_tokens, dtype=torch.int
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
464
508
  )
465
509
  last_host_node.release_host()
466
510
  del self.ongoing_prefetch[req_id]
511
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
512
+
513
+ return True
467
514
 
468
515
  def match_prefix(self, key: List[int], **kwargs):
469
516
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
531
578
  host_indices,
532
579
  operation,
533
580
  )
581
+ self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
534
582
 
535
583
  def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
536
584
  node.last_access_time = time.monotonic()
@@ -618,7 +618,7 @@ class MLATokenToKVPoolHost(HostKVCache):
618
618
  elif self.layout == "page_first":
619
619
  transfer_kv_all_layer_mla_lf_pf(
620
620
  src_layers=device_pool.data_ptrs,
621
- dst_k=self.kv_buffer,
621
+ dst=self.kv_buffer,
622
622
  src_indices=device_indices,
623
623
  dst_indices=host_indices,
624
624
  item_size=self.token_stride_size,
@@ -1,24 +1,46 @@
1
+ import logging
2
+ from collections import OrderedDict
1
3
  from typing import Dict
2
4
 
3
5
  import torch
4
6
 
7
+ # Set up logging for cache behavior
8
+ logger = logging.getLogger(__name__)
9
+
5
10
 
6
11
  class MultiModalCache:
7
- """MultiModalCache is used to store vlm encoder results"""
12
+ """MultiModalCache is used to store vlm encoder results with LRU eviction"""
8
13
 
9
14
  def __init__(
10
15
  self,
11
16
  max_size: int,
12
17
  ):
13
18
  self.max_size = max_size
14
- self.mm_cache: Dict[int, torch.Tensor] = {}
19
+ self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
15
20
  self.current_size = 0
16
21
 
22
+ def _allocate(self, embedding_size: int) -> bool:
23
+ """Allocate space by evicting least recently used entries"""
24
+ evictions = 0
25
+ while self.current_size + embedding_size > self.max_size and self.mm_cache:
26
+ _, old_embedding = self.mm_cache.popitem(last=False)
27
+ evicted_size = self._get_tensor_size(old_embedding)
28
+ self.current_size -= evicted_size
29
+ evictions += evicted_size
30
+
31
+ if evictions > 0:
32
+ logger.debug(
33
+ f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
34
+ )
35
+
36
+ if self.current_size + embedding_size > self.max_size:
37
+ return False
38
+ return True
39
+
17
40
  def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
18
- if mm_hash in self.mm_cache:
19
- return True
20
41
  data_size = self._get_tensor_size(embedding)
21
- if self.current_size + data_size > self.max_size:
42
+ # Lazy free cache if not enough space
43
+ if not self._allocate(data_size):
22
44
  return False
23
45
  self.mm_cache[mm_hash] = embedding
24
46
  self.current_size += data_size
@@ -28,14 +50,12 @@ class MultiModalCache:
28
50
  return mm_hash in self.mm_cache
29
51
 
30
52
  def get(self, mm_hash: int) -> torch.Tensor:
31
- return self.mm_cache.get(mm_hash)
32
-
33
- def free(self, mm_hash: int) -> bool:
34
- if mm_hash not in self.mm_cache:
35
- return False
36
- old_embedding = self.mm_cache.pop(mm_hash)
37
- self.current_size -= self._get_tensor_size(old_embedding)
38
- return True
53
+ """Get embedding and update LRU order"""
54
+ if mm_hash in self.mm_cache:
55
+ # Move to end (most recently used)
56
+ self.mm_cache.move_to_end(mm_hash)
57
+ return self.mm_cache[mm_hash]
58
+ return None
39
59
 
40
60
  def clear(self):
41
61
  self.mm_cache.clear()
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, List, Set
5
+
6
+ import torch
7
+
8
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
10
+ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
11
+ IOHandle,
12
+ RadixTreeCpp,
13
+ TreeNodeCpp,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.managers.schedule_batch import Req
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class RadixCacheCpp(BasePrefixCache):
25
+ def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
26
+ """
27
+ Merge a list of tensors into a single tensor.
28
+ Args:
29
+ l (List[torch.Tensor]): List of tensors to merge.
30
+ Returns:
31
+ torch.Tensor: Merged tensor.
32
+ """
33
+ if len(l) == 0:
34
+ return torch.empty(0, dtype=torch.int64, device=self.device)
35
+ elif len(l) == 1:
36
+ return l[0]
37
+ else:
38
+ return torch.cat(l)
39
+
40
+ def __init__(
41
+ self,
42
+ disable: bool,
43
+ use_hicache: bool,
44
+ req_to_token_pool: ReqToTokenPool,
45
+ token_to_kv_pool: BaseTokenToKVPoolAllocator,
46
+ tp_cache_group: torch.distributed.ProcessGroup,
47
+ page_size: int,
48
+ hicache_ratio: float,
49
+ hicache_size: int,
50
+ hicache_write_policy: str,
51
+ enable_kv_cache_events: bool = False,
52
+ hicache_oracle: bool = False,
53
+ enable_write_cancel: bool = False,
54
+ ):
55
+ self.disable = disable
56
+ self.enable_write_cancel = enable_write_cancel
57
+
58
+ assert (
59
+ enable_kv_cache_events is False
60
+ ), "HiRadixCache does not support kv cache events yet"
61
+ self.kv_cache = token_to_kv_pool.get_kvcache()
62
+
63
+ # record the nodes with ongoing write through
64
+ self.ongoing_write_through: Set[IOHandle] = set()
65
+ # record the node segments with ongoing load back
66
+ self.ongoing_load_back: Set[IOHandle] = set()
67
+ # todo: dynamically adjust the threshold
68
+ self.write_through_threshold = (
69
+ 1 if hicache_write_policy == "write_through" else 2
70
+ )
71
+ self.device = token_to_kv_pool.device
72
+ self.token_to_kv_pool = token_to_kv_pool
73
+ self.req_to_token_pool = req_to_token_pool
74
+ self.page_size = page_size
75
+
76
+ self.tp_group = tp_cache_group
77
+
78
+ if not use_hicache:
79
+ self.tree = RadixTreeCpp(
80
+ disabled=self.disable,
81
+ page_size=page_size,
82
+ host_size=None, # no host cache, this should be removed in the future
83
+ write_through_threshold=self.write_through_threshold,
84
+ )
85
+ self.cache_controller = None
86
+ return # early return if hicache is not used
87
+
88
+ raise NotImplementedError("Host cache is not supported yet")
89
+
90
+ def reset(self):
91
+ if self.cache_controller is not None:
92
+ # need to clear the acks before resetting the cache controller
93
+ raise NotImplementedError("Host cache is not supported yet")
94
+ self.tree.reset()
95
+
96
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
+ self.tree.match_prefix(key)
99
+ )
100
+ return MatchResult(
101
+ device_indices=self._merge_tensor(device_indices_vec),
102
+ last_device_node=node_gpu,
103
+ last_host_node=node_cpu,
104
+ host_hit_length=host_indices_length,
105
+ )
106
+
107
+ def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ """
109
+ Insert a key-value pair into the radix tree.
110
+ Args:
111
+ key (List[int]): The key to insert, represented as a list of integers.
112
+ value (torch.Tensor): The value to associate with the key.
113
+ Returns:
114
+ int: Number of device indices that were already present in the tree before the insertion.
115
+ """
116
+ ongoing_write, length = self.tree.writing_through(key, value)
117
+ if self.cache_controller is None:
118
+ assert len(ongoing_write) == 0, "Implementation error"
119
+ return length
120
+
121
+ raise NotImplementedError("Host cache is not supported yet")
122
+
123
+ def dec_lock_ref(self, node: TreeNodeCpp):
124
+ """
125
+ Decrement the reference count of a node to root of the radix tree.
126
+ Args:
127
+ node (TreeNodeCpp): The handle of the node to decrement the reference count for.
128
+ """
129
+ self.tree.lock_ref(node, False) # do not increment
130
+
131
+ def inc_lock_ref(self, node: TreeNodeCpp):
132
+ """
133
+ Increment the reference count of from a node to root of the radix tree.
134
+ Args:
135
+ node (TreeNodeCpp): The handle of the node to increment the reference count for.
136
+ """
137
+ self.tree.lock_ref(node, True)
138
+
139
+ def evict(self, num_tokens: int):
140
+ evicted_device_indices = self.tree.evict(num_tokens)
141
+ for indice in evicted_device_indices:
142
+ self.token_to_kv_pool.free(indice)
143
+
144
+ def evictable_size(self):
145
+ return self.tree.evictable_size()
146
+
147
+ def protected_size(self):
148
+ return self.tree.protected_size()
149
+
150
+ def total_size(self):
151
+ return self.tree.total_size()
152
+
153
+ def cache_finished_req(self, req: Req):
154
+ """Cache request when it finishes."""
155
+ assert req.req_pool_idx is not None
156
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
157
+ overall_len = len(token_ids) # prefill + decode
158
+ kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
159
+
160
+ # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
+ # it will automatically align them, but length of them should be equal
162
+ old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
+ new_prefix_len = self._insert(token_ids, kv_indices)
164
+
165
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
167
+
168
+ # KVCache between old & new is newly generated, but already exists in the pool
169
+ # we need to free this newly generated kv indices
170
+ if old_prefix_len < new_prefix_len:
171
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
172
+
173
+ # need to free the unaligned part, since it cannot be inserted into the radix tree
174
+ if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
175
+ (unaligned_len := overall_len % self.page_size) > 0
176
+ ):
177
+ # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
178
+ self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
179
+
180
+ # Remove req slot release the cache lock
181
+ self.dec_lock_ref(req.last_node)
182
+ self.req_to_token_pool.free(req.req_pool_idx)
183
+
184
+ def cache_unfinished_req(self, req: Req):
185
+ """Cache request when it is unfinished."""
186
+ assert req.req_pool_idx is not None
187
+ token_ids = req.fill_ids
188
+ prefill_len = len(token_ids) # prefill only (maybe chunked)
189
+ kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
190
+
191
+ # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
+ # it will automatically align them, but length of them should be equal
193
+ old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
+ new_prefix_len = self._insert(token_ids, kv_indices)
195
+
196
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
+
199
+ # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
+ # The prefix indices need to updated to reuse the kv indices in the pool
201
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
202
+ new_indices = self._merge_tensor(new_indices_vec)
203
+ assert new_prefix_len <= len(new_indices)
204
+
205
+ # KVCache between old & new is newly generated, but already exists in the pool
206
+ # we need to free this newly generated kv indices and reuse the indices in the pool
207
+ if old_prefix_len < new_prefix_len:
208
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
209
+ reused_indices = new_indices[old_prefix_len:new_prefix_len]
210
+ self.req_to_token_pool.req_to_token[
211
+ req.req_pool_idx, old_prefix_len:new_prefix_len
212
+ ] = reused_indices
213
+
214
+ if req.last_node != new_last_node:
215
+ self.dec_lock_ref(req.last_node)
216
+ self.inc_lock_ref(new_last_node)
217
+
218
+ # NOTE: there might be unaligned tail, so we may need to append it
219
+ assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
220
+ if self.page_size != 1 and len(new_indices) < prefill_len:
221
+ req.prefix_indices = torch.cat(
222
+ [new_indices, kv_indices[len(new_indices) :]]
223
+ )
224
+ else:
225
+ req.prefix_indices = new_indices
226
+ req.last_node = new_last_node
227
+
228
+ def pretty_print(self):
229
+ return self.tree.debug_print()
@@ -96,6 +96,8 @@ class Hf3fsClient:
96
96
  )
97
97
  self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
98
98
  self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
99
+ self.shm_r.unlink()
100
+ self.shm_w.unlink()
99
101
 
100
102
  self.rlock = threading.RLock()
101
103
  self.wlock = threading.RLock()
@@ -176,8 +178,6 @@ class Hf3fsClient:
176
178
  del self.iov_w
177
179
  self.shm_r.close()
178
180
  self.shm_w.close()
179
- self.shm_r.unlink()
180
- self.shm_w.unlink()
181
181
 
182
182
  def flush(self) -> None:
183
183
  os.fsync(self.file)
@@ -0,0 +1,35 @@
1
+ #include <torch/extension.h>
2
+
3
+ #include <cstring>
4
+ #include <vector>
5
+
6
+ void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
7
+ py::gil_scoped_release release;
8
+ char *src_ptr = static_cast<char *>(shm.data_ptr());
9
+ size_t current = 0;
10
+ for (size_t i = 0; i < dst.size(); ++i) {
11
+ auto &t = dst[i];
12
+ size_t t_bytes = t.numel() * t.element_size();
13
+ char *dst_ptr = static_cast<char *>(t.data_ptr());
14
+ std::memcpy(dst_ptr, src_ptr + current, t_bytes);
15
+ current += t_bytes;
16
+ }
17
+ }
18
+
19
+ void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
20
+ py::gil_scoped_release release;
21
+ char *dst_ptr = static_cast<char *>(shm.data_ptr());
22
+ size_t current = 0;
23
+ for (size_t i = 0; i < src.size(); ++i) {
24
+ auto &t = src[i];
25
+ size_t t_bytes = t.numel() * t.element_size();
26
+ char *src_ptr = static_cast<char *>(t.data_ptr());
27
+ std::memcpy(dst_ptr + current, src_ptr, t_bytes);
28
+ current += t_bytes;
29
+ }
30
+ }
31
+
32
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33
+ m.def("read_shm", &read_shm, "Read tensors from shared memory");
34
+ m.def("write_shm", &write_shm, "Write tensors to shared memory");
35
+ }
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
29
29
 
30
30
  from sglang.srt.custom_op import CustomOp
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_rank
32
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
+ set_graph_pool_id,
34
+ )
32
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
33
36
  from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
34
37
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -372,6 +375,11 @@ class CudaGraphRunner:
372
375
  dtype=torch.bool,
373
376
  device="cuda",
374
377
  )
378
+ self.next_token_logits_buffer = torch.zeros(
379
+ (self.max_num_token, self.model_runner.model_config.vocab_size),
380
+ dtype=torch.float,
381
+ device="cuda",
382
+ )
375
383
 
376
384
  # Capture
377
385
  try:
@@ -517,6 +525,7 @@ class CudaGraphRunner:
517
525
  else:
518
526
  encoder_lens = None
519
527
  mrope_positions = self.mrope_positions[:, :bs]
528
+ next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
520
529
  self.num_token_non_padded[...] = num_tokens
521
530
 
522
531
  # pipeline parallelism
@@ -567,11 +576,11 @@ class CudaGraphRunner:
567
576
  )
568
577
 
569
578
  if self.model_runner.server_args.enable_lora:
570
- # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
571
- # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
572
- lora_paths = [None] * bs
579
+ # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
580
+ # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
581
+ lora_ids = [None] * bs
573
582
  else:
574
- lora_paths = None
583
+ lora_ids = None
575
584
 
576
585
  forward_batch = ForwardBatch(
577
586
  forward_mode=self.capture_forward_mode,
@@ -579,6 +588,8 @@ class CudaGraphRunner:
579
588
  input_ids=input_ids,
580
589
  req_pool_indices=req_pool_indices,
581
590
  seq_lens=seq_lens,
591
+ next_token_logits_buffer=next_token_logits_buffer,
592
+ orig_seq_lens=seq_lens,
582
593
  req_to_token_pool=self.model_runner.req_to_token_pool,
583
594
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
584
595
  attn_backend=self.model_runner.attn_backend,
@@ -597,11 +608,11 @@ class CudaGraphRunner:
597
608
  capture_hidden_mode=self.capture_hidden_mode,
598
609
  num_token_non_padded=self.num_token_non_padded,
599
610
  global_forward_mode=self.capture_forward_mode,
600
- lora_paths=lora_paths,
611
+ lora_ids=lora_ids,
601
612
  )
602
613
  self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
603
614
 
604
- if lora_paths is not None:
615
+ if lora_ids is not None:
605
616
  self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
606
617
 
607
618
  # Attention backend
@@ -643,11 +654,15 @@ class CudaGraphRunner:
643
654
 
644
655
  run_once()
645
656
 
646
- global global_graph_memory_pool
647
- with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
657
+ if get_global_graph_memory_pool() is None:
658
+ set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
659
+ # Set graph pool id globally to be able to use symmetric memory
660
+ set_graph_pool_id(get_global_graph_memory_pool())
661
+ with torch.cuda.graph(
662
+ graph, pool=get_global_graph_memory_pool(), stream=stream
663
+ ):
648
664
  out = run_once()
649
665
 
650
- global_graph_memory_pool = graph.pool()
651
666
  return graph, out
652
667
 
653
668
  def recapture_if_needed(self, forward_batch: ForwardBatch):