sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
|
|
31
31
|
BLOCK_S: tl.constexpr,
|
32
32
|
BLOCK_N: tl.constexpr,
|
33
33
|
BLOCK_K: tl.constexpr,
|
34
|
-
# For fused output scaling
|
35
|
-
fuse_scaling_add,
|
34
|
+
# For fused output scaling
|
36
35
|
scalings,
|
37
36
|
):
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
37
|
+
"""
|
38
|
+
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
|
39
|
+
results are accumulated into the output tensor.
|
40
|
+
|
41
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
42
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
43
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
44
|
+
|
45
|
+
Args:
|
46
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
47
|
+
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
|
48
|
+
batch and K is the maximum LoRA rank.
|
49
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
50
|
+
Shape: (num_lora, 2 * output_dim, K).
|
51
|
+
output (Tensor): The output tensor where the result is stored.
|
52
|
+
Shape: (s, 2 * output_dim).
|
53
|
+
"""
|
43
54
|
# output_dim >> K
|
44
55
|
|
45
56
|
# Current block computes sequence with batch_id,
|
46
57
|
# which starts from row seg_start of x with length seg_len.
|
47
58
|
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
48
59
|
batch_id = tl.program_id(axis=2)
|
60
|
+
w_index = tl.load(weight_indices + batch_id)
|
61
|
+
rank = tl.load(lora_ranks + w_index)
|
62
|
+
|
63
|
+
# If rank is 0, this kernel is a no-op.
|
64
|
+
if rank == 0:
|
65
|
+
return
|
66
|
+
|
49
67
|
gate_up_id = tl.program_id(axis=1)
|
50
68
|
pid = tl.program_id(axis=0)
|
51
69
|
seg_len = tl.load(seg_lens + batch_id)
|
52
|
-
w_index = tl.load(weight_indices + batch_id)
|
53
70
|
seg_start = tl.load(seg_indptr + batch_id)
|
54
71
|
n_start = gate_up_id * output_dim # offset on output dim
|
55
|
-
rank = tl.load(lora_ranks + w_index)
|
56
72
|
scaling = tl.load(scalings + w_index)
|
57
73
|
|
58
74
|
# Adjust K (rank) according to the specific LoRA adapter
|
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
|
|
82
98
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
83
99
|
x_tile = tl.load(
|
84
100
|
x_ptrs,
|
85
|
-
mask=(s_offset[:, None] < seg_len)
|
86
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
101
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
87
102
|
other=0.0,
|
88
103
|
)
|
89
104
|
w_tile = tl.load(
|
90
105
|
w_ptrs,
|
91
106
|
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
92
|
-
|
107
|
+
& (n_offset[None, :] < output_dim),
|
93
108
|
other=0.0,
|
94
109
|
)
|
95
110
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
|
|
103
118
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
104
119
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
105
120
|
)
|
106
|
-
output_mask = (s_offset[:, None] < seg_len)
|
107
|
-
|
108
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
121
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
|
122
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
109
123
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
110
124
|
|
111
125
|
|
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
|
|
143
157
|
)
|
144
158
|
|
145
159
|
if base_output is None:
|
146
|
-
output = torch.
|
147
|
-
fuse_scaling_add = False
|
160
|
+
output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
148
161
|
else:
|
149
162
|
output = base_output
|
150
|
-
fuse_scaling_add = True
|
151
163
|
|
152
164
|
_gate_up_lora_b_kernel[grid_b](
|
153
165
|
x,
|
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
|
|
169
181
|
BLOCK_S,
|
170
182
|
BLOCK_OUT,
|
171
183
|
BLOCK_R,
|
172
|
-
fuse_scaling_add,
|
173
184
|
batch_info.scalings,
|
174
185
|
)
|
175
186
|
|
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
|
|
33
33
|
BLOCK_S: tl.constexpr,
|
34
34
|
BLOCK_N: tl.constexpr,
|
35
35
|
BLOCK_K: tl.constexpr,
|
36
|
-
# For fused output scaling
|
37
|
-
fuse_scaling_add,
|
36
|
+
# For fused output scaling
|
38
37
|
scalings,
|
39
38
|
):
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
39
|
+
"""
|
40
|
+
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
|
41
|
+
results are accumulated into the output tensor.
|
42
|
+
|
43
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
44
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
45
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
49
|
+
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
|
50
|
+
batch and K is the maximum LoRA rank. The second dimension is partitioned
|
51
|
+
for Q, K, and V.
|
52
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
53
|
+
Shape: (num_lora, N_Q + 2 * N_KV, K).
|
54
|
+
output (Tensor): The output tensor where the result is stored.
|
55
|
+
Shape: (s, N_Q + 2 * N_KV).
|
56
|
+
"""
|
46
57
|
|
47
58
|
# Current block computes sequence with batch_id,
|
48
59
|
# which starts from row seg_start of x with length seg_len.
|
49
60
|
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
50
61
|
batch_id = tl.program_id(axis=2)
|
62
|
+
w_index = tl.load(weight_indices + batch_id)
|
63
|
+
rank = tl.load(lora_ranks + w_index)
|
64
|
+
|
65
|
+
# If rank is 0, this kernel is a no-op.
|
66
|
+
if rank == 0:
|
67
|
+
return
|
68
|
+
|
51
69
|
qkv_id = tl.program_id(axis=1)
|
52
70
|
pid = tl.program_id(axis=0)
|
53
71
|
seg_len = tl.load(seg_lens + batch_id)
|
54
|
-
w_index = tl.load(weight_indices + batch_id)
|
55
72
|
seg_start = tl.load(seg_indptr + batch_id)
|
56
73
|
n_start = tl.load(n_offs + qkv_id)
|
57
74
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
58
|
-
rank = tl.load(lora_ranks + w_index)
|
59
75
|
scaling = tl.load(scalings + w_index)
|
60
76
|
# Adjust K (rank) according to the specific LoRA adapter
|
61
77
|
K = tl.minimum(K, rank)
|
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
|
|
84
100
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
85
101
|
x_tile = tl.load(
|
86
102
|
x_ptrs,
|
87
|
-
mask=(s_offset[:, None] < seg_len)
|
88
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
103
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
89
104
|
other=0.0,
|
90
105
|
)
|
91
106
|
w_tile = tl.load(
|
92
107
|
w_ptrs,
|
93
|
-
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
108
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
|
94
109
|
other=0.0,
|
95
110
|
)
|
96
111
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
|
|
105
120
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
106
121
|
)
|
107
122
|
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
108
|
-
|
109
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
123
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
110
124
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
111
125
|
|
112
126
|
|
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
|
|
153
167
|
)
|
154
168
|
|
155
169
|
if base_output is None:
|
156
|
-
output = torch.
|
157
|
-
fuse_scaling_add = False
|
170
|
+
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
|
158
171
|
else:
|
159
172
|
output = base_output
|
160
|
-
fuse_scaling_add = True
|
161
173
|
|
162
174
|
_qkv_lora_b_kernel[grid_b](
|
163
175
|
x,
|
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
|
|
180
192
|
BLOCK_S,
|
181
193
|
BLOCK_OUT,
|
182
194
|
BLOCK_R,
|
183
|
-
fuse_scaling_add,
|
184
195
|
batch_info.scalings,
|
185
196
|
)
|
186
197
|
|
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
|
|
33
33
|
BLOCK_N: tl.constexpr,
|
34
34
|
BLOCK_K: tl.constexpr,
|
35
35
|
):
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
36
|
+
"""
|
37
|
+
Computes a segmented batched matrix multiplication for the LoRA A matrix.
|
38
|
+
|
39
|
+
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
|
40
|
+
stores the product of the input `x` and the LoRA weights for the corresponding
|
41
|
+
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
42
|
+
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
43
|
+
|
44
|
+
Args:
|
45
|
+
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
46
|
+
is the sum of all sequence lengths in the batch.
|
47
|
+
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
|
48
|
+
with shape `(num_lora, N, K)`.
|
49
|
+
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
50
|
+
"""
|
40
51
|
|
41
52
|
# Current block computes sequence with batch_id,
|
42
53
|
# which starts from row seg_start of x with length seg_len
|
43
54
|
batch_id = tl.program_id(axis=1)
|
44
|
-
pid = tl.program_id(axis=0)
|
45
|
-
seg_len = tl.load(seg_lens + batch_id)
|
46
55
|
w_index = tl.load(weight_indices + batch_id)
|
47
|
-
seg_start = tl.load(seg_indptr + batch_id)
|
48
56
|
rank = tl.load(lora_ranks + w_index)
|
57
|
+
|
58
|
+
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
59
|
+
if rank == 0:
|
60
|
+
return
|
61
|
+
|
62
|
+
pid = tl.program_id(axis=0)
|
63
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
64
|
+
seg_len = tl.load(seg_lens + batch_id)
|
65
|
+
|
49
66
|
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
50
67
|
N = tl.minimum(N, rank * stack_num)
|
51
68
|
|
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
|
|
72
89
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
73
90
|
x_tile = tl.load(
|
74
91
|
x_ptrs,
|
75
|
-
mask=(s_offset[:, None] < seg_len)
|
76
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
92
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
77
93
|
other=0.0,
|
78
94
|
)
|
79
95
|
w_tile = tl.load(
|
80
96
|
w_ptrs,
|
81
|
-
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
97
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
|
82
98
|
other=0.0,
|
83
99
|
)
|
84
100
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
|
|
91
107
|
output_ptr = (output + seg_start * output_stride_0) + (
|
92
108
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
93
109
|
)
|
94
|
-
output_mask = (s_offset[:, None] < seg_len)
|
110
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
|
95
111
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
96
112
|
|
97
113
|
|
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
|
|
31
31
|
BLOCK_S: tl.constexpr,
|
32
32
|
BLOCK_N: tl.constexpr,
|
33
33
|
BLOCK_K: tl.constexpr,
|
34
|
-
# For fused output scaling
|
35
|
-
fuse_scaling_add,
|
34
|
+
# For fused output scaling
|
36
35
|
scalings,
|
37
36
|
):
|
38
|
-
|
39
|
-
|
40
|
-
|
37
|
+
"""
|
38
|
+
Computes a segmented batched matrix multiplication for the LoRA B matrix
|
39
|
+
and adds the result to the output in-place.
|
40
|
+
|
41
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
42
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
43
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
44
|
+
|
45
|
+
Args:
|
46
|
+
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
|
47
|
+
of shape `(s, K)`, where `s` is the total number of tokens.
|
48
|
+
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
|
49
|
+
with shape `(num_lora, N, K)`.
|
50
|
+
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
|
51
|
+
the base model's output for a fused add operation.
|
52
|
+
"""
|
41
53
|
|
42
54
|
# Current block computes sequence with batch_id,
|
43
55
|
# which starts from row seg_start of x with length seg_len
|
44
56
|
batch_id = tl.program_id(axis=1)
|
57
|
+
w_index = tl.load(weight_indices + batch_id)
|
58
|
+
rank = tl.load(lora_ranks + w_index)
|
59
|
+
|
60
|
+
# If rank is 0, this kernel is a no-op.
|
61
|
+
if rank == 0:
|
62
|
+
return
|
63
|
+
|
45
64
|
pid = tl.program_id(axis=0)
|
46
65
|
seg_len = tl.load(seg_lens + batch_id)
|
47
|
-
w_index = tl.load(weight_indices + batch_id)
|
48
66
|
seg_start = tl.load(seg_indptr + batch_id)
|
49
|
-
rank = tl.load(lora_ranks + w_index)
|
50
67
|
scaling = tl.load(scalings + w_index)
|
51
68
|
# Adjust K (rank) according to the specific LoRA adapter
|
52
69
|
K = tl.minimum(K, rank)
|
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
|
|
74
91
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
75
92
|
x_tile = tl.load(
|
76
93
|
x_ptrs,
|
77
|
-
mask=(s_offset[:, None] < seg_len)
|
78
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
94
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
79
95
|
other=0.0,
|
80
96
|
)
|
81
97
|
w_tile = tl.load(
|
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
|
|
95
111
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
96
112
|
)
|
97
113
|
output_mask = s_offset[:, None] < seg_len
|
98
|
-
|
99
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
114
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
100
115
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
101
116
|
|
102
117
|
|
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
|
|
132
147
|
)
|
133
148
|
|
134
149
|
if base_output is None:
|
135
|
-
output = torch.
|
136
|
-
fuse_scaling_add = False
|
150
|
+
output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
|
137
151
|
else:
|
138
152
|
output = base_output
|
139
|
-
fuse_scaling_add = True
|
140
153
|
|
141
154
|
_sgemm_lora_b_kernel[grid](
|
142
155
|
x,
|
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
|
|
158
171
|
BLOCK_S,
|
159
172
|
BLOCK_N,
|
160
173
|
BLOCK_R,
|
161
|
-
fuse_scaling_add,
|
162
174
|
batch_info.scalings,
|
163
175
|
)
|
164
176
|
return output
|
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
import concurrent.futures
|
17
16
|
import logging
|
18
17
|
import math
|
19
18
|
import threading
|
@@ -169,12 +168,23 @@ class HiCacheController:
|
|
169
168
|
page_size: int,
|
170
169
|
load_cache_event: threading.Event = None,
|
171
170
|
write_policy: str = "write_through_selective",
|
171
|
+
io_backend: str = "",
|
172
172
|
):
|
173
173
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
174
174
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
175
175
|
self.mem_pool_host = mem_pool_host
|
176
176
|
self.write_policy = write_policy
|
177
177
|
self.page_size = page_size
|
178
|
+
# using kernel for small page KV cache transfer and DMA for large pages
|
179
|
+
if not io_backend:
|
180
|
+
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
181
|
+
self.io_backend = (
|
182
|
+
"direct"
|
183
|
+
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
184
|
+
else "kernel"
|
185
|
+
)
|
186
|
+
else:
|
187
|
+
self.io_backend = io_backend
|
178
188
|
|
179
189
|
self.load_cache_event = load_cache_event
|
180
190
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -203,12 +213,7 @@ class HiCacheController:
|
|
203
213
|
self.load_stream = torch.cuda.Stream()
|
204
214
|
|
205
215
|
self.write_thread = threading.Thread(
|
206
|
-
target=
|
207
|
-
self.write_thread_func_buffer
|
208
|
-
if self.page_size == 1
|
209
|
-
else self.write_thread_func_direct
|
210
|
-
),
|
211
|
-
daemon=True,
|
216
|
+
target=self.write_thread_func_direct, daemon=True
|
212
217
|
)
|
213
218
|
self.load_thread = threading.Thread(
|
214
219
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -229,12 +234,7 @@ class HiCacheController:
|
|
229
234
|
self.ack_load_queue.queue.clear()
|
230
235
|
|
231
236
|
self.write_thread = threading.Thread(
|
232
|
-
target=
|
233
|
-
self.write_thread_func_buffer
|
234
|
-
if self.page_size == 1
|
235
|
-
else self.write_thread_func_direct
|
236
|
-
),
|
237
|
-
daemon=True,
|
237
|
+
target=self.write_thread_func_direct, daemon=True
|
238
238
|
)
|
239
239
|
self.load_thread = threading.Thread(
|
240
240
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -281,6 +281,15 @@ class HiCacheController:
|
|
281
281
|
)
|
282
282
|
return device_indices
|
283
283
|
|
284
|
+
def move_indices(self, host_indices, device_indices):
|
285
|
+
# move indices to GPU if using kernels, to host if using direct indexing
|
286
|
+
if self.io_backend == "kernel":
|
287
|
+
return host_indices.to(self.mem_pool_device.device), device_indices
|
288
|
+
elif self.io_backend == "direct":
|
289
|
+
return host_indices, device_indices.cpu()
|
290
|
+
else:
|
291
|
+
raise ValueError(f"Unsupported io backend")
|
292
|
+
|
284
293
|
def write_thread_func_direct(self):
|
285
294
|
"""
|
286
295
|
Directly write through KV caches to host memory without buffering.
|
@@ -289,10 +298,14 @@ class HiCacheController:
|
|
289
298
|
while not self.stop_event.is_set():
|
290
299
|
try:
|
291
300
|
operation = self.write_queue.get(block=True, timeout=1)
|
292
|
-
self.
|
293
|
-
operation.host_indices,
|
294
|
-
|
295
|
-
|
301
|
+
host_indices, device_indices = self.move_indices(
|
302
|
+
operation.host_indices, operation.device_indices
|
303
|
+
)
|
304
|
+
self.mem_pool_device.backup_to_host_all_layer(
|
305
|
+
self.mem_pool_host,
|
306
|
+
host_indices,
|
307
|
+
device_indices,
|
308
|
+
self.io_backend,
|
296
309
|
)
|
297
310
|
self.write_stream.synchronize()
|
298
311
|
self.mem_pool_host.complete_io(operation.host_indices)
|
@@ -304,27 +317,6 @@ class HiCacheController:
|
|
304
317
|
except Exception as e:
|
305
318
|
logger.error(e)
|
306
319
|
|
307
|
-
def load_thread_func_direct(self):
|
308
|
-
"""
|
309
|
-
Directly load KV caches from host memory to device memory without buffering.
|
310
|
-
"""
|
311
|
-
torch.cuda.set_stream(self.load_stream)
|
312
|
-
while not self.stop_event.is_set():
|
313
|
-
try:
|
314
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
315
|
-
operation.data = self.mem_pool_host.get_flat_data(
|
316
|
-
operation.host_indices
|
317
|
-
)
|
318
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
319
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
320
|
-
for node_id in operation.node_ids:
|
321
|
-
if node_id != 0:
|
322
|
-
self.ack_load_queue.put(node_id)
|
323
|
-
except Empty:
|
324
|
-
continue
|
325
|
-
except Exception as e:
|
326
|
-
logger.error(e)
|
327
|
-
|
328
320
|
def load_thread_func_layer_by_layer(self):
|
329
321
|
"""
|
330
322
|
Load KV caches from host memory to device memory layer by layer.
|
@@ -349,22 +341,18 @@ class HiCacheController:
|
|
349
341
|
|
350
342
|
# start layer-wise KV cache transfer from CPU to GPU
|
351
343
|
self.layer_done_counter.reset()
|
344
|
+
host_indices, device_indices = self.move_indices(
|
345
|
+
batch_operation.host_indices, batch_operation.device_indices
|
346
|
+
)
|
352
347
|
for i in range(self.mem_pool_host.layer_num):
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
self.mem_pool_host.load_page_per_layer(
|
362
|
-
batch_operation.host_indices,
|
363
|
-
batch_operation.device_indices,
|
364
|
-
self.mem_pool_device,
|
365
|
-
i,
|
366
|
-
)
|
367
|
-
self.load_stream.synchronize()
|
348
|
+
self.mem_pool_device.load_from_host_per_layer(
|
349
|
+
self.mem_pool_host,
|
350
|
+
host_indices,
|
351
|
+
device_indices,
|
352
|
+
i,
|
353
|
+
self.io_backend,
|
354
|
+
)
|
355
|
+
self.load_stream.synchronize()
|
368
356
|
self.layer_done_counter.increment()
|
369
357
|
|
370
358
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
@@ -372,148 +360,6 @@ class HiCacheController:
|
|
372
360
|
if node_id != 0:
|
373
361
|
self.ack_load_queue.put(node_id)
|
374
362
|
|
375
|
-
def write_aux_func(self, no_wait=False):
|
376
|
-
"""
|
377
|
-
Auxiliary function to prepare the buffer for write operations.
|
378
|
-
"""
|
379
|
-
torch.cuda.set_stream(self.write_stream)
|
380
|
-
|
381
|
-
def _to_op(op_):
|
382
|
-
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
383
|
-
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
384
|
-
self.mem_pool_host.device
|
385
|
-
)
|
386
|
-
self.write_buffer.put(op_)
|
387
|
-
return op_
|
388
|
-
|
389
|
-
buffer = None
|
390
|
-
while not self.stop_event.is_set():
|
391
|
-
try:
|
392
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
393
|
-
factor = (
|
394
|
-
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
395
|
-
)
|
396
|
-
|
397
|
-
if factor >= 1:
|
398
|
-
if buffer is not None:
|
399
|
-
_to_op(buffer)
|
400
|
-
buffer = None
|
401
|
-
|
402
|
-
if factor < 2:
|
403
|
-
_to_op(operation)
|
404
|
-
else:
|
405
|
-
split_ops = operation.split(factor)
|
406
|
-
for op_ in split_ops:
|
407
|
-
_to_op(op_)
|
408
|
-
continue
|
409
|
-
|
410
|
-
if buffer is None:
|
411
|
-
buffer = operation
|
412
|
-
else:
|
413
|
-
buffer.merge(operation)
|
414
|
-
if (
|
415
|
-
no_wait
|
416
|
-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
417
|
-
or self.write_queue.empty()
|
418
|
-
or self.write_buffer.empty()
|
419
|
-
):
|
420
|
-
_to_op(buffer)
|
421
|
-
buffer = None
|
422
|
-
except Empty:
|
423
|
-
continue
|
424
|
-
except Exception as e:
|
425
|
-
logger.error(e)
|
426
|
-
|
427
|
-
def load_aux_func(self):
|
428
|
-
"""
|
429
|
-
Auxiliary function to prepare the buffer for load operations.
|
430
|
-
"""
|
431
|
-
|
432
|
-
def _pin_op(op_, put=True):
|
433
|
-
op_.data = (
|
434
|
-
self.mem_pool_host.get_flat_data(op_.host_indices)
|
435
|
-
.contiguous()
|
436
|
-
.pin_memory()
|
437
|
-
)
|
438
|
-
if put:
|
439
|
-
self.load_buffer.put(op_)
|
440
|
-
return op_
|
441
|
-
|
442
|
-
buffer = None
|
443
|
-
while not self.stop_event.is_set():
|
444
|
-
try:
|
445
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
446
|
-
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
447
|
-
|
448
|
-
if factor >= 1:
|
449
|
-
if buffer is not None:
|
450
|
-
_pin_op(buffer)
|
451
|
-
buffer = None
|
452
|
-
|
453
|
-
if factor < 2:
|
454
|
-
_pin_op(operation)
|
455
|
-
else:
|
456
|
-
split_ops = operation.split(factor)
|
457
|
-
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
458
|
-
split_args.append((split_ops[-1], False))
|
459
|
-
# Spawn threads to pin each op concurrently
|
460
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
461
|
-
pinned_ops = list(
|
462
|
-
executor.map(
|
463
|
-
lambda x: _pin_op(x[0], put=x[1]), split_args
|
464
|
-
)
|
465
|
-
)
|
466
|
-
# preserve the order of last op to ensure correct ack
|
467
|
-
self.load_buffer.put(pinned_ops[-1])
|
468
|
-
continue
|
469
|
-
|
470
|
-
if buffer is None:
|
471
|
-
buffer = operation
|
472
|
-
else:
|
473
|
-
buffer.merge(operation)
|
474
|
-
if (
|
475
|
-
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
476
|
-
or self.load_queue.empty()
|
477
|
-
or self.load_buffer.empty()
|
478
|
-
):
|
479
|
-
_pin_op(buffer)
|
480
|
-
buffer = None
|
481
|
-
except Empty:
|
482
|
-
continue
|
483
|
-
except Exception as e:
|
484
|
-
logger.error(e)
|
485
|
-
|
486
|
-
# todo (zhiqiang): double buffering to be deprecated
|
487
|
-
def write_thread_func_buffer(self):
|
488
|
-
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
489
|
-
aux_thread.start()
|
490
|
-
|
491
|
-
while not self.stop_event.is_set():
|
492
|
-
operation = self.write_buffer.get()
|
493
|
-
if operation is None:
|
494
|
-
continue
|
495
|
-
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
496
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
497
|
-
for node_id in operation.node_ids:
|
498
|
-
if node_id != 0:
|
499
|
-
self.ack_write_queue.put(node_id)
|
500
|
-
aux_thread.join()
|
501
|
-
|
502
|
-
def load_thread_func_buffer(self):
|
503
|
-
torch.cuda.set_stream(self.load_stream)
|
504
|
-
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
505
|
-
aux_thread.start()
|
506
|
-
while not self.stop_event.is_set():
|
507
|
-
operation = self.load_buffer.get()
|
508
|
-
if operation is None:
|
509
|
-
continue
|
510
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
511
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
512
|
-
for node_id in operation.node_ids:
|
513
|
-
if node_id != 0:
|
514
|
-
self.ack_load_queue.put(node_id)
|
515
|
-
aux_thread.join()
|
516
|
-
|
517
363
|
def evict_device(
|
518
364
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
519
365
|
) -> int:
|
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
30
|
parser.add_argument("--log-requests", action="store_true")
|
31
|
-
parser.add_argument("--log-requests-level", type=int, default=
|
31
|
+
parser.add_argument("--log-requests-level", type=int, default=3)
|
32
32
|
parser.add_argument(
|
33
33
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
34
34
|
)
|