sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
# Adapt from
|
2
|
+
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
3
|
+
|
4
|
+
from typing import List, Optional, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
from sglang.srt.utils import get_device_core_count
|
11
|
+
|
12
|
+
|
13
|
+
@triton.jit
|
14
|
+
def apply_token_bitmask_inplace_kernel(
|
15
|
+
logits_ptr,
|
16
|
+
bitmask_ptr,
|
17
|
+
indices_ptr,
|
18
|
+
num_rows,
|
19
|
+
vocab_size,
|
20
|
+
logits_strides,
|
21
|
+
bitmask_strides,
|
22
|
+
NUM_SMS: tl.constexpr,
|
23
|
+
BLOCK_SIZE: tl.constexpr,
|
24
|
+
):
|
25
|
+
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
|
26
|
+
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
|
27
|
+
the masked logits will be set to -inf.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
logits_ptr : tl.tensor
|
32
|
+
Pointer to the logits tensor to apply the bitmask to.
|
33
|
+
|
34
|
+
bitmask_ptr : tl.tensor
|
35
|
+
Pointer to the bitmask tensor to apply.
|
36
|
+
|
37
|
+
indices_ptr : Optional[tl.tensor]
|
38
|
+
Optional pointer to indices tensor specifying which rows to apply the mask to.
|
39
|
+
|
40
|
+
num_rows : int
|
41
|
+
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
|
42
|
+
|
43
|
+
vocab_size : int
|
44
|
+
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
|
45
|
+
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
|
46
|
+
|
47
|
+
logits_strides : int
|
48
|
+
Stride between rows in the logits tensor.
|
49
|
+
|
50
|
+
bitmask_strides : int
|
51
|
+
Stride between rows in the bitmask tensor.
|
52
|
+
|
53
|
+
NUM_SMS : int
|
54
|
+
Number of streaming multiprocessors to use.
|
55
|
+
|
56
|
+
BLOCK_SIZE : int
|
57
|
+
Size of processing blocks.
|
58
|
+
"""
|
59
|
+
|
60
|
+
pid = tl.program_id(0)
|
61
|
+
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
62
|
+
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
63
|
+
row_id = work_id // num_blocks
|
64
|
+
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
65
|
+
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
66
|
+
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
67
|
+
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
68
|
+
vocab_mask = offsets < vocab_size
|
69
|
+
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
70
|
+
packed_bitmask = tl.load(
|
71
|
+
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
|
72
|
+
packed_bitmask_mask,
|
73
|
+
)
|
74
|
+
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
75
|
+
bitmask = bitmask.reshape(BLOCK_SIZE)
|
76
|
+
|
77
|
+
tl.store(
|
78
|
+
logits_ptr + batch_id * logits_strides + offsets,
|
79
|
+
-float("inf"),
|
80
|
+
vocab_mask & bitmask,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
def apply_token_bitmask_inplace_triton(
|
85
|
+
logits: torch.Tensor,
|
86
|
+
bitmask: torch.Tensor,
|
87
|
+
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
88
|
+
):
|
89
|
+
NUM_SMS = get_device_core_count()
|
90
|
+
BLOCK_SIZE = 4096
|
91
|
+
BITS_PER_BLOCK = 32
|
92
|
+
|
93
|
+
# Check input dtype
|
94
|
+
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
|
95
|
+
|
96
|
+
# Check input tensor shapes.
|
97
|
+
logits_shape = logits.shape
|
98
|
+
bitmask_shape = bitmask.shape
|
99
|
+
if logits.ndim == 1:
|
100
|
+
logits_shape = (1, logits_shape[0])
|
101
|
+
if bitmask.ndim == 1:
|
102
|
+
bitmask_shape = (1, bitmask_shape[0])
|
103
|
+
|
104
|
+
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
|
105
|
+
assert required_bitmask_width >= bitmask_shape[1], (
|
106
|
+
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
|
107
|
+
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
|
108
|
+
)
|
109
|
+
|
110
|
+
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
|
111
|
+
|
112
|
+
num_rows = None
|
113
|
+
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
|
114
|
+
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
115
|
+
num_rows = indices.shape[0]
|
116
|
+
else:
|
117
|
+
assert (
|
118
|
+
logits_shape[0] == bitmask_shape[0]
|
119
|
+
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
|
120
|
+
num_rows = logits_shape[0]
|
121
|
+
|
122
|
+
if NUM_SMS > 0:
|
123
|
+
grid = (NUM_SMS,)
|
124
|
+
else:
|
125
|
+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
126
|
+
grid = (num_rows * num_blocks,)
|
127
|
+
NUM_SMS = triton.next_power_of_2(grid[0])
|
128
|
+
|
129
|
+
apply_token_bitmask_inplace_kernel[grid](
|
130
|
+
logits,
|
131
|
+
bitmask,
|
132
|
+
indices,
|
133
|
+
num_rows,
|
134
|
+
vocab_size,
|
135
|
+
logits_shape[1],
|
136
|
+
bitmask_shape[1],
|
137
|
+
NUM_SMS,
|
138
|
+
BLOCK_SIZE,
|
139
|
+
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
140
|
+
num_stages=3,
|
141
|
+
)
|
@@ -25,13 +25,16 @@ from xgrammar import (
|
|
25
25
|
StructuralTagItem,
|
26
26
|
TokenizerInfo,
|
27
27
|
allocate_token_bitmask,
|
28
|
-
apply_token_bitmask_inplace,
|
29
28
|
)
|
30
29
|
|
31
30
|
from sglang.srt.constrained.base_grammar_backend import (
|
32
31
|
BaseGrammarBackend,
|
33
32
|
BaseGrammarObject,
|
34
33
|
)
|
34
|
+
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
35
|
+
apply_token_bitmask_inplace_triton,
|
36
|
+
)
|
37
|
+
from sglang.srt.utils import get_bool_env_var
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
@@ -48,12 +51,25 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
48
51
|
ctx: CompiledGrammar,
|
49
52
|
override_stop_tokens: Optional[Union[List[int], int]],
|
50
53
|
) -> None:
|
54
|
+
super().__init__()
|
51
55
|
self.matcher = matcher
|
52
56
|
self.vocab_size = vocab_size
|
53
57
|
self.ctx = ctx
|
54
58
|
self.override_stop_tokens = override_stop_tokens
|
55
59
|
self.finished = False
|
56
60
|
|
61
|
+
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
|
62
|
+
# class init site to avoid re-initializing CUDA in forked subprocess.
|
63
|
+
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
|
64
|
+
|
65
|
+
self.use_token_bitmask_triton = get_bool_env_var(
|
66
|
+
"SGLANG_TOKEN_BITMASK_TRITON", "false"
|
67
|
+
)
|
68
|
+
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
|
69
|
+
"cuda", None
|
70
|
+
)
|
71
|
+
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
|
72
|
+
|
57
73
|
def accept_token(self, token: int):
|
58
74
|
assert self.matcher.accept_token(token)
|
59
75
|
|
@@ -96,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
96
112
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
97
113
|
return vocab_mask.to(device, non_blocking=True)
|
98
114
|
|
99
|
-
|
100
|
-
|
101
|
-
|
115
|
+
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
116
|
+
if (
|
117
|
+
not self.use_token_bitmask_triton
|
118
|
+
and logits.device.type == "cuda"
|
119
|
+
and self.apply_vocab_mask_cuda
|
120
|
+
):
|
121
|
+
return self.apply_vocab_mask_cuda(logits, vocab_mask)
|
122
|
+
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
123
|
+
return self.apply_vocab_mask_cpu(logits, vocab_mask)
|
124
|
+
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
102
125
|
|
103
126
|
def copy(self):
|
104
127
|
matcher = GrammarMatcher(
|
sglang/srt/custom_op.py
CHANGED
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
|
|
42
42
|
return self.forward_hip
|
43
43
|
else:
|
44
44
|
return self.forward_native
|
45
|
-
|
46
|
-
|
47
|
-
if _is_cuda:
|
48
|
-
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
49
|
-
|
50
|
-
def scaled_fp8_quant(
|
51
|
-
input: torch.Tensor,
|
52
|
-
scale: Optional[torch.Tensor] = None,
|
53
|
-
num_token_padding: Optional[int] = None,
|
54
|
-
use_per_token_if_dynamic: bool = False,
|
55
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
56
|
-
"""
|
57
|
-
Quantize input tensor to FP8 (8-bit floating point) format.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
input (torch.Tensor): Input tensor to be quantized
|
61
|
-
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
62
|
-
If None, scales will be computed dynamically.
|
63
|
-
num_token_padding (Optional[int]): If specified, pad the first dimension
|
64
|
-
of the output to at least this value.
|
65
|
-
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
66
|
-
determines the quantization granularity:
|
67
|
-
- True: compute scale per token
|
68
|
-
- False: compute single scale per tensor
|
69
|
-
|
70
|
-
Returns:
|
71
|
-
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
72
|
-
- quantized_tensor: The FP8 quantized version of input
|
73
|
-
- scale_tensor: The scaling factors used for quantization
|
74
|
-
|
75
|
-
Raises:
|
76
|
-
AssertionError: If input is not 2D or if static scale's numel != 1
|
77
|
-
"""
|
78
|
-
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
79
|
-
shape = input.shape
|
80
|
-
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
81
|
-
if num_token_padding:
|
82
|
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
83
|
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
84
|
-
|
85
|
-
if scale is None:
|
86
|
-
# Dynamic scaling
|
87
|
-
if use_per_token_if_dynamic:
|
88
|
-
scale = torch.empty(
|
89
|
-
(shape[0], 1), device=input.device, dtype=torch.float32
|
90
|
-
)
|
91
|
-
sgl_per_token_quant_fp8(input, output, scale)
|
92
|
-
else:
|
93
|
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
94
|
-
sgl_per_tensor_quant_fp8(
|
95
|
-
input, output, scale, is_static=False
|
96
|
-
) # False for dynamic
|
97
|
-
else:
|
98
|
-
# Static scaling
|
99
|
-
assert (
|
100
|
-
scale.numel() == 1
|
101
|
-
), f"Expected scalar scale, got numel={scale.numel()}"
|
102
|
-
sgl_per_tensor_quant_fp8(
|
103
|
-
input, output, scale, is_static=True
|
104
|
-
) # True for static
|
105
|
-
|
106
|
-
return output, scale
|
@@ -0,0 +1,113 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
|
7
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
8
|
+
from sglang.srt.server_args import ServerArgs
|
9
|
+
|
10
|
+
|
11
|
+
class KVArgs:
|
12
|
+
engine_rank: int
|
13
|
+
kv_data_ptrs: list[int]
|
14
|
+
kv_data_lens: list[int]
|
15
|
+
kv_item_lens: list[int]
|
16
|
+
aux_data_ptrs: list[int]
|
17
|
+
aux_data_lens: list[int]
|
18
|
+
aux_item_lens: list[int]
|
19
|
+
ib_device: str
|
20
|
+
gpu_id: int
|
21
|
+
|
22
|
+
|
23
|
+
class KVPoll:
|
24
|
+
Failed = 0
|
25
|
+
Bootstrapping = 1
|
26
|
+
WaitingForInput = 2
|
27
|
+
Transferring = 3
|
28
|
+
Success = 4
|
29
|
+
|
30
|
+
|
31
|
+
class BaseKVManager(ABC):
|
32
|
+
"""Base class for managing transfers states"""
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
args: KVArgs,
|
38
|
+
disaggregation_mode: DisaggregationMode,
|
39
|
+
server_args: ServerArgs,
|
40
|
+
): ...
|
41
|
+
|
42
|
+
|
43
|
+
class BaseKVSender(ABC):
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def __init__(
|
47
|
+
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
|
48
|
+
): ...
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
52
|
+
"""
|
53
|
+
Notify the decoder server about the kv indices length and aux index
|
54
|
+
"""
|
55
|
+
...
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def send(self, kv_indices: npt.NDArray[np.int64]):
|
59
|
+
"""
|
60
|
+
Send the kv cache at the given kv indices to the decoder server
|
61
|
+
"""
|
62
|
+
...
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def poll(self) -> KVPoll:
|
66
|
+
"""
|
67
|
+
Check the status of the kv cache transfer
|
68
|
+
"""
|
69
|
+
...
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
def failure_exception(self):
|
73
|
+
"""
|
74
|
+
Raise an exception if the kv cache transfer fails
|
75
|
+
"""
|
76
|
+
...
|
77
|
+
|
78
|
+
|
79
|
+
class BaseKVReceiver(ABC):
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
mgr: BaseKVManager,
|
85
|
+
bootstrap_addr: str,
|
86
|
+
bootstrap_room: Optional[int] = None,
|
87
|
+
): ...
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
91
|
+
"""
|
92
|
+
Notify the prefill server about the kv indices and aux index
|
93
|
+
"""
|
94
|
+
...
|
95
|
+
|
96
|
+
@abstractmethod
|
97
|
+
def poll(self) -> KVPoll:
|
98
|
+
"""
|
99
|
+
Check the status of the kv cache transfer
|
100
|
+
"""
|
101
|
+
...
|
102
|
+
|
103
|
+
@abstractmethod
|
104
|
+
def failure_exception(self):
|
105
|
+
"""
|
106
|
+
Raise an exception if the kv cache transfer fails
|
107
|
+
"""
|
108
|
+
...
|
109
|
+
|
110
|
+
|
111
|
+
class BaseKVBootstrapServer(ABC):
|
112
|
+
@abstractmethod
|
113
|
+
def __init__(self, port: int): ...
|
@@ -24,12 +24,18 @@ import logging
|
|
24
24
|
from dataclasses import dataclass
|
25
25
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
26
|
|
27
|
+
import numpy as np
|
27
28
|
import torch
|
28
29
|
from torch.distributed import ProcessGroup
|
29
30
|
|
30
|
-
from sglang.srt.disaggregation.
|
31
|
+
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
31
32
|
from sglang.srt.disaggregation.utils import (
|
33
|
+
DisaggregationMode,
|
34
|
+
KVClassType,
|
32
35
|
ReqToMetadataIdxAllocator,
|
36
|
+
TransferBackend,
|
37
|
+
get_kv_class,
|
38
|
+
kv_to_page_indices,
|
33
39
|
poll_and_all_reduce,
|
34
40
|
)
|
35
41
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -49,7 +55,7 @@ if TYPE_CHECKING:
|
|
49
55
|
@dataclass
|
50
56
|
class DecodeRequest:
|
51
57
|
req: Req
|
52
|
-
kv_receiver:
|
58
|
+
kv_receiver: BaseKVReceiver
|
53
59
|
waiting_for_input: bool = False
|
54
60
|
metadata_buffer_index: int = -1
|
55
61
|
|
@@ -73,6 +79,7 @@ class DecodePreallocQueue:
|
|
73
79
|
tp_rank: int,
|
74
80
|
tp_size: int,
|
75
81
|
bootstrap_port: int,
|
82
|
+
transfer_backend: TransferBackend,
|
76
83
|
):
|
77
84
|
self.req_to_token_pool = req_to_token_pool
|
78
85
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
@@ -92,9 +99,10 @@ class DecodePreallocQueue:
|
|
92
99
|
|
93
100
|
# Queue for requests pending pre-allocation
|
94
101
|
self.queue: List[DecodeRequest] = []
|
102
|
+
self.transfer_backend = transfer_backend
|
95
103
|
self.kv_manager = self._init_kv_manager()
|
96
104
|
|
97
|
-
def _init_kv_manager(self) ->
|
105
|
+
def _init_kv_manager(self) -> BaseKVManager:
|
98
106
|
kv_args = KVArgs()
|
99
107
|
kv_args.engine_rank = self.tp_rank
|
100
108
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
@@ -114,14 +122,19 @@ class DecodePreallocQueue:
|
|
114
122
|
kv_args.aux_item_lens = [
|
115
123
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
116
124
|
]
|
117
|
-
kv_args.ib_device =
|
118
|
-
|
125
|
+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
126
|
+
kv_args.gpu_id = self.scheduler.gpu_id
|
127
|
+
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
128
|
+
kv_manager = kv_manager_class(
|
129
|
+
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
|
130
|
+
)
|
119
131
|
return kv_manager
|
120
132
|
|
121
133
|
def add(self, req: Req) -> None:
|
122
134
|
"""Add a request to the pending queue."""
|
123
135
|
|
124
|
-
|
136
|
+
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
137
|
+
kv_receiver = kv_receiver_class(
|
125
138
|
mgr=self.kv_manager,
|
126
139
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
127
140
|
bootstrap_room=req.bootstrap_room,
|
@@ -186,13 +199,17 @@ class DecodePreallocQueue:
|
|
186
199
|
]
|
187
200
|
.cpu()
|
188
201
|
.numpy()
|
202
|
+
.astype(np.int64)
|
189
203
|
)
|
190
204
|
|
191
205
|
decode_req.metadata_buffer_index = (
|
192
206
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
193
207
|
)
|
194
208
|
assert decode_req.metadata_buffer_index is not None
|
195
|
-
|
209
|
+
page_indices = kv_to_page_indices(
|
210
|
+
kv_indices, self.token_to_kv_pool_allocator.page_size
|
211
|
+
)
|
212
|
+
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
196
213
|
preallocated_reqs.append(decode_req)
|
197
214
|
indices_to_remove.add(i)
|
198
215
|
|
@@ -232,10 +249,30 @@ class DecodePreallocQueue:
|
|
232
249
|
assert req_pool_indices is not None
|
233
250
|
|
234
251
|
req.req_pool_idx = req_pool_indices[0]
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
252
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
253
|
+
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
254
|
+
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
255
|
+
)
|
256
|
+
else:
|
257
|
+
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
258
|
+
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
259
|
+
prefix_lens=torch.tensor(
|
260
|
+
[0],
|
261
|
+
dtype=torch.int64,
|
262
|
+
device=self.token_to_kv_pool_allocator.device,
|
263
|
+
),
|
264
|
+
seq_lens=torch.tensor(
|
265
|
+
[num_tokens],
|
266
|
+
dtype=torch.int64,
|
267
|
+
device=self.token_to_kv_pool_allocator.device,
|
268
|
+
),
|
269
|
+
last_loc=torch.tensor(
|
270
|
+
[-1],
|
271
|
+
dtype=torch.int64,
|
272
|
+
device=self.token_to_kv_pool_allocator.device,
|
273
|
+
),
|
274
|
+
extend_num_tokens=num_tokens,
|
275
|
+
)
|
239
276
|
assert kv_loc is not None
|
240
277
|
|
241
278
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
@@ -406,6 +443,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
406
443
|
|
407
444
|
class SchedulerDisaggregationDecodeMixin:
|
408
445
|
|
446
|
+
@torch.no_grad()
|
447
|
+
def event_loop_normal_disagg_decode(self):
|
448
|
+
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
449
|
+
|
450
|
+
while True:
|
451
|
+
recv_reqs = self.recv_requests()
|
452
|
+
self.process_input_requests(recv_reqs)
|
453
|
+
# polling and allocating kv cache
|
454
|
+
self.process_decode_queue()
|
455
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
456
|
+
self.cur_batch = batch
|
457
|
+
|
458
|
+
if batch:
|
459
|
+
# Generate fake extend output.
|
460
|
+
if batch.forward_mode.is_extend():
|
461
|
+
# Note: Logprobs should be handled on the prefill engine.
|
462
|
+
self.stream_output(batch.reqs, False)
|
463
|
+
else:
|
464
|
+
result = self.run_batch(batch)
|
465
|
+
self.process_batch_result(batch, result)
|
466
|
+
|
467
|
+
if batch is None and (
|
468
|
+
len(self.disagg_decode_transfer_queue.queue)
|
469
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
470
|
+
== 0
|
471
|
+
):
|
472
|
+
# When the server is idle, do self-check and re-init some states
|
473
|
+
self.check_memory()
|
474
|
+
self.new_token_ratio = self.init_new_token_ratio
|
475
|
+
|
476
|
+
self.last_batch = batch
|
477
|
+
|
409
478
|
def get_next_disagg_decode_batch_to_run(
|
410
479
|
self: Scheduler,
|
411
480
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|