sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +40 -18
- sglang/global_config.py +21 -16
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +33 -59
- sglang/srt/layers/token_attention.py +4 -8
- sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
- sglang/srt/managers/controller/infer_batch.py +244 -36
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/tp_worker.py +39 -20
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/memory_pool.py +33 -6
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +18 -8
- sglang/srt/server_args.py +24 -20
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
sglang/srt/layers/fused_moe.py
CHANGED
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
import triton
|
11
11
|
import triton.language as tl
|
12
|
-
|
13
12
|
from vllm import _custom_ops as ops
|
14
13
|
from vllm.logger import init_logger
|
15
14
|
|
@@ -108,12 +107,16 @@ def fused_moe_kernel(
|
|
108
107
|
|
109
108
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
110
109
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
111
|
-
a_ptrs = a_ptr + (
|
112
|
-
|
110
|
+
a_ptrs = a_ptr + (
|
111
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
112
|
+
)
|
113
113
|
|
114
114
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
115
|
-
b_ptrs =
|
116
|
-
|
115
|
+
b_ptrs = (
|
116
|
+
b_ptr
|
117
|
+
+ off_experts * stride_be
|
118
|
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
119
|
+
)
|
117
120
|
|
118
121
|
if use_fp8:
|
119
122
|
a_scale = tl.load(a_scale_ptr)
|
@@ -129,13 +132,12 @@ def fused_moe_kernel(
|
|
129
132
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
130
133
|
# Load the next block of A and B, generate a mask by checking the
|
131
134
|
# K dimension.
|
132
|
-
a = tl.load(
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
other=0.0)
|
135
|
+
a = tl.load(
|
136
|
+
a_ptrs,
|
137
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
138
|
+
other=0.0,
|
139
|
+
)
|
140
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
139
141
|
# We accumulate along the K dimension.
|
140
142
|
if use_fp8:
|
141
143
|
accumulator = tl.dot(a, b, acc=accumulator)
|
@@ -146,9 +148,7 @@ def fused_moe_kernel(
|
|
146
148
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
147
149
|
|
148
150
|
if MUL_ROUTED_WEIGHT:
|
149
|
-
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
150
|
-
mask=token_mask,
|
151
|
-
other=0)
|
151
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
152
152
|
accumulator = accumulator * moe_weight[:, None]
|
153
153
|
|
154
154
|
if use_fp8:
|
@@ -158,15 +158,14 @@ def fused_moe_kernel(
|
|
158
158
|
# -----------------------------------------------------------
|
159
159
|
# Write back the block of the output
|
160
160
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
161
|
-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
162
|
-
None, :]
|
161
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
163
162
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
164
163
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
165
164
|
|
166
165
|
|
167
166
|
def moe_align_block_size(
|
168
|
-
|
169
|
-
|
167
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
170
169
|
"""
|
171
170
|
Aligns the token distribution across experts to be compatible with block
|
172
171
|
size for matrix multiplication.
|
@@ -205,32 +204,38 @@ def moe_align_block_size(
|
|
205
204
|
by block_size for proper block matrix operations.
|
206
205
|
"""
|
207
206
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
208
|
-
sorted_ids = torch.empty(
|
209
|
-
|
210
|
-
|
207
|
+
sorted_ids = torch.empty(
|
208
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
209
|
+
)
|
211
210
|
sorted_ids.fill_(topk_ids.numel())
|
212
211
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
213
|
-
expert_ids = torch.empty(
|
214
|
-
|
215
|
-
|
216
|
-
num_tokens_post_pad = torch.empty((1),
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
expert_ids, num_tokens_post_pad)
|
212
|
+
expert_ids = torch.empty(
|
213
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
214
|
+
)
|
215
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
216
|
+
ops.moe_align_block_size(
|
217
|
+
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
218
|
+
)
|
221
219
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
222
220
|
|
223
221
|
|
224
|
-
def invoke_fused_moe_kernel(
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
222
|
+
def invoke_fused_moe_kernel(
|
223
|
+
A: torch.Tensor,
|
224
|
+
B: torch.Tensor,
|
225
|
+
C: torch.Tensor,
|
226
|
+
A_scale: Optional[torch.Tensor],
|
227
|
+
B_scale: Optional[torch.Tensor],
|
228
|
+
topk_weights: torch.Tensor,
|
229
|
+
topk_ids: torch.Tensor,
|
230
|
+
sorted_token_ids: torch.Tensor,
|
231
|
+
expert_ids: torch.Tensor,
|
232
|
+
num_tokens_post_padded: torch.Tensor,
|
233
|
+
mul_routed_weight: bool,
|
234
|
+
top_k: int,
|
235
|
+
config: Dict[str, Any],
|
236
|
+
compute_type: tl.dtype,
|
237
|
+
use_fp8: bool,
|
238
|
+
) -> None:
|
234
239
|
assert topk_weights.stride(1) == 1
|
235
240
|
assert sorted_token_ids.stride(0) == 1
|
236
241
|
|
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|
241
246
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
242
247
|
assert B_scale is not None
|
243
248
|
|
244
|
-
grid = lambda META: (
|
245
|
-
|
249
|
+
grid = lambda META: (
|
250
|
+
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
251
|
+
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
252
|
+
)
|
246
253
|
|
247
254
|
fused_moe_kernel[grid](
|
248
255
|
A,
|
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
|
280
287
|
|
281
288
|
|
282
289
|
@functools.lru_cache
|
283
|
-
def get_moe_configs(E: int, N: int,
|
284
|
-
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
290
|
+
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
285
291
|
"""
|
286
292
|
Return optimized configurations for the fused MoE kernel.
|
287
293
|
|
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
|
|
296
302
|
json_file_name = get_config_file_name(E, N, dtype)
|
297
303
|
|
298
304
|
config_file_path = os.path.join(
|
299
|
-
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
305
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
306
|
+
)
|
300
307
|
if os.path.exists(config_file_path):
|
301
308
|
with open(config_file_path) as f:
|
302
|
-
logger.info("Using configuration from %s for MoE layer.",
|
303
|
-
config_file_path)
|
309
|
+
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
304
310
|
# If a configuration has been found, return it
|
305
311
|
return {int(key): val for key, val in json.load(f).items()}
|
306
312
|
|
@@ -319,35 +325,35 @@ def get_default_config(
|
|
319
325
|
) -> Dict[str, int]:
|
320
326
|
if dtype == "float8":
|
321
327
|
config = {
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
328
|
+
"BLOCK_SIZE_M": 128,
|
329
|
+
"BLOCK_SIZE_N": 256,
|
330
|
+
"BLOCK_SIZE_K": 128,
|
331
|
+
"GROUP_SIZE_M": 32,
|
326
332
|
"num_warps": 8,
|
327
|
-
"num_stages": 4
|
333
|
+
"num_stages": 4,
|
328
334
|
}
|
329
335
|
if M <= E:
|
330
336
|
config = {
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
337
|
+
"BLOCK_SIZE_M": 64,
|
338
|
+
"BLOCK_SIZE_N": 128,
|
339
|
+
"BLOCK_SIZE_K": 128,
|
340
|
+
"GROUP_SIZE_M": 1,
|
335
341
|
"num_warps": 4,
|
336
|
-
"num_stages": 4
|
342
|
+
"num_stages": 4,
|
337
343
|
}
|
338
344
|
else:
|
339
345
|
config = {
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
346
|
+
"BLOCK_SIZE_M": 64,
|
347
|
+
"BLOCK_SIZE_N": 64,
|
348
|
+
"BLOCK_SIZE_K": 32,
|
349
|
+
"GROUP_SIZE_M": 8,
|
344
350
|
}
|
345
351
|
if M <= E:
|
346
352
|
config = {
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
353
|
+
"BLOCK_SIZE_M": 16,
|
354
|
+
"BLOCK_SIZE_N": 32,
|
355
|
+
"BLOCK_SIZE_K": 64,
|
356
|
+
"GROUP_SIZE_M": 1,
|
351
357
|
}
|
352
358
|
return config
|
353
359
|
|
@@ -358,23 +364,17 @@ def fused_topk(
|
|
358
364
|
topk: int,
|
359
365
|
renormalize: bool,
|
360
366
|
):
|
361
|
-
assert hidden_states.shape[0] == gating_output.shape[0],
|
362
|
-
"Number of tokens mismatch")
|
367
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
363
368
|
|
364
369
|
M, _ = hidden_states.shape
|
365
370
|
|
366
|
-
topk_weights = torch.empty(
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
device=hidden_states.device)
|
374
|
-
token_expert_indicies = torch.empty(M,
|
375
|
-
topk,
|
376
|
-
dtype=torch.int32,
|
377
|
-
device=hidden_states.device)
|
371
|
+
topk_weights = torch.empty(
|
372
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
373
|
+
)
|
374
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
375
|
+
token_expert_indicies = torch.empty(
|
376
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
377
|
+
)
|
378
378
|
ops.topk_softmax(
|
379
379
|
topk_weights,
|
380
380
|
topk_ids,
|
@@ -388,27 +388,27 @@ def fused_topk(
|
|
388
388
|
return topk_weights, topk_ids
|
389
389
|
|
390
390
|
|
391
|
-
def fused_experts(
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
391
|
+
def fused_experts(
|
392
|
+
hidden_states: torch.Tensor,
|
393
|
+
w1: torch.Tensor,
|
394
|
+
w2: torch.Tensor,
|
395
|
+
topk_weights: torch.Tensor,
|
396
|
+
topk_ids: torch.Tensor,
|
397
|
+
inplace: bool = False,
|
398
|
+
override_config: Optional[Dict[str, Any]] = None,
|
399
|
+
use_fp8: bool = False,
|
400
|
+
w1_scale: Optional[torch.Tensor] = None,
|
401
|
+
w2_scale: Optional[torch.Tensor] = None,
|
402
|
+
a1_scale: Optional[torch.Tensor] = None,
|
403
|
+
a2_scale: Optional[torch.Tensor] = None,
|
404
|
+
):
|
403
405
|
# Check constraints.
|
404
406
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
405
407
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
406
408
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
407
409
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
408
410
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
409
|
-
assert hidden_states.dtype in [
|
410
|
-
torch.float32, torch.float16, torch.bfloat16
|
411
|
-
]
|
411
|
+
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
412
412
|
|
413
413
|
M, _ = hidden_states.shape
|
414
414
|
E, N, _ = w1.shape
|
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|
417
417
|
config = override_config
|
418
418
|
else:
|
419
419
|
# First try to load optimal config from the file
|
420
|
-
configs = get_moe_configs(E, w2.shape[2],
|
421
|
-
"float8" if use_fp8 else None)
|
420
|
+
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
422
421
|
|
423
422
|
if configs:
|
424
423
|
# If an optimal configuration map has been found, look up the
|
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
|
|
426
425
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
427
426
|
else:
|
428
427
|
# Else use the default config
|
429
|
-
config = get_default_config(
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
intermediate_cache1 = torch.empty(
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
428
|
+
config = get_default_config(
|
429
|
+
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
430
|
+
)
|
431
|
+
|
432
|
+
intermediate_cache1 = torch.empty(
|
433
|
+
(M, topk_ids.shape[1], N),
|
434
|
+
device=hidden_states.device,
|
435
|
+
dtype=hidden_states.dtype,
|
436
|
+
)
|
437
|
+
intermediate_cache2 = torch.empty(
|
438
|
+
(M * topk_ids.shape[1], N // 2),
|
439
|
+
device=hidden_states.device,
|
440
|
+
dtype=hidden_states.dtype,
|
441
|
+
)
|
442
|
+
intermediate_cache3 = torch.empty(
|
443
|
+
(M, topk_ids.shape[1], w2.shape[1]),
|
444
|
+
device=hidden_states.device,
|
445
|
+
dtype=hidden_states.dtype,
|
446
|
+
)
|
442
447
|
|
443
448
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
444
|
-
topk_ids, config[
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
invoke_fused_moe_kernel(
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
449
|
+
topk_ids, config["BLOCK_SIZE_M"], E
|
450
|
+
)
|
451
|
+
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
452
|
+
|
453
|
+
invoke_fused_moe_kernel(
|
454
|
+
hidden_states,
|
455
|
+
w1,
|
456
|
+
intermediate_cache1,
|
457
|
+
a1_scale,
|
458
|
+
w1_scale,
|
459
|
+
topk_weights,
|
460
|
+
topk_ids,
|
461
|
+
sorted_token_ids,
|
462
|
+
expert_ids,
|
463
|
+
num_tokens_post_padded,
|
464
|
+
False,
|
465
|
+
topk_ids.shape[1],
|
466
|
+
config,
|
467
|
+
compute_type=compute_type,
|
468
|
+
use_fp8=use_fp8,
|
469
|
+
)
|
463
470
|
|
464
471
|
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
465
472
|
|
466
|
-
invoke_fused_moe_kernel(
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
473
|
+
invoke_fused_moe_kernel(
|
474
|
+
intermediate_cache2,
|
475
|
+
w2,
|
476
|
+
intermediate_cache3,
|
477
|
+
a2_scale,
|
478
|
+
w2_scale,
|
479
|
+
topk_weights,
|
480
|
+
topk_ids,
|
481
|
+
sorted_token_ids,
|
482
|
+
expert_ids,
|
483
|
+
num_tokens_post_padded,
|
484
|
+
True,
|
485
|
+
1,
|
486
|
+
config,
|
487
|
+
compute_type=compute_type,
|
488
|
+
use_fp8=use_fp8,
|
489
|
+
)
|
481
490
|
|
482
491
|
if inplace:
|
483
|
-
return torch.sum(
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
492
|
+
return torch.sum(
|
493
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
494
|
+
dim=1,
|
495
|
+
out=hidden_states,
|
496
|
+
)
|
497
|
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
488
498
|
|
489
499
|
|
490
500
|
def fused_moe(
|
@@ -532,25 +542,28 @@ def fused_moe(
|
|
532
542
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
533
543
|
|
534
544
|
if hasattr(ops, "topk_softmax"):
|
535
|
-
topk_weights, topk_ids = fused_topk(
|
536
|
-
|
545
|
+
topk_weights, topk_ids = fused_topk(
|
546
|
+
hidden_states, gating_output, topk, renormalize
|
547
|
+
)
|
537
548
|
else:
|
538
|
-
topk_weights, topk_ids = fused_topk_v0_4_3(
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
549
|
+
topk_weights, topk_ids = fused_topk_v0_4_3(
|
550
|
+
hidden_states, gating_output, topk, renormalize
|
551
|
+
)
|
552
|
+
|
553
|
+
return fused_experts(
|
554
|
+
hidden_states,
|
555
|
+
w1,
|
556
|
+
w2,
|
557
|
+
topk_weights,
|
558
|
+
topk_ids,
|
559
|
+
inplace=inplace,
|
560
|
+
override_config=override_config,
|
561
|
+
use_fp8=use_fp8,
|
562
|
+
w1_scale=w1_scale,
|
563
|
+
w2_scale=w2_scale,
|
564
|
+
a1_scale=a1_scale,
|
565
|
+
a2_scale=a2_scale,
|
566
|
+
)
|
554
567
|
|
555
568
|
|
556
569
|
def fused_topk_v0_4_3(
|
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
|
|
560
573
|
renormalize: bool,
|
561
574
|
):
|
562
575
|
import vllm._moe_C as moe_kernels
|
576
|
+
|
563
577
|
M, _ = hidden_states.shape
|
564
578
|
|
565
579
|
topk_weights = torch.empty(
|
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
|
|
579
593
|
if renormalize:
|
580
594
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
581
595
|
|
582
|
-
return topk_weights, topk_ids
|
596
|
+
return topk_weights, topk_ids
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Logits processing."""
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import List
|
4
|
+
from typing import List, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
|
|
31
31
|
decode_top_logprobs: List
|
32
32
|
|
33
33
|
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class LogitsMetadata:
|
36
|
+
forward_mode: ForwardMode
|
37
|
+
extend_seq_lens: torch.Tensor
|
38
|
+
extend_start_loc: torch.Tensor
|
39
|
+
|
40
|
+
# For logprobs
|
41
|
+
return_logprob: bool
|
42
|
+
top_logprobs_nums: List[int]
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def from_input_metadata(cls, input_metadata: InputMetadata):
|
46
|
+
return cls(
|
47
|
+
forward_mode=input_metadata.forward_mode,
|
48
|
+
extend_seq_lens=input_metadata.extend_seq_lens,
|
49
|
+
extend_start_loc=input_metadata.extend_start_loc,
|
50
|
+
return_logprob=input_metadata.return_logprob,
|
51
|
+
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
34
55
|
class LogitsProcessor(nn.Module):
|
35
56
|
def __init__(self, config):
|
36
57
|
super().__init__()
|
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
|
|
38
59
|
self.tp_size = get_tensor_model_parallel_world_size()
|
39
60
|
|
40
61
|
def _get_normalized_prompt_logprobs(
|
41
|
-
self, prefill_token_logprobs,
|
62
|
+
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
42
63
|
):
|
43
64
|
logprobs_cumsum = torch.cumsum(
|
44
65
|
prefill_token_logprobs, dim=0, dtype=torch.float32
|
45
66
|
)
|
46
67
|
|
47
|
-
start =
|
48
|
-
end = start +
|
68
|
+
start = logits_metadata.extend_start_loc.clone()
|
69
|
+
end = start + logits_metadata.extend_seq_lens - 2
|
49
70
|
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
50
71
|
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
51
72
|
sum_logp = (
|
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
|
|
54
75
|
+ prefill_token_logprobs[start]
|
55
76
|
)
|
56
77
|
normalized_prompt_logprobs = sum_logp / (
|
57
|
-
(
|
78
|
+
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
58
79
|
)
|
59
80
|
|
60
81
|
return normalized_prompt_logprobs
|
61
82
|
|
62
|
-
def _get_top_logprobs(self, all_logprobs,
|
83
|
+
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
63
84
|
# TODO: vectorize the code below
|
64
|
-
if
|
85
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
65
86
|
decode_top_logprobs = []
|
66
87
|
for i in range(all_logprobs.shape[0]):
|
67
|
-
k =
|
88
|
+
k = logits_metadata.top_logprobs_nums[i]
|
68
89
|
t = all_logprobs[i].topk(k)
|
69
90
|
v_cpu = t.values.tolist()
|
70
91
|
p_cpu = t.indices.tolist()
|
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
|
|
73
94
|
else:
|
74
95
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
75
96
|
pt = 0
|
76
|
-
extend_seq_lens_cpu =
|
97
|
+
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
77
98
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
78
99
|
if extend_seq_len == 0:
|
79
100
|
prefill_top_logprobs.append([])
|
80
101
|
decode_top_logprobs.append([])
|
81
102
|
continue
|
82
|
-
k =
|
103
|
+
k = logits_metadata.top_logprobs_nums[i]
|
83
104
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
84
105
|
vs_cpu = t.values.tolist()
|
85
106
|
ps_cpu = t.indices.tolist()
|
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
|
|
91
112
|
|
92
113
|
return prefill_top_logprobs, decode_top_logprobs
|
93
114
|
|
94
|
-
def forward(
|
115
|
+
def forward(
|
116
|
+
self,
|
117
|
+
input_ids,
|
118
|
+
hidden_states,
|
119
|
+
weight,
|
120
|
+
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
121
|
+
):
|
122
|
+
if isinstance(logits_metadata, InputMetadata):
|
123
|
+
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
124
|
+
assert isinstance(logits_metadata, LogitsMetadata)
|
125
|
+
|
95
126
|
# Get the last hidden states and last logits for the next token prediction
|
96
|
-
if
|
127
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
97
128
|
last_index = None
|
98
129
|
last_hidden = hidden_states
|
99
130
|
else:
|
100
131
|
last_index = (
|
101
|
-
torch.cumsum(
|
132
|
+
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
102
133
|
- 1
|
103
134
|
)
|
104
135
|
last_hidden = hidden_states[last_index]
|
@@ -108,8 +139,13 @@ class LogitsProcessor(nn.Module):
|
|
108
139
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
109
140
|
last_logits = last_logits[:, : self.config.vocab_size]
|
110
141
|
|
142
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
143
|
+
last_logits /= self.config.final_logit_softcapping
|
144
|
+
last_logits = torch.tanh(last_logits)
|
145
|
+
last_logits *= self.config.final_logit_softcapping
|
146
|
+
|
111
147
|
# Return only last_logits if logprob is not requested
|
112
|
-
if not
|
148
|
+
if not logits_metadata.return_logprob:
|
113
149
|
return LogitProcessorOutput(
|
114
150
|
next_token_logits=last_logits,
|
115
151
|
next_token_logprobs=None,
|
@@ -120,7 +156,7 @@ class LogitsProcessor(nn.Module):
|
|
120
156
|
)
|
121
157
|
else:
|
122
158
|
# When logprob is requested, compute the logits for all tokens.
|
123
|
-
if
|
159
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
124
160
|
all_logits = last_logits
|
125
161
|
else:
|
126
162
|
all_logits = torch.matmul(hidden_states, weight.T)
|
@@ -133,15 +169,15 @@ class LogitsProcessor(nn.Module):
|
|
133
169
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
134
170
|
|
135
171
|
# Get the logprob of top-k tokens
|
136
|
-
return_top_logprob = any(x > 0 for x in
|
172
|
+
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
137
173
|
if return_top_logprob:
|
138
174
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
139
|
-
all_logprobs,
|
175
|
+
all_logprobs, logits_metadata
|
140
176
|
)
|
141
177
|
else:
|
142
178
|
prefill_top_logprobs = decode_top_logprobs = None
|
143
179
|
|
144
|
-
if
|
180
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
145
181
|
return LogitProcessorOutput(
|
146
182
|
next_token_logits=last_logits,
|
147
183
|
next_token_logprobs=all_logprobs,
|
@@ -161,7 +197,7 @@ class LogitsProcessor(nn.Module):
|
|
161
197
|
]
|
162
198
|
|
163
199
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
164
|
-
prefill_token_logprobs,
|
200
|
+
prefill_token_logprobs, logits_metadata
|
165
201
|
)
|
166
202
|
|
167
203
|
return LogitProcessorOutput(
|