sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/managers/cache_controller.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +52 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +9 -1
- sglang/srt/mem_cache/memory_pool.py +4 -1
- sglang/srt/model_executor/cuda_graph_runner.py +59 -16
- sglang/srt/model_executor/forward_batch_info.py +13 -4
- sglang/srt/models/deepseek_v2.py +180 -177
- sglang/srt/models/grok.py +374 -119
- sglang/srt/openai_api/adapter.py +22 -20
- sglang/srt/server_args.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +24 -22
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
|
9
|
+
|
10
|
+
@triton.jit
|
11
|
+
def fused_moe_router_kernel(
|
12
|
+
input_ptr, # input (bs, hidden_dim)
|
13
|
+
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
14
|
+
topk_weights_ptr, # output (bs, topk)
|
15
|
+
topk_ids_ptr, # output (bs, topk)
|
16
|
+
num_experts: tl.constexpr,
|
17
|
+
topk: tl.constexpr,
|
18
|
+
moe_softcapping: tl.constexpr,
|
19
|
+
moe_renormalize: tl.constexpr, # not supported
|
20
|
+
hidden_dim: tl.constexpr,
|
21
|
+
BLOCK_SIZE: tl.constexpr,
|
22
|
+
):
|
23
|
+
pid = tl.program_id(axis=0)
|
24
|
+
|
25
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
26
|
+
mask = offsets < hidden_dim
|
27
|
+
|
28
|
+
# moe_router_weight is k major
|
29
|
+
expert_offsets = tl.arange(0, num_experts)[:, None]
|
30
|
+
router_mask = mask[None, :]
|
31
|
+
w_router = tl.load(
|
32
|
+
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
|
33
|
+
mask=router_mask,
|
34
|
+
other=0.0,
|
35
|
+
)
|
36
|
+
|
37
|
+
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
|
38
|
+
|
39
|
+
# todo: tl.dot?
|
40
|
+
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
41
|
+
|
42
|
+
# logit softcap
|
43
|
+
logits_scaled = logits / moe_softcapping
|
44
|
+
exped = tl.exp(2 * logits_scaled)
|
45
|
+
top = exped - 1
|
46
|
+
bottom = exped + 1
|
47
|
+
logits_softcapped = top / bottom * moe_softcapping
|
48
|
+
|
49
|
+
# topk
|
50
|
+
# assert 1 <= topk <= num_experts
|
51
|
+
|
52
|
+
# 5.38 us
|
53
|
+
|
54
|
+
top1 = tl.argmax(logits_softcapped, axis=0)
|
55
|
+
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
|
56
|
+
|
57
|
+
top1_v = tl.max(logits_softcapped, axis=0)
|
58
|
+
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
|
59
|
+
|
60
|
+
tl.store(
|
61
|
+
topk_weights_ptr + pid * topk + 0,
|
62
|
+
invsumexp,
|
63
|
+
) # 5.73 us
|
64
|
+
|
65
|
+
if topk >= 2:
|
66
|
+
top2 = tl.argmax(
|
67
|
+
tl.where(
|
68
|
+
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
|
69
|
+
),
|
70
|
+
axis=0,
|
71
|
+
)
|
72
|
+
tl.store(topk_ids_ptr + pid * topk + 1, top2)
|
73
|
+
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
|
74
|
+
tl.store(
|
75
|
+
topk_weights_ptr + pid * topk + 1,
|
76
|
+
tl.exp(top2_v - top1_v) * invsumexp,
|
77
|
+
) # 5.95us
|
78
|
+
|
79
|
+
# probably slow
|
80
|
+
if topk > 2:
|
81
|
+
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
|
82
|
+
topk_mask = tl.where(
|
83
|
+
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
|
84
|
+
)
|
85
|
+
topk_mask = tl.where(
|
86
|
+
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
|
87
|
+
)
|
88
|
+
for i in range(2, topk):
|
89
|
+
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
|
90
|
+
topk_mask = tl.where(
|
91
|
+
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
|
92
|
+
)
|
93
|
+
tl.store(topk_ids_ptr + pid * topk + i, topi)
|
94
|
+
topi_v = tl.sum(
|
95
|
+
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
|
96
|
+
)
|
97
|
+
tl.store(
|
98
|
+
topk_weights_ptr + pid * topk + i,
|
99
|
+
tl.exp(topi_v - top1_v) * invsumexp,
|
100
|
+
)
|
101
|
+
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
102
|
+
|
103
|
+
|
104
|
+
def fused_moe_router_impl(
|
105
|
+
x: torch.Tensor,
|
106
|
+
router_weight: torch.Tensor,
|
107
|
+
topk: int,
|
108
|
+
moe_softcapping: float,
|
109
|
+
):
|
110
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
111
|
+
bs, hidden_dim = x.shape
|
112
|
+
num_experts = router_weight.shape[0]
|
113
|
+
|
114
|
+
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
115
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
116
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
|
+
|
118
|
+
grid = lambda meta: (bs,)
|
119
|
+
config = {
|
120
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
|
+
"num_warps": max(
|
122
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
123
|
+
),
|
124
|
+
}
|
125
|
+
|
126
|
+
fused_moe_router_kernel[grid](
|
127
|
+
x,
|
128
|
+
router_weight,
|
129
|
+
topk_weights,
|
130
|
+
topk_ids,
|
131
|
+
num_experts=num_experts,
|
132
|
+
topk=topk,
|
133
|
+
moe_softcapping=moe_softcapping,
|
134
|
+
moe_renormalize=False,
|
135
|
+
hidden_dim=hidden_dim,
|
136
|
+
**config,
|
137
|
+
)
|
138
|
+
|
139
|
+
return topk_weights, topk_ids
|
140
|
+
|
141
|
+
|
142
|
+
@triton.jit
|
143
|
+
def fused_moe_router_large_bs_kernel(
|
144
|
+
a_ptr, # input (bs, hidden_dim)
|
145
|
+
b_ptr, # input (num_experts, hidden_dim)
|
146
|
+
topk_weights_ptr, # output (bs, topk)
|
147
|
+
topk_ids_ptr, # output (bs, topk)
|
148
|
+
bs,
|
149
|
+
num_experts: tl.constexpr,
|
150
|
+
topk: tl.constexpr, # only support topk == 1
|
151
|
+
moe_softcapping: tl.constexpr,
|
152
|
+
moe_renormalize: tl.constexpr, # not supported
|
153
|
+
K: tl.constexpr,
|
154
|
+
BLOCK_SIZE_M: tl.constexpr,
|
155
|
+
BLOCK_SIZE_N: tl.constexpr,
|
156
|
+
BLOCK_SIZE_K: tl.constexpr,
|
157
|
+
stride_am: tl.constexpr,
|
158
|
+
stride_bn: tl.constexpr,
|
159
|
+
):
|
160
|
+
|
161
|
+
# 1. get block id
|
162
|
+
pid = tl.program_id(axis=0)
|
163
|
+
|
164
|
+
# 2. create pointers for the first block of A and B
|
165
|
+
# 2.1. setup a_ptrs with offsets in m and k
|
166
|
+
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
|
167
|
+
bs_mask = offs_m < bs
|
168
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
169
|
+
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
|
170
|
+
|
171
|
+
# 2.2. setup b_ptrs with offsets in k and n.
|
172
|
+
# Note: b matrix is k-major.
|
173
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
174
|
+
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
|
175
|
+
expert_mask = offs_n < num_experts
|
176
|
+
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
|
177
|
+
|
178
|
+
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
|
179
|
+
# 3.1. iterate in K dimension
|
180
|
+
# 3.2. transpose tile B
|
181
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
182
|
+
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
|
183
|
+
a = tl.load(
|
184
|
+
a_ptrs,
|
185
|
+
mask=bs_mask,
|
186
|
+
other=0.0,
|
187
|
+
).to(tl.float32)
|
188
|
+
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
|
189
|
+
acc += tl.dot(a, b)
|
190
|
+
|
191
|
+
# Advance the ptrs to the next K block.
|
192
|
+
a_ptrs += BLOCK_SIZE_K
|
193
|
+
b_ptrs += BLOCK_SIZE_K
|
194
|
+
|
195
|
+
# 4. logit softcap
|
196
|
+
logits_scaled = acc / moe_softcapping
|
197
|
+
exped = tl.exp(2 * logits_scaled)
|
198
|
+
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
199
|
+
|
200
|
+
# 5. top1
|
201
|
+
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
202
|
+
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
203
|
+
top1_v = tl.max(
|
204
|
+
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
205
|
+
)
|
206
|
+
invsumexp = 1.0 / tl.sum(
|
207
|
+
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
208
|
+
)
|
209
|
+
|
210
|
+
# 6. store to output
|
211
|
+
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
212
|
+
topk_mask = offs_topk < bs
|
213
|
+
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
214
|
+
tl.store(
|
215
|
+
topk_weights_ptr + offs_topk,
|
216
|
+
invsumexp,
|
217
|
+
mask=topk_mask,
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
def fused_moe_router_large_bs_impl(
|
222
|
+
x: torch.Tensor,
|
223
|
+
router_weight: torch.Tensor,
|
224
|
+
topk: int,
|
225
|
+
moe_softcapping: float,
|
226
|
+
BLOCK_SIZE_M: int,
|
227
|
+
BLOCK_SIZE_N: int,
|
228
|
+
BLOCK_SIZE_K: int,
|
229
|
+
):
|
230
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
231
|
+
bs, hidden_dim = x.shape
|
232
|
+
num_experts = router_weight.shape[0]
|
233
|
+
|
234
|
+
assert num_experts <= BLOCK_SIZE_N
|
235
|
+
assert hidden_dim % BLOCK_SIZE_K == 0
|
236
|
+
assert topk == 1
|
237
|
+
|
238
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
239
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
240
|
+
|
241
|
+
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
242
|
+
|
243
|
+
fused_moe_router_large_bs_kernel[grid](
|
244
|
+
a_ptr=x,
|
245
|
+
b_ptr=router_weight,
|
246
|
+
topk_weights_ptr=topk_weights,
|
247
|
+
topk_ids_ptr=topk_ids,
|
248
|
+
bs=bs,
|
249
|
+
num_experts=num_experts,
|
250
|
+
topk=topk,
|
251
|
+
moe_softcapping=moe_softcapping,
|
252
|
+
moe_renormalize=False,
|
253
|
+
K=hidden_dim,
|
254
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
255
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
256
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
257
|
+
stride_am=hidden_dim,
|
258
|
+
stride_bn=hidden_dim,
|
259
|
+
)
|
260
|
+
|
261
|
+
return topk_weights, topk_ids
|
262
|
+
|
263
|
+
|
264
|
+
def fused_moe_router_shim(
|
265
|
+
moe_softcapping,
|
266
|
+
hidden_states,
|
267
|
+
gating_output,
|
268
|
+
topk,
|
269
|
+
renormalize,
|
270
|
+
):
|
271
|
+
assert not renormalize
|
272
|
+
assert (
|
273
|
+
len(hidden_states.shape) == 2
|
274
|
+
and hidden_states.shape[1] == gating_output.shape[1]
|
275
|
+
)
|
276
|
+
bs, hidden_dim = hidden_states.shape
|
277
|
+
num_experts = gating_output.shape[0]
|
278
|
+
BLOCK_SIZE_M = 32
|
279
|
+
BLOCK_SIZE_N = 16
|
280
|
+
BLOCK_SIZE_K = 256
|
281
|
+
if (
|
282
|
+
bs >= 512
|
283
|
+
and topk == 1
|
284
|
+
and num_experts <= BLOCK_SIZE_N
|
285
|
+
and hidden_dim % BLOCK_SIZE_K == 0
|
286
|
+
):
|
287
|
+
return fused_moe_router_large_bs_impl(
|
288
|
+
x=hidden_states,
|
289
|
+
router_weight=gating_output,
|
290
|
+
topk=topk,
|
291
|
+
moe_softcapping=moe_softcapping,
|
292
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
293
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
294
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
return fused_moe_router_impl(
|
298
|
+
x=hidden_states,
|
299
|
+
router_weight=gating_output,
|
300
|
+
topk=topk,
|
301
|
+
moe_softcapping=moe_softcapping,
|
302
|
+
)
|
303
|
+
|
304
|
+
|
305
|
+
class FusedMoeRouter:
|
306
|
+
def __init__(self, router_linear, topk, moe_softcapping) -> None:
|
307
|
+
self.router_linear = router_linear
|
308
|
+
self.topk = topk
|
309
|
+
self.moe_softcapping = moe_softcapping
|
310
|
+
|
311
|
+
def __call__(self, *args, **kwargs):
|
312
|
+
return self.forward(*args, **kwargs)
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
316
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
317
|
+
if x.is_cuda:
|
318
|
+
return self.forward_cuda(x, residual)
|
319
|
+
else:
|
320
|
+
return self.forward_vllm(x, residual)
|
321
|
+
|
322
|
+
def forward_cuda(
|
323
|
+
self, x: torch.Tensor, autotune=False
|
324
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
325
|
+
return fused_moe_router_shim(
|
326
|
+
moe_softcapping=self.moe_softcapping,
|
327
|
+
hidden_states=x,
|
328
|
+
gating_output=self.router_linear.weight,
|
329
|
+
topk=self.topk,
|
330
|
+
renormalize=False,
|
331
|
+
)
|
332
|
+
|
333
|
+
def forward_vllm(
|
334
|
+
self,
|
335
|
+
x: torch.Tensor,
|
336
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
337
|
+
# g, _ = self.router_linear.forward(x)
|
338
|
+
g = x.float() @ self.router_linear.weight.T.float()
|
339
|
+
|
340
|
+
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
341
|
+
|
342
|
+
return fused_topk(x, g, self.topk, False)
|
@@ -248,6 +248,8 @@ class HiCacheController:
|
|
248
248
|
if device_indices is None:
|
249
249
|
return None
|
250
250
|
self.mem_pool_host.protect_load(host_indices)
|
251
|
+
# to ensure the device indices are ready before accessed by another CUDA stream
|
252
|
+
torch.cuda.current_stream().synchronize()
|
251
253
|
self.load_queue.put(
|
252
254
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
253
255
|
)
|
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
|
|
54
54
|
class DataParallelController:
|
55
55
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
56
56
|
|
57
|
-
def __init__(self, server_args, port_args) -> None:
|
57
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
58
58
|
# Parse args
|
59
59
|
self.max_total_num_tokens = None
|
60
60
|
self.server_args = server_args
|
@@ -361,7 +361,7 @@ class Req:
|
|
361
361
|
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
362
362
|
self.output_token_ids_logprobs_idx
|
363
363
|
) = None
|
364
|
-
self.hidden_states = []
|
364
|
+
self.hidden_states: List[List[float]] = []
|
365
365
|
|
366
366
|
# Embedding (return values)
|
367
367
|
self.embedding = None
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
434
434
|
req_to_token_pool=self.req_to_token_pool,
|
435
435
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
436
|
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
437
|
+
page_size=self.page_size,
|
437
438
|
)
|
438
439
|
else:
|
439
440
|
self.tree_cache = RadixCache(
|
@@ -997,7 +998,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
997
998
|
|
998
999
|
# Handle DP attention
|
999
1000
|
if self.server_args.enable_dp_attention:
|
1000
|
-
ret = self.prepare_dp_attn_batch(ret)
|
1001
|
+
ret, _ = self.prepare_dp_attn_batch(ret)
|
1001
1002
|
|
1002
1003
|
return ret
|
1003
1004
|
|
@@ -1269,39 +1270,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1269
1270
|
# Check if other DP workers have running batches
|
1270
1271
|
if local_batch is None:
|
1271
1272
|
num_tokens = 0
|
1273
|
+
global_num_tokens_for_logprob = 0
|
1272
1274
|
elif local_batch.forward_mode.is_decode():
|
1273
1275
|
num_tokens = local_batch.batch_size()
|
1276
|
+
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
1277
|
+
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
1278
|
+
global_num_tokens_for_logprob = num_tokens
|
1274
1279
|
else:
|
1275
1280
|
num_tokens = local_batch.extend_num_tokens
|
1281
|
+
global_num_tokens_for_logprob = sum(
|
1282
|
+
[
|
1283
|
+
# We should have at least 1 token for sample in every case.
|
1284
|
+
max(extend_len - logprob_start_len, 1)
|
1285
|
+
for logprob_start_len, extend_len in zip(
|
1286
|
+
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
1287
|
+
)
|
1288
|
+
]
|
1289
|
+
)
|
1290
|
+
|
1291
|
+
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
1292
|
+
can_cuda_graph = 1
|
1293
|
+
else:
|
1294
|
+
can_cuda_graph = 0
|
1295
|
+
|
1296
|
+
if not self.spec_algorithm.is_none():
|
1297
|
+
# TODO(sang): Support cuda graph when idle batch is there.
|
1298
|
+
if local_batch is None or local_batch.forward_mode.is_idle():
|
1299
|
+
can_cuda_graph = 0
|
1276
1300
|
|
1277
|
-
|
1278
|
-
|
1301
|
+
is_extend_in_batch = (
|
1302
|
+
local_batch.forward_mode.is_extend() if local_batch else False
|
1303
|
+
)
|
1304
|
+
local_info = torch.tensor(
|
1305
|
+
[
|
1306
|
+
num_tokens,
|
1307
|
+
can_cuda_graph,
|
1308
|
+
global_num_tokens_for_logprob,
|
1309
|
+
is_extend_in_batch,
|
1310
|
+
],
|
1311
|
+
dtype=torch.int64,
|
1312
|
+
)
|
1313
|
+
global_info = torch.empty(
|
1314
|
+
(self.server_args.dp_size, self.attn_tp_size, 4),
|
1315
|
+
dtype=torch.int64,
|
1316
|
+
)
|
1279
1317
|
torch.distributed.all_gather_into_tensor(
|
1280
|
-
|
1281
|
-
|
1318
|
+
global_info.flatten(),
|
1319
|
+
local_info,
|
1282
1320
|
group=self.tp_cpu_group,
|
1283
1321
|
)
|
1322
|
+
global_num_tokens = global_info[:, 0, 0].tolist()
|
1323
|
+
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
1324
|
+
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
1325
|
+
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
1284
1326
|
|
1285
|
-
if local_batch is None and
|
1327
|
+
if local_batch is None and max(global_num_tokens) > 0:
|
1286
1328
|
local_batch = self.get_idle_batch()
|
1287
1329
|
|
1288
1330
|
if local_batch is not None:
|
1289
|
-
local_batch.global_num_tokens = global_num_tokens
|
1331
|
+
local_batch.global_num_tokens = global_num_tokens
|
1332
|
+
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
1290
1333
|
|
1291
1334
|
# Check forward mode for cuda graph
|
1292
1335
|
if not self.server_args.disable_cuda_graph:
|
1293
|
-
|
1294
|
-
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1295
|
-
dtype=torch.int32,
|
1296
|
-
)
|
1297
|
-
torch.distributed.all_reduce(
|
1298
|
-
forward_mode_state,
|
1299
|
-
op=torch.distributed.ReduceOp.MIN,
|
1300
|
-
group=self.tp_cpu_group,
|
1301
|
-
)
|
1302
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1336
|
+
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1303
1337
|
|
1304
|
-
return local_batch
|
1338
|
+
return local_batch, any(is_extend_in_batch)
|
1305
1339
|
|
1306
1340
|
def get_idle_batch(self):
|
1307
1341
|
idle_batch = ScheduleBatch.init_new(
|
@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
|
|
111
111
|
]
|
112
112
|
.cpu()
|
113
113
|
.clone()
|
114
|
+
.tolist()
|
114
115
|
)
|
115
116
|
|
116
117
|
if req.grammar is not None:
|
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
|
|
245
246
|
)
|
246
247
|
|
247
248
|
if req.return_hidden_states and logits_output.hidden_states is not None:
|
248
|
-
req.hidden_states.append(
|
249
|
+
req.hidden_states.append(
|
250
|
+
logits_output.hidden_states[i].cpu().clone().tolist()
|
251
|
+
)
|
249
252
|
|
250
253
|
if req.grammar is not None and batch.spec_algorithm.is_none():
|
251
254
|
req.grammar.accept_token(next_token_id)
|
@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache):
|
|
25
25
|
req_to_token_pool: ReqToTokenPool,
|
26
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
27
|
tp_cache_group: torch.distributed.ProcessGroup,
|
28
|
+
page_size: int,
|
28
29
|
):
|
30
|
+
if page_size != 1:
|
31
|
+
raise ValueError(
|
32
|
+
"Page size larger than 1 is not yet supported in HiRadixCache."
|
33
|
+
)
|
29
34
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
30
35
|
token_to_kv_pool_allocator.get_kvcache()
|
31
36
|
)
|
32
37
|
self.tp_group = tp_cache_group
|
38
|
+
self.page_size = page_size
|
33
39
|
|
34
40
|
self.load_cache_event = threading.Event()
|
35
41
|
self.cache_controller = HiCacheController(
|
@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache):
|
|
45
51
|
# todo: dynamically adjust the threshold
|
46
52
|
self.write_through_threshold = 1
|
47
53
|
self.load_back_threshold = 10
|
48
|
-
super().__init__(
|
54
|
+
super().__init__(
|
55
|
+
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
56
|
+
)
|
49
57
|
|
50
58
|
def reset(self):
|
51
59
|
TreeNode.counter = 0
|
@@ -326,7 +326,7 @@ class MHATokenToKVPool(KVCache):
|
|
326
326
|
cache_k = cache_k.view(self.store_dtype)
|
327
327
|
cache_v = cache_v.view(self.store_dtype)
|
328
328
|
|
329
|
-
if self.capture_mode:
|
329
|
+
if self.capture_mode and cache_k.shape[0] < 4:
|
330
330
|
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
331
331
|
with torch.cuda.stream(self.alt_stream):
|
332
332
|
self.k_buffer[layer_id][loc] = cache_k
|
@@ -591,6 +591,9 @@ class MHATokenToKVPoolHost:
|
|
591
591
|
def get_flat_data(self, indices):
|
592
592
|
return self.kv_buffer[:, :, indices]
|
593
593
|
|
594
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
595
|
+
return self.kv_buffer[:, layer_id, indices]
|
596
|
+
|
594
597
|
def assign_flat_data(self, indices, flat_data):
|
595
598
|
self.kv_buffer[:, :, indices] = flat_data
|
596
599
|
|