sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,17 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
6
|
+
|
1
7
|
import torch
|
2
8
|
import triton
|
3
9
|
import triton.language as tl
|
4
10
|
|
11
|
+
from sglang.srt.utils import get_device_name
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
5
15
|
|
6
16
|
@triton.jit
|
7
17
|
def _per_token_quant_int8(
|
@@ -52,3 +62,320 @@ def per_token_quant_int8(x):
|
|
52
62
|
)
|
53
63
|
|
54
64
|
return x_q, scales
|
65
|
+
|
66
|
+
|
67
|
+
@triton.jit
|
68
|
+
def _per_token_group_quant_int8(
|
69
|
+
# Pointers to inputs and output
|
70
|
+
y_ptr,
|
71
|
+
y_q_ptr,
|
72
|
+
y_s_ptr,
|
73
|
+
# Stride of input
|
74
|
+
y_stride,
|
75
|
+
# Collums of input
|
76
|
+
N,
|
77
|
+
# Avoid to divide zero
|
78
|
+
eps,
|
79
|
+
# Information for int8
|
80
|
+
int8_min,
|
81
|
+
int8_max,
|
82
|
+
# Meta-parameters
|
83
|
+
BLOCK: tl.constexpr,
|
84
|
+
):
|
85
|
+
"""A Triton-accelerated function to perform per-token-group quantization on a
|
86
|
+
tensor.
|
87
|
+
|
88
|
+
This function converts the tensor values into int8 values.
|
89
|
+
"""
|
90
|
+
# Map the program id to the row of X and Y it should compute.
|
91
|
+
g_id = tl.program_id(0)
|
92
|
+
y_ptr += g_id * y_stride
|
93
|
+
y_q_ptr += g_id * y_stride
|
94
|
+
y_s_ptr += g_id
|
95
|
+
|
96
|
+
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
97
|
+
mask = cols < N
|
98
|
+
|
99
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
100
|
+
# Quant
|
101
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
102
|
+
y_s = _absmax / int8_max
|
103
|
+
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
|
104
|
+
|
105
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
106
|
+
tl.store(y_s_ptr, y_s)
|
107
|
+
|
108
|
+
|
109
|
+
def per_token_group_quant_int8(
|
110
|
+
x: torch.Tensor,
|
111
|
+
group_size: int,
|
112
|
+
eps: float = 1e-10,
|
113
|
+
dtype: torch.dtype = torch.int8,
|
114
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
115
|
+
"""Function to perform per-token-group quantization on an input tensor `x`.
|
116
|
+
|
117
|
+
It converts the tensor values into signed int8 values and returns the
|
118
|
+
quantized tensor along with the scaling factor used for quantization.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
x: The input tenosr with ndim >= 2.
|
122
|
+
group_size: The group size used for quantization.
|
123
|
+
eps: The minimum to avoid dividing zero.
|
124
|
+
dtype: The dype of output tensor. Note that only `torch.int8` is supported for now.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
128
|
+
"""
|
129
|
+
assert (
|
130
|
+
x.shape[-1] % group_size == 0
|
131
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
132
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
133
|
+
|
134
|
+
iinfo = torch.iinfo(dtype)
|
135
|
+
int8_max = iinfo.max
|
136
|
+
int8_min = iinfo.min
|
137
|
+
|
138
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
139
|
+
M = x.numel() // group_size
|
140
|
+
N = group_size
|
141
|
+
x_s = torch.empty(
|
142
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
143
|
+
device=x.device,
|
144
|
+
dtype=torch.float32,
|
145
|
+
)
|
146
|
+
|
147
|
+
BLOCK = triton.next_power_of_2(N)
|
148
|
+
# heuristics for number of warps
|
149
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
150
|
+
num_stages = 1
|
151
|
+
_per_token_group_quant_int8[(M,)](
|
152
|
+
x,
|
153
|
+
x_q,
|
154
|
+
x_s,
|
155
|
+
group_size,
|
156
|
+
N,
|
157
|
+
eps,
|
158
|
+
int8_min=int8_min,
|
159
|
+
int8_max=int8_max,
|
160
|
+
BLOCK=BLOCK,
|
161
|
+
num_warps=num_warps,
|
162
|
+
num_stages=num_stages,
|
163
|
+
)
|
164
|
+
|
165
|
+
return x_q, x_s
|
166
|
+
|
167
|
+
|
168
|
+
@triton.jit
|
169
|
+
def _w8a8_block_int8_matmul(
|
170
|
+
# Pointers to inputs and output
|
171
|
+
A,
|
172
|
+
B,
|
173
|
+
C,
|
174
|
+
As,
|
175
|
+
Bs,
|
176
|
+
# Shape for matmul
|
177
|
+
M,
|
178
|
+
N,
|
179
|
+
K,
|
180
|
+
# Block size for block-wise quantization
|
181
|
+
group_n,
|
182
|
+
group_k,
|
183
|
+
# Stride for inputs and output
|
184
|
+
stride_am,
|
185
|
+
stride_ak,
|
186
|
+
stride_bk,
|
187
|
+
stride_bn,
|
188
|
+
stride_cm,
|
189
|
+
stride_cn,
|
190
|
+
stride_As_m,
|
191
|
+
stride_As_k,
|
192
|
+
stride_Bs_k,
|
193
|
+
stride_Bs_n,
|
194
|
+
# Meta-parameters
|
195
|
+
BLOCK_SIZE_M: tl.constexpr,
|
196
|
+
BLOCK_SIZE_N: tl.constexpr,
|
197
|
+
BLOCK_SIZE_K: tl.constexpr,
|
198
|
+
GROUP_SIZE_M: tl.constexpr,
|
199
|
+
):
|
200
|
+
"""Triton-accelerated function used to perform linear operations (dot
|
201
|
+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
202
|
+
tensor `C`.
|
203
|
+
"""
|
204
|
+
|
205
|
+
pid = tl.program_id(axis=0)
|
206
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
207
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
208
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
209
|
+
group_id = pid // num_pid_in_group
|
210
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
211
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
212
|
+
pid_m = first_pid_m + (pid % group_size_m)
|
213
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
214
|
+
|
215
|
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
216
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
217
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
218
|
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
219
|
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
220
|
+
|
221
|
+
As_ptrs = As + offs_am * stride_As_m
|
222
|
+
offs_bsn = offs_bn // group_n
|
223
|
+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
224
|
+
|
225
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
226
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
227
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
228
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
229
|
+
|
230
|
+
k_start = k * BLOCK_SIZE_K
|
231
|
+
offs_ks = k_start // group_k
|
232
|
+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
233
|
+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
234
|
+
|
235
|
+
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
|
236
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
237
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
238
|
+
|
239
|
+
if C.dtype.element_ty == tl.bfloat16:
|
240
|
+
c = accumulator.to(tl.bfloat16)
|
241
|
+
elif C.dtype.element_ty == tl.float16:
|
242
|
+
c = accumulator.to(tl.float16)
|
243
|
+
else:
|
244
|
+
c = accumulator.to(tl.float32)
|
245
|
+
|
246
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
247
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
248
|
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
249
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
250
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
251
|
+
|
252
|
+
|
253
|
+
@functools.lru_cache
|
254
|
+
def get_w8a8_block_int8_configs(
|
255
|
+
N: int, K: int, block_n: int, block_k: int
|
256
|
+
) -> Optional[Dict[int, Any]]:
|
257
|
+
"""
|
258
|
+
Return optimized configurations for the w8a8 block fp8 kernel.
|
259
|
+
|
260
|
+
The return value will be a dictionary that maps an irregular grid of
|
261
|
+
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
|
262
|
+
kernel on a given batch size bs, the closest batch size in the grid should
|
263
|
+
be picked and the associated configuration chosen to invoke the kernel.
|
264
|
+
"""
|
265
|
+
|
266
|
+
# First look up if an optimized configuration is available in the configs
|
267
|
+
# directory
|
268
|
+
device_name = get_device_name().replace(" ", "_")
|
269
|
+
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
270
|
+
|
271
|
+
config_file_path = os.path.join(
|
272
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
273
|
+
)
|
274
|
+
if os.path.exists(config_file_path):
|
275
|
+
with open(config_file_path) as f:
|
276
|
+
logger.info(
|
277
|
+
"Using configuration from %s for W8A8 Block INT8 kernel.",
|
278
|
+
config_file_path,
|
279
|
+
)
|
280
|
+
# If a configuration has been found, return it
|
281
|
+
return {int(key): val for key, val in json.load(f).items()}
|
282
|
+
|
283
|
+
# If no optimized configuration is available, we will use the default
|
284
|
+
# configuration
|
285
|
+
logger.warning(
|
286
|
+
(
|
287
|
+
"Using default W8A8 Block INT8 kernel config. Performance might be sub-optimal! "
|
288
|
+
"Config file not found at %s"
|
289
|
+
),
|
290
|
+
config_file_path,
|
291
|
+
)
|
292
|
+
return None
|
293
|
+
|
294
|
+
|
295
|
+
def w8a8_block_int8_matmul(
|
296
|
+
A: torch.Tensor,
|
297
|
+
B: torch.Tensor,
|
298
|
+
As: torch.Tensor,
|
299
|
+
Bs: torch.Tensor,
|
300
|
+
block_size: List[int],
|
301
|
+
output_dtype: torch.dtype = torch.float16,
|
302
|
+
) -> torch.Tensor:
|
303
|
+
"""This function performs matrix multiplication with block-wise quantization.
|
304
|
+
|
305
|
+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
306
|
+
The output is returned in the specified `output_dtype`.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
A: The input tensor, e.g., activation.
|
310
|
+
B: The input tensor, e.g., weight.
|
311
|
+
As: The per-token-group quantization scale for `A`.
|
312
|
+
Bs: The per-block quantization scale for `B`.
|
313
|
+
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
314
|
+
output_dytpe: The dtype of the returned tensor.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
torch.Tensor: The result of matmul.
|
318
|
+
"""
|
319
|
+
assert len(block_size) == 2
|
320
|
+
block_n, block_k = block_size[0], block_size[1]
|
321
|
+
|
322
|
+
assert A.shape[-1] == B.shape[-1]
|
323
|
+
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
324
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
325
|
+
M = A.numel() // A.shape[-1]
|
326
|
+
|
327
|
+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
328
|
+
N, K = B.shape
|
329
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
330
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
331
|
+
|
332
|
+
C_shape = A.shape[:-1] + (N,)
|
333
|
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
334
|
+
|
335
|
+
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
|
336
|
+
if configs:
|
337
|
+
# If an optimal configuration map has been found, look up the
|
338
|
+
# optimal config
|
339
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
340
|
+
else:
|
341
|
+
# Default config
|
342
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
343
|
+
config = {
|
344
|
+
"BLOCK_SIZE_M": 64,
|
345
|
+
"BLOCK_SIZE_N": block_size[0],
|
346
|
+
"BLOCK_SIZE_K": block_size[1],
|
347
|
+
"GROUP_SIZE_M": 32,
|
348
|
+
"num_warps": 4,
|
349
|
+
"num_stages": 3,
|
350
|
+
}
|
351
|
+
|
352
|
+
def grid(META):
|
353
|
+
return (
|
354
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
355
|
+
)
|
356
|
+
|
357
|
+
_w8a8_block_int8_matmul[grid](
|
358
|
+
A,
|
359
|
+
B,
|
360
|
+
C,
|
361
|
+
As,
|
362
|
+
Bs,
|
363
|
+
M,
|
364
|
+
N,
|
365
|
+
K,
|
366
|
+
block_n,
|
367
|
+
block_k,
|
368
|
+
A.stride(-2),
|
369
|
+
A.stride(-1),
|
370
|
+
B.stride(1),
|
371
|
+
B.stride(0),
|
372
|
+
C.stride(-2),
|
373
|
+
C.stride(-1),
|
374
|
+
As.stride(-2),
|
375
|
+
As.stride(-1),
|
376
|
+
Bs.stride(1),
|
377
|
+
Bs.stride(0),
|
378
|
+
**config,
|
379
|
+
)
|
380
|
+
|
381
|
+
return C
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from typing import List, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
6
|
+
per_token_group_quant_int8,
|
7
|
+
w8a8_block_int8_matmul,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
def apply_w8a8_block_int8_linear(
|
12
|
+
input: torch.Tensor,
|
13
|
+
weight: torch.Tensor,
|
14
|
+
block_size: List[int],
|
15
|
+
weight_scale: torch.Tensor,
|
16
|
+
input_scale: Optional[torch.Tensor] = None,
|
17
|
+
bias: Optional[torch.Tensor] = None,
|
18
|
+
) -> torch.Tensor:
|
19
|
+
assert input_scale is None
|
20
|
+
# View input as 2D matrix for fp8 methods
|
21
|
+
input_2d = input.view(-1, input.shape[-1])
|
22
|
+
output_shape = [*input.shape[:-1], weight.shape[0]]
|
23
|
+
|
24
|
+
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
|
25
|
+
output = w8a8_block_int8_matmul(
|
26
|
+
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
27
|
+
)
|
28
|
+
|
29
|
+
if bias is not None:
|
30
|
+
output = output + bias
|
31
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
32
|
+
|
33
|
+
|
34
|
+
def input_to_int8(
|
35
|
+
x: torch.Tensor, dtype: torch.dtype = torch.int8
|
36
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
37
|
+
"""This function quantizes input values to int8 values with tensor-wise quantization."""
|
38
|
+
iinfo = torch.iinfo(dtype)
|
39
|
+
min_val, max_val = x.aminmax()
|
40
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
41
|
+
int8_min, int8_max = iinfo.min, iinfo.max
|
42
|
+
scale = int8_max / amax
|
43
|
+
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
|
44
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
45
|
+
|
46
|
+
|
47
|
+
def block_dequant(
|
48
|
+
x_q_block: torch.Tensor,
|
49
|
+
x_s: torch.Tensor,
|
50
|
+
block_size: List[int],
|
51
|
+
) -> torch.Tensor:
|
52
|
+
"""This function conducts block-wise dequantization.
|
53
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
54
|
+
and the block size.
|
55
|
+
The outputs are dequantized tensor.
|
56
|
+
"""
|
57
|
+
block_n, block_k = block_size[0], block_size[1]
|
58
|
+
n, k = x_q_block.shape
|
59
|
+
n_tiles = (n + block_n - 1) // block_n
|
60
|
+
k_tiles = (k + block_k - 1) // block_k
|
61
|
+
assert n_tiles == x_s.shape[0]
|
62
|
+
assert k_tiles == x_s.shape[1]
|
63
|
+
|
64
|
+
x_dq_block = x_q_block.to(torch.float32)
|
65
|
+
|
66
|
+
for i in range(k_tiles):
|
67
|
+
for j in range(n_tiles):
|
68
|
+
x_dq_block[
|
69
|
+
j * block_n : min((j + 1) * block_n, n),
|
70
|
+
i * block_k : min((i + 1) * block_k, k),
|
71
|
+
] *= x_s[j][i]
|
72
|
+
|
73
|
+
return x_dq_block
|
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn.parameter import Parameter
|
8
|
+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
8
9
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
9
10
|
apply_fp8_linear,
|
10
11
|
cutlass_fp8_supported,
|
11
12
|
requantize_with_max_scale,
|
12
13
|
)
|
13
14
|
|
15
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
14
16
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
15
17
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
16
18
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
70
72
|
def get_quant_method(
|
71
73
|
self, layer: torch.nn.Module, prefix: str
|
72
74
|
) -> Optional["QuantizeMethodBase"]:
|
73
|
-
|
75
|
+
|
76
|
+
if isinstance(layer, LinearBase):
|
77
|
+
return ModelOptFp8LinearMethod(self)
|
78
|
+
if isinstance(layer, AttentionBackend):
|
79
|
+
return ModelOptFp8KVCacheMethod(self)
|
80
|
+
|
81
|
+
return None
|
74
82
|
|
75
83
|
def get_scaled_act_names(self) -> List[str]:
|
76
84
|
return []
|
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|
171
179
|
bias=bias,
|
172
180
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
173
181
|
)
|
182
|
+
|
183
|
+
|
184
|
+
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
185
|
+
"""
|
186
|
+
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
|
187
|
+
"""
|
188
|
+
|
189
|
+
def __init__(self, quant_config: ModelOptFp8Config):
|
190
|
+
super().__init__(quant_config)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -29,7 +29,7 @@ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
|
29
29
|
class Sampler(nn.Module):
|
30
30
|
def __init__(self):
|
31
31
|
super().__init__()
|
32
|
-
self.
|
32
|
+
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
33
33
|
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
34
34
|
|
35
35
|
if global_server_args_dict["enable_dp_attention"]:
|
@@ -41,14 +41,28 @@ class Sampler(nn.Module):
|
|
41
41
|
sampling_info: SamplingBatchInfo,
|
42
42
|
return_logprob: bool,
|
43
43
|
top_logprobs_nums: List[int],
|
44
|
+
token_ids_logprobs: List[List[int]],
|
45
|
+
batch_next_token_ids: Optional[torch.Tensor] = None,
|
44
46
|
):
|
47
|
+
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
logits_output: The logits from the model forward
|
51
|
+
sampling_info: Metadata for sampling
|
52
|
+
return_logprob: If set, store the output logprob information to
|
53
|
+
logits_output
|
54
|
+
top_logprobs_nums: Number of top lobprobs per sequence in a batch
|
55
|
+
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
56
|
+
compute output logprobs It is used for speculative decoding which
|
57
|
+
performs sampling in draft workers.
|
58
|
+
"""
|
45
59
|
logits = logits_output.next_token_logits
|
46
60
|
|
47
61
|
# Apply the custom logit processors if registered in the sampling info.
|
48
62
|
if sampling_info.has_custom_logit_processor:
|
49
63
|
self._apply_custom_logit_processor(logits, sampling_info)
|
50
64
|
|
51
|
-
if self.
|
65
|
+
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
52
66
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
53
67
|
logits = torch.where(
|
54
68
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
@@ -58,13 +72,15 @@ class Sampler(nn.Module):
|
|
58
72
|
|
59
73
|
if sampling_info.is_all_greedy:
|
60
74
|
# Use torch.argmax if all requests use greedy sampling
|
61
|
-
batch_next_token_ids
|
75
|
+
if batch_next_token_ids is None:
|
76
|
+
batch_next_token_ids = torch.argmax(logits, -1)
|
62
77
|
if return_logprob:
|
63
78
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
64
79
|
else:
|
65
80
|
# Post process logits
|
66
81
|
logits.div_(sampling_info.temperatures)
|
67
|
-
|
82
|
+
logits[:] = torch.softmax(logits, dim=-1)
|
83
|
+
probs = logits
|
68
84
|
del logits
|
69
85
|
|
70
86
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -78,38 +94,43 @@ class Sampler(nn.Module):
|
|
78
94
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
79
95
|
).clamp(min=torch.finfo(probs.dtype).min)
|
80
96
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
if sampling_info.need_min_p_sampling:
|
86
|
-
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
87
|
-
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
88
|
-
batch_next_token_ids = min_p_sampling_from_probs(
|
89
|
-
probs, uniform_samples, sampling_info.min_ps
|
97
|
+
if batch_next_token_ids is None:
|
98
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
99
|
+
uniform_samples = torch.rand(
|
100
|
+
(max_top_k_round, batch_size), device=probs.device
|
90
101
|
)
|
91
|
-
|
92
|
-
|
102
|
+
if sampling_info.need_min_p_sampling:
|
103
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
104
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
105
|
+
batch_next_token_ids = min_p_sampling_from_probs(
|
106
|
+
probs, uniform_samples, sampling_info.min_ps
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
110
|
+
probs,
|
111
|
+
uniform_samples,
|
112
|
+
sampling_info.top_ks,
|
113
|
+
sampling_info.top_ps,
|
114
|
+
filter_apply_order="joint",
|
115
|
+
)
|
116
|
+
|
117
|
+
if self.use_nan_detection and not torch.all(success):
|
118
|
+
logger.warning("Detected errors during sampling!")
|
119
|
+
batch_next_token_ids = torch.zeros_like(
|
120
|
+
batch_next_token_ids
|
121
|
+
)
|
122
|
+
|
123
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
124
|
+
if batch_next_token_ids is None:
|
125
|
+
# A slower fallback implementation with torch native operations.
|
126
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
93
127
|
probs,
|
94
|
-
uniform_samples,
|
95
128
|
sampling_info.top_ks,
|
96
129
|
sampling_info.top_ps,
|
97
|
-
|
130
|
+
sampling_info.min_ps,
|
131
|
+
sampling_info.need_min_p_sampling,
|
98
132
|
)
|
99
133
|
|
100
|
-
if self.use_nan_detectioin and not torch.all(success):
|
101
|
-
logger.warning("Detected errors during sampling!")
|
102
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
103
|
-
|
104
|
-
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
105
|
-
# A slower fallback implementation with torch native operations.
|
106
|
-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
107
|
-
probs,
|
108
|
-
sampling_info.top_ks,
|
109
|
-
sampling_info.top_ps,
|
110
|
-
sampling_info.min_ps,
|
111
|
-
sampling_info.need_min_p_sampling,
|
112
|
-
)
|
113
134
|
if return_logprob:
|
114
135
|
# clamp to avoid -inf
|
115
136
|
logprobs = torch.log(
|
@@ -128,6 +149,12 @@ class Sampler(nn.Module):
|
|
128
149
|
logits_output.next_token_top_logprobs_idx,
|
129
150
|
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
130
151
|
|
152
|
+
if any(x is not None for x in token_ids_logprobs):
|
153
|
+
(
|
154
|
+
logits_output.next_token_token_ids_logprobs_val,
|
155
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
156
|
+
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
|
157
|
+
|
131
158
|
logits_output.next_token_logprobs = logprobs[
|
132
159
|
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
133
160
|
batch_next_token_ids,
|
@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
|
|
223
250
|
|
224
251
|
|
225
252
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
253
|
+
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
254
|
+
len(top_logprobs_nums),
|
255
|
+
logprobs.shape[0],
|
256
|
+
)
|
226
257
|
max_k = max(top_logprobs_nums)
|
227
258
|
ret = logprobs.topk(max_k, dim=1)
|
228
259
|
values = ret.values.tolist()
|
@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
|
234
265
|
output_top_logprobs_val.append(values[i][:k])
|
235
266
|
output_top_logprobs_idx.append(indices[i][:k])
|
236
267
|
return output_top_logprobs_val, output_top_logprobs_idx
|
268
|
+
|
269
|
+
|
270
|
+
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
271
|
+
output_token_ids_logprobs_val = []
|
272
|
+
output_token_ids_logprobs_idx = []
|
273
|
+
for i, token_ids in enumerate(token_ids_logprobs):
|
274
|
+
if token_ids is not None:
|
275
|
+
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
|
276
|
+
output_token_ids_logprobs_idx.append(token_ids)
|
277
|
+
else:
|
278
|
+
output_token_ids_logprobs_val.append([])
|
279
|
+
output_token_ids_logprobs_idx.append([])
|
280
|
+
|
281
|
+
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|