sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ # Adapt from
2
+ # https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
3
+
4
+ from typing import List, Optional, Union
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.utils import get_device_core_count
11
+
12
+
13
+ @triton.jit
14
+ def apply_token_bitmask_inplace_kernel(
15
+ logits_ptr,
16
+ bitmask_ptr,
17
+ indices_ptr,
18
+ num_rows,
19
+ vocab_size,
20
+ logits_strides,
21
+ bitmask_strides,
22
+ NUM_SMS: tl.constexpr,
23
+ BLOCK_SIZE: tl.constexpr,
24
+ ):
25
+ """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
26
+ where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
27
+ the masked logits will be set to -inf.
28
+
29
+ Parameters
30
+ ----------
31
+ logits_ptr : tl.tensor
32
+ Pointer to the logits tensor to apply the bitmask to.
33
+
34
+ bitmask_ptr : tl.tensor
35
+ Pointer to the bitmask tensor to apply.
36
+
37
+ indices_ptr : Optional[tl.tensor]
38
+ Optional pointer to indices tensor specifying which rows to apply the mask to.
39
+
40
+ num_rows : int
41
+ Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
42
+
43
+ vocab_size : int
44
+ Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
45
+ same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
46
+
47
+ logits_strides : int
48
+ Stride between rows in the logits tensor.
49
+
50
+ bitmask_strides : int
51
+ Stride between rows in the bitmask tensor.
52
+
53
+ NUM_SMS : int
54
+ Number of streaming multiprocessors to use.
55
+
56
+ BLOCK_SIZE : int
57
+ Size of processing blocks.
58
+ """
59
+
60
+ pid = tl.program_id(0)
61
+ num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
62
+ for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
63
+ row_id = work_id // num_blocks
64
+ block_offset = (work_id % num_blocks) * BLOCK_SIZE
65
+ batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
66
+ offsets = block_offset + tl.arange(0, BLOCK_SIZE)
67
+ bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
68
+ vocab_mask = offsets < vocab_size
69
+ packed_bitmask_mask = bitmask_offsets < bitmask_strides
70
+ packed_bitmask = tl.load(
71
+ bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
72
+ packed_bitmask_mask,
73
+ )
74
+ bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
75
+ bitmask = bitmask.reshape(BLOCK_SIZE)
76
+
77
+ tl.store(
78
+ logits_ptr + batch_id * logits_strides + offsets,
79
+ -float("inf"),
80
+ vocab_mask & bitmask,
81
+ )
82
+
83
+
84
+ def apply_token_bitmask_inplace_triton(
85
+ logits: torch.Tensor,
86
+ bitmask: torch.Tensor,
87
+ indices: Optional[Union[List[int], torch.Tensor]] = None,
88
+ ):
89
+ NUM_SMS = get_device_core_count()
90
+ BLOCK_SIZE = 4096
91
+ BITS_PER_BLOCK = 32
92
+
93
+ # Check input dtype
94
+ assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
95
+
96
+ # Check input tensor shapes.
97
+ logits_shape = logits.shape
98
+ bitmask_shape = bitmask.shape
99
+ if logits.ndim == 1:
100
+ logits_shape = (1, logits_shape[0])
101
+ if bitmask.ndim == 1:
102
+ bitmask_shape = (1, bitmask_shape[0])
103
+
104
+ required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
105
+ assert required_bitmask_width >= bitmask_shape[1], (
106
+ f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
107
+ f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
108
+ )
109
+
110
+ vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
111
+
112
+ num_rows = None
113
+ if isinstance(indices, list) or isinstance(indices, torch.Tensor):
114
+ indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
115
+ num_rows = indices.shape[0]
116
+ else:
117
+ assert (
118
+ logits_shape[0] == bitmask_shape[0]
119
+ ), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
120
+ num_rows = logits_shape[0]
121
+
122
+ if NUM_SMS > 0:
123
+ grid = (NUM_SMS,)
124
+ else:
125
+ num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
126
+ grid = (num_rows * num_blocks,)
127
+ NUM_SMS = triton.next_power_of_2(grid[0])
128
+
129
+ apply_token_bitmask_inplace_kernel[grid](
130
+ logits,
131
+ bitmask,
132
+ indices,
133
+ num_rows,
134
+ vocab_size,
135
+ logits_shape[1],
136
+ bitmask_shape[1],
137
+ NUM_SMS,
138
+ BLOCK_SIZE,
139
+ num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
140
+ num_stages=3,
141
+ )
@@ -25,13 +25,16 @@ from xgrammar import (
25
25
  StructuralTagItem,
26
26
  TokenizerInfo,
27
27
  allocate_token_bitmask,
28
- apply_token_bitmask_inplace,
29
28
  )
30
29
 
31
30
  from sglang.srt.constrained.base_grammar_backend import (
32
31
  BaseGrammarBackend,
33
32
  BaseGrammarObject,
34
33
  )
34
+ from sglang.srt.constrained.triton_ops.bitmask_ops import (
35
+ apply_token_bitmask_inplace_triton,
36
+ )
37
+ from sglang.srt.utils import get_bool_env_var
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
@@ -48,12 +51,25 @@ class XGrammarGrammar(BaseGrammarObject):
48
51
  ctx: CompiledGrammar,
49
52
  override_stop_tokens: Optional[Union[List[int], int]],
50
53
  ) -> None:
54
+ super().__init__()
51
55
  self.matcher = matcher
52
56
  self.vocab_size = vocab_size
53
57
  self.ctx = ctx
54
58
  self.override_stop_tokens = override_stop_tokens
55
59
  self.finished = False
56
60
 
61
+ # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
62
+ # class init site to avoid re-initializing CUDA in forked subprocess.
63
+ from xgrammar.kernels import apply_token_bitmask_inplace_kernels
64
+
65
+ self.use_token_bitmask_triton = get_bool_env_var(
66
+ "SGLANG_TOKEN_BITMASK_TRITON", "false"
67
+ )
68
+ self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
69
+ "cuda", None
70
+ )
71
+ self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
72
+
57
73
  def accept_token(self, token: int):
58
74
  assert self.matcher.accept_token(token)
59
75
 
@@ -96,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
96
112
  def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
97
113
  return vocab_mask.to(device, non_blocking=True)
98
114
 
99
- @staticmethod
100
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
101
- apply_token_bitmask_inplace(logits, vocab_mask)
115
+ def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
116
+ if (
117
+ not self.use_token_bitmask_triton
118
+ and logits.device.type == "cuda"
119
+ and self.apply_vocab_mask_cuda
120
+ ):
121
+ return self.apply_vocab_mask_cuda(logits, vocab_mask)
122
+ if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
123
+ return self.apply_vocab_mask_cpu(logits, vocab_mask)
124
+ apply_token_bitmask_inplace_triton(logits, vocab_mask)
102
125
 
103
126
  def copy(self):
104
127
  matcher = GrammarMatcher(
sglang/srt/custom_op.py CHANGED
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
42
42
  return self.forward_hip
43
43
  else:
44
44
  return self.forward_native
45
-
46
-
47
- if _is_cuda:
48
- from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
49
-
50
- def scaled_fp8_quant(
51
- input: torch.Tensor,
52
- scale: Optional[torch.Tensor] = None,
53
- num_token_padding: Optional[int] = None,
54
- use_per_token_if_dynamic: bool = False,
55
- ) -> tuple[torch.Tensor, torch.Tensor]:
56
- """
57
- Quantize input tensor to FP8 (8-bit floating point) format.
58
-
59
- Args:
60
- input (torch.Tensor): Input tensor to be quantized
61
- scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
62
- If None, scales will be computed dynamically.
63
- num_token_padding (Optional[int]): If specified, pad the first dimension
64
- of the output to at least this value.
65
- use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
66
- determines the quantization granularity:
67
- - True: compute scale per token
68
- - False: compute single scale per tensor
69
-
70
- Returns:
71
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
72
- - quantized_tensor: The FP8 quantized version of input
73
- - scale_tensor: The scaling factors used for quantization
74
-
75
- Raises:
76
- AssertionError: If input is not 2D or if static scale's numel != 1
77
- """
78
- assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
79
- shape = input.shape
80
- out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
81
- if num_token_padding:
82
- shape = (max(num_token_padding, input.shape[0]), shape[1])
83
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
84
-
85
- if scale is None:
86
- # Dynamic scaling
87
- if use_per_token_if_dynamic:
88
- scale = torch.empty(
89
- (shape[0], 1), device=input.device, dtype=torch.float32
90
- )
91
- sgl_per_token_quant_fp8(input, output, scale)
92
- else:
93
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
94
- sgl_per_tensor_quant_fp8(
95
- input, output, scale, is_static=False
96
- ) # False for dynamic
97
- else:
98
- # Static scaling
99
- assert (
100
- scale.numel() == 1
101
- ), f"Expected scalar scale, got numel={scale.numel()}"
102
- sgl_per_tensor_quant_fp8(
103
- input, output, scale, is_static=True
104
- ) # True for static
105
-
106
- return output, scale
@@ -0,0 +1,8 @@
1
+ from .conn import (
2
+ BaseKVBootstrapServer,
3
+ BaseKVManager,
4
+ BaseKVReceiver,
5
+ BaseKVSender,
6
+ KVArgs,
7
+ KVPoll,
8
+ )
@@ -0,0 +1,113 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.utils import DisaggregationMode
8
+ from sglang.srt.server_args import ServerArgs
9
+
10
+
11
+ class KVArgs:
12
+ engine_rank: int
13
+ kv_data_ptrs: list[int]
14
+ kv_data_lens: list[int]
15
+ kv_item_lens: list[int]
16
+ aux_data_ptrs: list[int]
17
+ aux_data_lens: list[int]
18
+ aux_item_lens: list[int]
19
+ ib_device: str
20
+ gpu_id: int
21
+
22
+
23
+ class KVPoll:
24
+ Failed = 0
25
+ Bootstrapping = 1
26
+ WaitingForInput = 2
27
+ Transferring = 3
28
+ Success = 4
29
+
30
+
31
+ class BaseKVManager(ABC):
32
+ """Base class for managing transfers states"""
33
+
34
+ @abstractmethod
35
+ def __init__(
36
+ self,
37
+ args: KVArgs,
38
+ disaggregation_mode: DisaggregationMode,
39
+ server_args: ServerArgs,
40
+ ): ...
41
+
42
+
43
+ class BaseKVSender(ABC):
44
+
45
+ @abstractmethod
46
+ def __init__(
47
+ self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
48
+ ): ...
49
+
50
+ @abstractmethod
51
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
52
+ """
53
+ Notify the decoder server about the kv indices length and aux index
54
+ """
55
+ ...
56
+
57
+ @abstractmethod
58
+ def send(self, kv_indices: npt.NDArray[np.int64]):
59
+ """
60
+ Send the kv cache at the given kv indices to the decoder server
61
+ """
62
+ ...
63
+
64
+ @abstractmethod
65
+ def poll(self) -> KVPoll:
66
+ """
67
+ Check the status of the kv cache transfer
68
+ """
69
+ ...
70
+
71
+ @abstractmethod
72
+ def failure_exception(self):
73
+ """
74
+ Raise an exception if the kv cache transfer fails
75
+ """
76
+ ...
77
+
78
+
79
+ class BaseKVReceiver(ABC):
80
+
81
+ @abstractmethod
82
+ def __init__(
83
+ self,
84
+ mgr: BaseKVManager,
85
+ bootstrap_addr: str,
86
+ bootstrap_room: Optional[int] = None,
87
+ ): ...
88
+
89
+ @abstractmethod
90
+ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
91
+ """
92
+ Notify the prefill server about the kv indices and aux index
93
+ """
94
+ ...
95
+
96
+ @abstractmethod
97
+ def poll(self) -> KVPoll:
98
+ """
99
+ Check the status of the kv cache transfer
100
+ """
101
+ ...
102
+
103
+ @abstractmethod
104
+ def failure_exception(self):
105
+ """
106
+ Raise an exception if the kv cache transfer fails
107
+ """
108
+ ...
109
+
110
+
111
+ class BaseKVBootstrapServer(ABC):
112
+ @abstractmethod
113
+ def __init__(self, port: int): ...
@@ -24,12 +24,18 @@ import logging
24
24
  from dataclasses import dataclass
25
25
  from typing import TYPE_CHECKING, List, Optional, Tuple
26
26
 
27
+ import numpy as np
27
28
  import torch
28
29
  from torch.distributed import ProcessGroup
29
30
 
30
- from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
31
+ from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
31
32
  from sglang.srt.disaggregation.utils import (
33
+ DisaggregationMode,
34
+ KVClassType,
32
35
  ReqToMetadataIdxAllocator,
36
+ TransferBackend,
37
+ get_kv_class,
38
+ kv_to_page_indices,
33
39
  poll_and_all_reduce,
34
40
  )
35
41
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -49,7 +55,7 @@ if TYPE_CHECKING:
49
55
  @dataclass
50
56
  class DecodeRequest:
51
57
  req: Req
52
- kv_receiver: KVReceiver
58
+ kv_receiver: BaseKVReceiver
53
59
  waiting_for_input: bool = False
54
60
  metadata_buffer_index: int = -1
55
61
 
@@ -73,6 +79,7 @@ class DecodePreallocQueue:
73
79
  tp_rank: int,
74
80
  tp_size: int,
75
81
  bootstrap_port: int,
82
+ transfer_backend: TransferBackend,
76
83
  ):
77
84
  self.req_to_token_pool = req_to_token_pool
78
85
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -92,9 +99,10 @@ class DecodePreallocQueue:
92
99
 
93
100
  # Queue for requests pending pre-allocation
94
101
  self.queue: List[DecodeRequest] = []
102
+ self.transfer_backend = transfer_backend
95
103
  self.kv_manager = self._init_kv_manager()
96
104
 
97
- def _init_kv_manager(self) -> KVManager:
105
+ def _init_kv_manager(self) -> BaseKVManager:
98
106
  kv_args = KVArgs()
99
107
  kv_args.engine_rank = self.tp_rank
100
108
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
@@ -114,14 +122,19 @@ class DecodePreallocQueue:
114
122
  kv_args.aux_item_lens = [
115
123
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
116
124
  ]
117
- kv_args.ib_device = "mock-ib-device"
118
- kv_manager = KVManager(kv_args)
125
+ kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
126
+ kv_args.gpu_id = self.scheduler.gpu_id
127
+ kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
128
+ kv_manager = kv_manager_class(
129
+ kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
130
+ )
119
131
  return kv_manager
120
132
 
121
133
  def add(self, req: Req) -> None:
122
134
  """Add a request to the pending queue."""
123
135
 
124
- kv_receiver = KVReceiver(
136
+ kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137
+ kv_receiver = kv_receiver_class(
125
138
  mgr=self.kv_manager,
126
139
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
127
140
  bootstrap_room=req.bootstrap_room,
@@ -186,13 +199,17 @@ class DecodePreallocQueue:
186
199
  ]
187
200
  .cpu()
188
201
  .numpy()
202
+ .astype(np.int64)
189
203
  )
190
204
 
191
205
  decode_req.metadata_buffer_index = (
192
206
  self.req_to_metadata_buffer_idx_allocator.alloc()
193
207
  )
194
208
  assert decode_req.metadata_buffer_index is not None
195
- decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
209
+ page_indices = kv_to_page_indices(
210
+ kv_indices, self.token_to_kv_pool_allocator.page_size
211
+ )
212
+ decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
196
213
  preallocated_reqs.append(decode_req)
197
214
  indices_to_remove.add(i)
198
215
 
@@ -232,10 +249,30 @@ class DecodePreallocQueue:
232
249
  assert req_pool_indices is not None
233
250
 
234
251
  req.req_pool_idx = req_pool_indices[0]
235
- kv_loc = self.token_to_kv_pool_allocator.alloc(
236
- len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
237
- )
238
-
252
+ if self.token_to_kv_pool_allocator.page_size == 1:
253
+ kv_loc = self.token_to_kv_pool_allocator.alloc(
254
+ len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
255
+ )
256
+ else:
257
+ num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
258
+ kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
259
+ prefix_lens=torch.tensor(
260
+ [0],
261
+ dtype=torch.int64,
262
+ device=self.token_to_kv_pool_allocator.device,
263
+ ),
264
+ seq_lens=torch.tensor(
265
+ [num_tokens],
266
+ dtype=torch.int64,
267
+ device=self.token_to_kv_pool_allocator.device,
268
+ ),
269
+ last_loc=torch.tensor(
270
+ [-1],
271
+ dtype=torch.int64,
272
+ device=self.token_to_kv_pool_allocator.device,
273
+ ),
274
+ extend_num_tokens=num_tokens,
275
+ )
239
276
  assert kv_loc is not None
240
277
 
241
278
  self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
@@ -406,6 +443,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
406
443
 
407
444
  class SchedulerDisaggregationDecodeMixin:
408
445
 
446
+ @torch.no_grad()
447
+ def event_loop_normal_disagg_decode(self):
448
+ """A normal scheduler loop for decode worker in disaggregation mode."""
449
+
450
+ while True:
451
+ recv_reqs = self.recv_requests()
452
+ self.process_input_requests(recv_reqs)
453
+ # polling and allocating kv cache
454
+ self.process_decode_queue()
455
+ batch = self.get_next_disagg_decode_batch_to_run()
456
+ self.cur_batch = batch
457
+
458
+ if batch:
459
+ # Generate fake extend output.
460
+ if batch.forward_mode.is_extend():
461
+ # Note: Logprobs should be handled on the prefill engine.
462
+ self.stream_output(batch.reqs, False)
463
+ else:
464
+ result = self.run_batch(batch)
465
+ self.process_batch_result(batch, result)
466
+
467
+ if batch is None and (
468
+ len(self.disagg_decode_transfer_queue.queue)
469
+ + len(self.disagg_decode_prealloc_queue.queue)
470
+ == 0
471
+ ):
472
+ # When the server is idle, do self-check and re-init some states
473
+ self.check_memory()
474
+ self.new_token_ratio = self.init_new_token_ratio
475
+
476
+ self.last_batch = batch
477
+
409
478
  def get_next_disagg_decode_batch_to_run(
410
479
  self: Scheduler,
411
480
  ) -> Optional[Tuple[ScheduleBatch, bool]]: