vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1051 @@
|
|
1
|
+
"""A layer that samples the next tokens from the model's outputs."""
|
2
|
+
import itertools
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
9
|
+
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
10
|
+
SamplingTensors,
|
11
|
+
SequenceGroupToSample)
|
12
|
+
from vllm.sampling_params import SamplingType
|
13
|
+
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
14
|
+
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
15
|
+
|
16
|
+
# (num_token_ids, num_parent_ids) per sequence group.
|
17
|
+
SampleResultType = List[Tuple[List[int], List[int]]]
|
18
|
+
|
19
|
+
|
20
|
+
class Sampler(nn.Module):
|
21
|
+
"""Samples the next tokens from the model's outputs.
|
22
|
+
|
23
|
+
This layer does the following:
|
24
|
+
1. Discard the hidden states that are not used for sampling (i.e., all
|
25
|
+
tokens except the final one in each prompt).
|
26
|
+
2. Compute the logits for the next tokens.
|
27
|
+
3. Apply presence, frequency and repetition penalties.
|
28
|
+
4. Apply temperature scaling.
|
29
|
+
5. Apply top-p and top-k truncation.
|
30
|
+
6. Sample the next tokens.
|
31
|
+
Here, each sequence group within the batch can have different sampling
|
32
|
+
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
33
|
+
|
34
|
+
The structure of the logits tensor is coupled with the seq_groups in
|
35
|
+
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
36
|
+
logits for the next token to be sampled; however, for a seq_group with a
|
37
|
+
prompt request with the prompt_logprobs sampling parameter, there are rows
|
38
|
+
in logits for each token in the input prompt.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self):
|
42
|
+
super().__init__()
|
43
|
+
|
44
|
+
# Whether or not the SamplerOutput should have on-device tensors
|
45
|
+
# containing the sampled token ids and probabilities. This is used by
|
46
|
+
# speculative decoding.
|
47
|
+
self.include_gpu_probs_tensor = False
|
48
|
+
|
49
|
+
def forward(
|
50
|
+
self,
|
51
|
+
logits: torch.Tensor,
|
52
|
+
sampling_metadata: SamplingMetadata,
|
53
|
+
) -> Optional[SamplerOutput]:
|
54
|
+
"""
|
55
|
+
Args:
|
56
|
+
logits: (num_tokens, vocab_size).
|
57
|
+
sampling_metadata: Metadata for sampling.
|
58
|
+
"""
|
59
|
+
assert logits is not None
|
60
|
+
_, vocab_size = logits.shape
|
61
|
+
|
62
|
+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
63
|
+
|
64
|
+
# Prepare sampling tensors with pinned memory to avoid blocking.
|
65
|
+
(sampling_tensors, do_penalties, do_top_p_top_k,
|
66
|
+
do_min_p) = SamplingTensors.from_sampling_metadata(
|
67
|
+
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
68
|
+
|
69
|
+
# Apply presence and frequency penalties.
|
70
|
+
if do_penalties:
|
71
|
+
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
72
|
+
sampling_tensors.output_tokens,
|
73
|
+
sampling_tensors.presence_penalties,
|
74
|
+
sampling_tensors.frequency_penalties,
|
75
|
+
sampling_tensors.repetition_penalties)
|
76
|
+
|
77
|
+
# Apply temperature scaling.
|
78
|
+
# Use in-place division to avoid creating a new tensor.
|
79
|
+
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
80
|
+
|
81
|
+
if do_top_p_top_k:
|
82
|
+
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
83
|
+
sampling_tensors.top_ks)
|
84
|
+
|
85
|
+
if do_min_p:
|
86
|
+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
87
|
+
|
88
|
+
# We use float32 for probabilities and log probabilities.
|
89
|
+
# Compute the probabilities.
|
90
|
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
91
|
+
# Compute the log probabilities.
|
92
|
+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
93
|
+
|
94
|
+
# Sample the next tokens.
|
95
|
+
sample_results, maybe_sampled_tokens_tensor = _sample(
|
96
|
+
probs,
|
97
|
+
logprobs,
|
98
|
+
sampling_metadata,
|
99
|
+
sampling_tensors,
|
100
|
+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
101
|
+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
102
|
+
)
|
103
|
+
|
104
|
+
if self.include_gpu_probs_tensor:
|
105
|
+
assert maybe_sampled_tokens_tensor is not None
|
106
|
+
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
107
|
+
else:
|
108
|
+
on_device_tensors = None
|
109
|
+
|
110
|
+
# Get the logprobs query results.
|
111
|
+
prompt_logprobs, sample_logprobs = _get_logprobs(
|
112
|
+
logprobs, sampling_metadata, sample_results)
|
113
|
+
return _build_sampler_output(sample_results,
|
114
|
+
sampling_metadata,
|
115
|
+
prompt_logprobs,
|
116
|
+
sample_logprobs,
|
117
|
+
on_device_tensors=on_device_tensors)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def _should_modify_greedy_probs_inplace(self) -> bool:
|
121
|
+
"""Whether or not the sampler should modify the probability distribution
|
122
|
+
of greedily-sampled tokens such that multinomial sampling would sample
|
123
|
+
the greedily-sampled token.
|
124
|
+
|
125
|
+
In other words, if True then we set the probability of the greedily-
|
126
|
+
sampled token to 1.
|
127
|
+
|
128
|
+
This is used by speculative decoding, which requires that the sampling
|
129
|
+
method be encoded into the probability distribution.
|
130
|
+
"""
|
131
|
+
# Modify greedy probs if include_gpu_probs_tensor is set.
|
132
|
+
return self.include_gpu_probs_tensor
|
133
|
+
|
134
|
+
|
135
|
+
def _get_bin_counts_and_mask(
|
136
|
+
tokens: torch.Tensor,
|
137
|
+
vocab_size: int,
|
138
|
+
num_seqs: int,
|
139
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
140
|
+
# Compute the bin counts for the tokens.
|
141
|
+
# vocab_size + 1 for padding.
|
142
|
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
143
|
+
dtype=torch.long,
|
144
|
+
device=tokens.device)
|
145
|
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
146
|
+
bin_counts = bin_counts[:, :vocab_size]
|
147
|
+
mask = bin_counts > 0
|
148
|
+
|
149
|
+
return bin_counts, mask
|
150
|
+
|
151
|
+
|
152
|
+
def _apply_min_tokens_penalty(
|
153
|
+
logits: torch.Tensor,
|
154
|
+
sampling_metadata: SamplingMetadata,
|
155
|
+
) -> torch.Tensor:
|
156
|
+
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
157
|
+
have not been generated yet
|
158
|
+
"""
|
159
|
+
# list of indices in logits that will be set to -inf
|
160
|
+
logits_to_penalize: List[Tuple[int, int]] = []
|
161
|
+
logits_applied = 0
|
162
|
+
for seq_group in sampling_metadata.seq_groups:
|
163
|
+
seq_ids = seq_group.seq_ids
|
164
|
+
sampling_params = seq_group.sampling_params
|
165
|
+
|
166
|
+
sample_indices = seq_group.sample_indices
|
167
|
+
logits_applied += len(sample_indices) + len(
|
168
|
+
seq_group.prompt_logprob_indices)
|
169
|
+
if not seq_group.do_sample:
|
170
|
+
continue
|
171
|
+
|
172
|
+
start_idx = sample_indices[0]
|
173
|
+
min_tokens = sampling_params.min_tokens
|
174
|
+
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
175
|
+
if min_tokens > 0 and token_ids_to_penalize:
|
176
|
+
seqs_to_penalize = []
|
177
|
+
for j, seq_id in enumerate(seq_ids):
|
178
|
+
seq_data = seq_group.seq_data[seq_id]
|
179
|
+
if len(seq_data.output_token_ids) < min_tokens:
|
180
|
+
seqs_to_penalize.append(j)
|
181
|
+
|
182
|
+
if seqs_to_penalize:
|
183
|
+
# convert to the index into logits
|
184
|
+
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
|
185
|
+
# itertools.product pairs each seq index with every token id
|
186
|
+
logits_to_penalize.extend(
|
187
|
+
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
188
|
+
|
189
|
+
if logits_to_penalize:
|
190
|
+
# use zip and * to group indices along each dimension
|
191
|
+
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
192
|
+
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
193
|
+
|
194
|
+
# verifies that no rows in logits were missed unexpectedly
|
195
|
+
assert logits_applied == logits.shape[0]
|
196
|
+
return logits
|
197
|
+
|
198
|
+
|
199
|
+
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
200
|
+
output_tokens_tensor: torch.Tensor,
|
201
|
+
presence_penalties: torch.Tensor,
|
202
|
+
frequency_penalties: torch.Tensor,
|
203
|
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
204
|
+
num_seqs, vocab_size = logits.shape
|
205
|
+
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
206
|
+
num_seqs)
|
207
|
+
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
208
|
+
output_tokens_tensor, vocab_size, num_seqs)
|
209
|
+
|
210
|
+
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
211
|
+
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
212
|
+
logits = torch.where(logits > 0, logits / repetition_penalties,
|
213
|
+
logits * repetition_penalties)
|
214
|
+
|
215
|
+
# We follow the definition in OpenAI API.
|
216
|
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
217
|
+
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
218
|
+
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
219
|
+
return logits
|
220
|
+
|
221
|
+
|
222
|
+
def _apply_top_k_top_p(
|
223
|
+
logits: torch.Tensor,
|
224
|
+
p: torch.Tensor,
|
225
|
+
k: torch.Tensor,
|
226
|
+
) -> torch.Tensor:
|
227
|
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
228
|
+
|
229
|
+
# Apply top-k.
|
230
|
+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
231
|
+
# Get all the top_k values.
|
232
|
+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
233
|
+
top_k_mask = logits_sort < top_k_mask
|
234
|
+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
235
|
+
|
236
|
+
# Apply top-p.
|
237
|
+
probs_sort = logits_sort.softmax(dim=-1)
|
238
|
+
probs_sum = probs_sort.cumsum(dim=-1)
|
239
|
+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
240
|
+
# at least one
|
241
|
+
top_p_mask[:, -1] = False
|
242
|
+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
243
|
+
|
244
|
+
# Re-sort the probabilities.
|
245
|
+
src = torch.arange(logits_idx.shape[-1],
|
246
|
+
device=logits_idx.device).expand_as(logits_idx)
|
247
|
+
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
|
248
|
+
index=logits_idx,
|
249
|
+
src=src)
|
250
|
+
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
|
251
|
+
return logits
|
252
|
+
|
253
|
+
|
254
|
+
def _apply_min_p(
|
255
|
+
logits: torch.Tensor,
|
256
|
+
min_p: torch.Tensor,
|
257
|
+
) -> torch.Tensor:
|
258
|
+
"""
|
259
|
+
Adapted from
|
260
|
+
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
261
|
+
"""
|
262
|
+
probs = torch.softmax(logits, dim=-1)
|
263
|
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
264
|
+
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
265
|
+
tokens_to_remove = probs < scaled_min_p
|
266
|
+
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
267
|
+
|
268
|
+
return logits
|
269
|
+
|
270
|
+
|
271
|
+
def _greedy_sample(
|
272
|
+
selected_seq_groups: List[SequenceGroupToSample],
|
273
|
+
samples: torch.Tensor,
|
274
|
+
) -> SampleResultType:
|
275
|
+
"""Run greedy sampling on a given samples.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
selected_seq_groups: A list of sequence groups batched.
|
279
|
+
samples: (num_selected_samples,) A tensor of samples. The length of
|
280
|
+
samples could be smaller than selected_seq_groups if
|
281
|
+
seq_group.do_sample is False.
|
282
|
+
Returns:
|
283
|
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
284
|
+
same as the length of selected_seq_groups. If the corresponding
|
285
|
+
seq_group has do_sample=False, tuple contains ([], [])
|
286
|
+
"""
|
287
|
+
samples = samples.tolist()
|
288
|
+
sample_idx = 0
|
289
|
+
results: SampleResultType = []
|
290
|
+
for seq_group in selected_seq_groups:
|
291
|
+
if not seq_group.do_sample:
|
292
|
+
results.append(([], []))
|
293
|
+
continue
|
294
|
+
|
295
|
+
seq_ids = seq_group.seq_ids
|
296
|
+
num_parent_seqs = len(seq_ids)
|
297
|
+
assert num_parent_seqs == 1, (
|
298
|
+
"Greedy sampling should have only one seq.")
|
299
|
+
parent_ids = list(range(num_parent_seqs))
|
300
|
+
next_token_ids = [samples[sample_idx]]
|
301
|
+
results.append((next_token_ids, parent_ids))
|
302
|
+
sample_idx += num_parent_seqs
|
303
|
+
return results
|
304
|
+
|
305
|
+
|
306
|
+
def _random_sample(
|
307
|
+
selected_seq_groups: List[SequenceGroupToSample],
|
308
|
+
random_samples: torch.Tensor,
|
309
|
+
) -> SampleResultType:
|
310
|
+
"""Run random sampling on a given samples.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
selected_seq_groups: A list of sequence groups batched.
|
314
|
+
random_samples: (num_selected_samples,) A tensor of samples. The
|
315
|
+
length of samples could be smaller than selected_seq_groups if
|
316
|
+
seq_group.do_sample is False.
|
317
|
+
Returns:
|
318
|
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
319
|
+
same as the length of selected_seq_groups. If the corresponding
|
320
|
+
seq_group has do_sample=False, tuple contains ([], [])
|
321
|
+
"""
|
322
|
+
# Find the maximum best_of value of the prompt phase requests.
|
323
|
+
random_samples = random_samples.cpu()
|
324
|
+
sample_idx = 0
|
325
|
+
results: SampleResultType = []
|
326
|
+
for seq_group in selected_seq_groups:
|
327
|
+
if not seq_group.do_sample:
|
328
|
+
results.append(([], []))
|
329
|
+
continue
|
330
|
+
|
331
|
+
seq_ids = seq_group.seq_ids
|
332
|
+
sampling_params = seq_group.sampling_params
|
333
|
+
is_prompt = seq_group.is_prompt
|
334
|
+
num_parent_seqs = len(seq_ids)
|
335
|
+
if is_prompt:
|
336
|
+
# Prompt phase.
|
337
|
+
parent_ids = [0] * sampling_params.best_of
|
338
|
+
next_token_ids = random_samples[
|
339
|
+
sample_idx, :sampling_params.best_of].tolist()
|
340
|
+
else:
|
341
|
+
# Generation phase.
|
342
|
+
parent_ids = list(range(num_parent_seqs))
|
343
|
+
next_token_ids = random_samples[sample_idx:sample_idx +
|
344
|
+
num_parent_seqs, 0].tolist()
|
345
|
+
results.append((next_token_ids, parent_ids))
|
346
|
+
sample_idx += num_parent_seqs
|
347
|
+
return results
|
348
|
+
|
349
|
+
|
350
|
+
def _beam_search_sample(
|
351
|
+
selected_seq_groups: List[SequenceGroupToSample],
|
352
|
+
logprobs: torch.Tensor,
|
353
|
+
) -> SampleResultType:
|
354
|
+
"""Run beam sampling on a given samples.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
selected_seq_groups: A list of sequence groups batched.
|
358
|
+
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
359
|
+
on selected sample indices.
|
360
|
+
Returns:
|
361
|
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
362
|
+
same as the length of selected_seq_groups. If the corresponding
|
363
|
+
seq_group has do_sample=False, tuple contains ([], [])
|
364
|
+
"""
|
365
|
+
# We sample 2 * beam_width candidates to make sure that with high
|
366
|
+
# probability we can get `beam_width` candidates in addition to
|
367
|
+
# the finished sequences for the next iteration. See
|
368
|
+
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
369
|
+
# for details. See also HF reference:
|
370
|
+
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
371
|
+
#
|
372
|
+
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
373
|
+
# other sampling methods.
|
374
|
+
sample_idx = 0
|
375
|
+
results: SampleResultType = []
|
376
|
+
for seq_group in selected_seq_groups:
|
377
|
+
if not seq_group.do_sample:
|
378
|
+
results.append(([], []))
|
379
|
+
continue
|
380
|
+
|
381
|
+
is_prompt = seq_group.is_prompt
|
382
|
+
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
383
|
+
num_parent_seqs = len(seq_ids)
|
384
|
+
beam_width = sampling_params.best_of
|
385
|
+
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
386
|
+
if is_prompt:
|
387
|
+
# Prompt phase.
|
388
|
+
assert num_parent_seqs == 1, (
|
389
|
+
"Prompt input should have only one seq.")
|
390
|
+
parent_ids = [0] * (2 * beam_width)
|
391
|
+
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
392
|
+
2 * beam_width)
|
393
|
+
next_token_ids = next_token_ids.tolist()
|
394
|
+
else:
|
395
|
+
# Generation phase.
|
396
|
+
cumulative_logprobs: List[int] = [
|
397
|
+
seq_group.seq_data[seq_id].cumulative_logprob
|
398
|
+
for seq_id in seq_ids
|
399
|
+
]
|
400
|
+
cumulative_logprobs_tensor = torch.tensor(
|
401
|
+
cumulative_logprobs,
|
402
|
+
dtype=torch.float,
|
403
|
+
device=seq_group_logprobs.device)
|
404
|
+
seq_group_logprobs = (seq_group_logprobs +
|
405
|
+
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
406
|
+
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
407
|
+
2 * beam_width)
|
408
|
+
topk_ids = topk_ids.tolist()
|
409
|
+
vocab_size = seq_group_logprobs.size(-1)
|
410
|
+
parent_ids = [i // vocab_size for i in topk_ids]
|
411
|
+
next_token_ids = [i % vocab_size for i in topk_ids]
|
412
|
+
results.append((next_token_ids, parent_ids))
|
413
|
+
sample_idx += num_parent_seqs
|
414
|
+
assert sample_idx == logprobs.size(0)
|
415
|
+
return results
|
416
|
+
|
417
|
+
|
418
|
+
# torch.multinomial forces a GPU<->CPU sync.
|
419
|
+
# Therefore, we use an optimized implementation instead.
|
420
|
+
# Note that we always sample with replacement.
|
421
|
+
# probs will be modified in place, but this is fine, as we pass
|
422
|
+
# in a copy already.
|
423
|
+
def _multinomial(
|
424
|
+
probs: torch.Tensor,
|
425
|
+
num_samples: int,
|
426
|
+
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
427
|
+
) -> torch.Tensor:
|
428
|
+
if num_samples > 1:
|
429
|
+
# This is equivalent to torch.repeat_interleaved (which also
|
430
|
+
# forces a GPU<->CPU sync).
|
431
|
+
# This allows us to do sampling with replacement by creating
|
432
|
+
# num_samples copies of each row in the tensor, and then
|
433
|
+
# batch sampling the resulting tensor.
|
434
|
+
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
435
|
+
probs.shape[1]).contiguous().view(
|
436
|
+
-1, probs.shape[1])
|
437
|
+
q = torch.empty_like(probs)
|
438
|
+
if seq_groups is None:
|
439
|
+
q.exponential_()
|
440
|
+
else:
|
441
|
+
sample_idx = 0
|
442
|
+
for seq_group in seq_groups:
|
443
|
+
seq_ids = seq_group.seq_ids
|
444
|
+
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
445
|
+
q[sample_idx:next_sample_idx].exponential_(
|
446
|
+
generator=seq_group.generator)
|
447
|
+
sample_idx = next_sample_idx
|
448
|
+
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
449
|
+
|
450
|
+
|
451
|
+
def _sample_with_torch(
|
452
|
+
probs: torch.Tensor,
|
453
|
+
logprobs: torch.Tensor,
|
454
|
+
sampling_metadata: SamplingMetadata,
|
455
|
+
include_gpu_probs_tensor: bool,
|
456
|
+
modify_greedy_probs: bool,
|
457
|
+
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
458
|
+
categorized_seq_group_ids: Dict[SamplingType,
|
459
|
+
List[int]] = {t: []
|
460
|
+
for t in SamplingType}
|
461
|
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
462
|
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
463
|
+
sampling_params = seq_group.sampling_params
|
464
|
+
sampling_type = sampling_params.sampling_type
|
465
|
+
categorized_seq_group_ids[sampling_type].append(i)
|
466
|
+
|
467
|
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
468
|
+
sample_metadata = {}
|
469
|
+
multinomial_samples = {}
|
470
|
+
|
471
|
+
# Create output tensor for sampled token ids.
|
472
|
+
if include_gpu_probs_tensor:
|
473
|
+
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
474
|
+
1,
|
475
|
+
dtype=torch.long,
|
476
|
+
device=logprobs.device)
|
477
|
+
else:
|
478
|
+
sampled_token_ids_tensor = None
|
479
|
+
|
480
|
+
# Counterintiutively, having two loops here is actually faster.
|
481
|
+
# The first loop can run without waiting on GPU<->CPU sync.
|
482
|
+
for sampling_type in SamplingType:
|
483
|
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
484
|
+
num_tokens = len(sample_indices)
|
485
|
+
if num_tokens == 0:
|
486
|
+
continue
|
487
|
+
|
488
|
+
seq_group_id = categorized_seq_group_ids[sampling_type]
|
489
|
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
490
|
+
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
491
|
+
long_sample_indices = sample_indices.long()
|
492
|
+
if sampling_type == SamplingType.GREEDY:
|
493
|
+
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
494
|
+
dim=-1)
|
495
|
+
|
496
|
+
if include_gpu_probs_tensor:
|
497
|
+
# Store sampled tokens in output tensor.
|
498
|
+
sampled_token_ids_tensor[
|
499
|
+
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
500
|
+
|
501
|
+
if modify_greedy_probs:
|
502
|
+
# If required, modify the probabilities such that sampling from
|
503
|
+
# the modified distribution would always sample the argmax
|
504
|
+
# token id.
|
505
|
+
_modify_greedy_probs_inplace(logprobs, probs,
|
506
|
+
long_sample_indices,
|
507
|
+
greedy_samples)
|
508
|
+
|
509
|
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
510
|
+
max_best_of_in_batch = 1
|
511
|
+
for seq_group in seq_groups:
|
512
|
+
if seq_group.is_prompt:
|
513
|
+
sampling_params = seq_group.sampling_params
|
514
|
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
515
|
+
sampling_params.best_of)
|
516
|
+
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
517
|
+
"seq_groups": seq_groups,
|
518
|
+
}
|
519
|
+
|
520
|
+
multinomial_samples[sampling_type] = _multinomial(
|
521
|
+
probs[long_sample_indices], max_best_of_in_batch,
|
522
|
+
**seeded_args)
|
523
|
+
|
524
|
+
if include_gpu_probs_tensor:
|
525
|
+
# Store sampled tokens in output tensor.
|
526
|
+
sampled_token_ids_tensor[
|
527
|
+
long_sample_indices] = multinomial_samples[sampling_type]
|
528
|
+
|
529
|
+
elif sampling_type == SamplingType.BEAM:
|
530
|
+
beam_search_logprobs = logprobs[sample_indices]
|
531
|
+
else:
|
532
|
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
533
|
+
|
534
|
+
# GPU<->CPU sync happens in the loop below.
|
535
|
+
# This also converts the sample output to Python objects.
|
536
|
+
for sampling_type in SamplingType:
|
537
|
+
if sampling_type not in sample_metadata:
|
538
|
+
continue
|
539
|
+
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
540
|
+
if sampling_type == SamplingType.GREEDY:
|
541
|
+
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
542
|
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
543
|
+
sample_results = _random_sample(seq_groups,
|
544
|
+
multinomial_samples[sampling_type])
|
545
|
+
elif sampling_type == SamplingType.BEAM:
|
546
|
+
sample_results = _beam_search_sample(seq_groups,
|
547
|
+
beam_search_logprobs)
|
548
|
+
sample_results_dict.update(zip(seq_group_id, sample_results))
|
549
|
+
|
550
|
+
sample_results = [
|
551
|
+
sample_results_dict.get(i, ([], []))
|
552
|
+
for i in range(len(sampling_metadata.seq_groups))
|
553
|
+
]
|
554
|
+
return sample_results, sampled_token_ids_tensor
|
555
|
+
|
556
|
+
|
557
|
+
def _sample_with_triton_kernel(
|
558
|
+
probs: torch.Tensor,
|
559
|
+
logprobs: torch.Tensor,
|
560
|
+
sampling_metadata: SamplingMetadata,
|
561
|
+
sampling_tensors: SamplingTensors,
|
562
|
+
) -> SampleResultType:
|
563
|
+
categorized_seq_group_ids: Dict[SamplingType,
|
564
|
+
List[int]] = {t: []
|
565
|
+
for t in SamplingType}
|
566
|
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
567
|
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
568
|
+
sampling_params = seq_group.sampling_params
|
569
|
+
sampling_type = sampling_params.sampling_type
|
570
|
+
categorized_seq_group_ids[sampling_type].append(i)
|
571
|
+
|
572
|
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
573
|
+
sample_metadata = {}
|
574
|
+
max_best_of_in_batch = 1
|
575
|
+
|
576
|
+
# Counterintiutively, having two loops here is actually faster.
|
577
|
+
# The first loop can run without waiting on GPU<->CPU sync.
|
578
|
+
for sampling_type in SamplingType:
|
579
|
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
580
|
+
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
581
|
+
num_tokens = len(sample_indices)
|
582
|
+
if num_tokens == 0:
|
583
|
+
continue
|
584
|
+
seq_group_id = categorized_seq_group_ids[sampling_type]
|
585
|
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
586
|
+
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
|
587
|
+
sample_indices,
|
588
|
+
sampled_token_indices)
|
589
|
+
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
590
|
+
SamplingType.RANDOM_SEED):
|
591
|
+
for seq_group in seq_groups:
|
592
|
+
if seq_group.is_prompt:
|
593
|
+
sampling_params = seq_group.sampling_params
|
594
|
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
595
|
+
sampling_params.best_of)
|
596
|
+
elif sampling_type == SamplingType.BEAM:
|
597
|
+
beam_search_logprobs = logprobs[sample_indices]
|
598
|
+
else:
|
599
|
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
600
|
+
|
601
|
+
sampled_tokens, _, _ = sample_triton(
|
602
|
+
probs=probs,
|
603
|
+
seeds=sampling_tensors.sampling_seeds,
|
604
|
+
max_best_of=max_best_of_in_batch,
|
605
|
+
sample_indices=sampling_tensors.sample_indices,
|
606
|
+
logprobs=logprobs,
|
607
|
+
# don't save logprobs because we have logic for that below
|
608
|
+
# TODO: use this instead of the CPU-based logic below
|
609
|
+
save_logprobs=False,
|
610
|
+
)
|
611
|
+
|
612
|
+
# GPU<->CPU sync happens in the loop below.
|
613
|
+
|
614
|
+
for sampling_type in SamplingType:
|
615
|
+
if sampling_type not in sample_metadata:
|
616
|
+
continue
|
617
|
+
(seq_group_id, seq_groups, sample_indices,
|
618
|
+
sampled_token_indices) = sample_metadata[sampling_type]
|
619
|
+
if sampling_type == SamplingType.GREEDY:
|
620
|
+
sample_results = _greedy_sample(
|
621
|
+
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
622
|
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
623
|
+
sample_results = _random_sample(
|
624
|
+
seq_groups, sampled_tokens[sampled_token_indices])
|
625
|
+
elif sampling_type == SamplingType.BEAM:
|
626
|
+
sample_results = _beam_search_sample(seq_groups,
|
627
|
+
beam_search_logprobs)
|
628
|
+
sample_results_dict.update(zip(seq_group_id, sample_results))
|
629
|
+
|
630
|
+
sample_results = [
|
631
|
+
sample_results_dict.get(i, ([], []))
|
632
|
+
for i in range(len(sampling_metadata.seq_groups))
|
633
|
+
]
|
634
|
+
return sample_results
|
635
|
+
|
636
|
+
|
637
|
+
def _sample(
|
638
|
+
probs: torch.Tensor, logprobs: torch.Tensor,
|
639
|
+
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
640
|
+
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
641
|
+
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
642
|
+
"""
|
643
|
+
Args:
|
644
|
+
probs: (num_query_tokens_in_batch, num_vocab)
|
645
|
+
logprobs: (num_query_tokens_in_batch, num_vocab)
|
646
|
+
sampling_metadata: The metadata for a batch for sampling.
|
647
|
+
sampling_tensors: Tensors that include sampling related metadata.
|
648
|
+
|
649
|
+
Returns:
|
650
|
+
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
651
|
+
If sampling is skipped, it returns ([], [])
|
652
|
+
sampled_token_ids_tensor: A tensor of sampled token ids.
|
653
|
+
"""
|
654
|
+
return _sample_with_torch(
|
655
|
+
probs,
|
656
|
+
logprobs,
|
657
|
+
sampling_metadata,
|
658
|
+
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
659
|
+
modify_greedy_probs=modify_greedy_probs,
|
660
|
+
)
|
661
|
+
|
662
|
+
# TODO: Enable once Triton kernel & associated code is faster.
|
663
|
+
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
664
|
+
# sampling_tensors)
|
665
|
+
|
666
|
+
|
667
|
+
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
668
|
+
"""
|
669
|
+
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
670
|
+
|
671
|
+
Args:
|
672
|
+
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
673
|
+
where N is the no. of tokens and M is the vocab dim.
|
674
|
+
indices (torch.Tensor): List of chosen token indices.
|
675
|
+
|
676
|
+
Returns:
|
677
|
+
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
678
|
+
Each element in the returned tensor represents the rank
|
679
|
+
of the chosen token in the input logprob tensor.
|
680
|
+
"""
|
681
|
+
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
682
|
+
indices]
|
683
|
+
return (x > vals[:, None]).long().sum(1).add_(1)
|
684
|
+
|
685
|
+
|
686
|
+
def _get_logprobs(
|
687
|
+
logprobs: torch.Tensor,
|
688
|
+
sampling_metadata: SamplingMetadata,
|
689
|
+
sample_results: SampleResultType,
|
690
|
+
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
691
|
+
"""Return sample lobprobs and prompt logprobs.
|
692
|
+
|
693
|
+
The logic consists of 3 parts.
|
694
|
+
- Select indices to compute logprob from, ranks of token ids, and
|
695
|
+
the top k token ids from logprobs.
|
696
|
+
- Compute prompt logprobs if required.
|
697
|
+
- Compute sample logprobs if required.
|
698
|
+
|
699
|
+
Args:
|
700
|
+
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
|
701
|
+
logprob per vocab. Sequence groups' query tokens are batched in a
|
702
|
+
single flattened tensor. For example, assuming there are N
|
703
|
+
seq groups, it is sorted by prefill tokens for seq_group_1 (if
|
704
|
+
prompt logprob is enabled), decode tokens for seq_group_1 (if
|
705
|
+
sampling is required), prefill tokens for seq_group_2, ...
|
706
|
+
sampling_metadata: The sampling metadata.
|
707
|
+
sample_results: (num_seq_groups) The tuple of (next_token_ids,
|
708
|
+
parent_ids) for each sequence group. When beam search is enabled,
|
709
|
+
sample_results can contain different number of seq_ids from
|
710
|
+
sampling_metadata.seq_groups. It is because beam search creates
|
711
|
+
2 * BEAM_WIDTH number of samples (whereas there are only up to
|
712
|
+
BEAM_WIDTH number of seq_ids).
|
713
|
+
|
714
|
+
Returns:
|
715
|
+
A tuple of prompt and sample logprobs per sequence group in a batch.
|
716
|
+
"""
|
717
|
+
# The index of query token to calculate logprobs. It includes both
|
718
|
+
# prompt and sample logprob indices.
|
719
|
+
query_indices: List[int] = []
|
720
|
+
# The next token ids to get the logprob value from.
|
721
|
+
next_token_ids: List[int] = []
|
722
|
+
# The largest requested number of logprobs. We find logprobs as many as the
|
723
|
+
# largest num logprobs in this API.
|
724
|
+
largest_num_logprobs = 1
|
725
|
+
|
726
|
+
# Select indices to compute logprob from, ranks of token ids, and the top
|
727
|
+
# k token ids from logprobs.
|
728
|
+
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
729
|
+
sample_results):
|
730
|
+
sampling_params = seq_group.sampling_params
|
731
|
+
|
732
|
+
# Update indices and tokens for prompt logprobs.
|
733
|
+
if (seq_group.is_prompt
|
734
|
+
and sampling_params.prompt_logprobs is not None):
|
735
|
+
largest_num_logprobs = max(largest_num_logprobs,
|
736
|
+
sampling_params.prompt_logprobs)
|
737
|
+
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
738
|
+
query_indices.extend(seq_group.prompt_logprob_indices)
|
739
|
+
next_token_ids.extend(next_prompt_tokens)
|
740
|
+
|
741
|
+
# Update indices and next tokenes for sample logprob.
|
742
|
+
if seq_group.do_sample:
|
743
|
+
token_ids, parent_seq_ids = sample_result
|
744
|
+
# NOTE: We cannot directly use sample_indices because
|
745
|
+
# sample_indices only contain parent seq_ids of a previous step.
|
746
|
+
# The current step may have different number of seq_ids, and
|
747
|
+
# we can obtain it from `sample_result[1]`.
|
748
|
+
query_idx = seq_group.sample_indices[0]
|
749
|
+
query_indices.extend(
|
750
|
+
[query_idx + parent_id for parent_id in parent_seq_ids])
|
751
|
+
next_token_ids.extend(token_ids)
|
752
|
+
|
753
|
+
if sampling_params.logprobs is not None:
|
754
|
+
largest_num_logprobs = max(largest_num_logprobs,
|
755
|
+
sampling_params.logprobs)
|
756
|
+
|
757
|
+
assert len(next_token_ids) == len(query_indices)
|
758
|
+
|
759
|
+
if len(query_indices) == 0:
|
760
|
+
empty_sampled_logprob: SampleLogprobs = []
|
761
|
+
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
762
|
+
return [empty_prompt_logprob], [empty_sampled_logprob]
|
763
|
+
|
764
|
+
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
765
|
+
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
|
766
|
+
|
767
|
+
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
768
|
+
# contain duplicates if beam search is enabled.
|
769
|
+
selected_logprobs = logprobs[[
|
770
|
+
query_indices_gpu,
|
771
|
+
next_token_ids_gpu,
|
772
|
+
]]
|
773
|
+
ranks = _get_ranks(
|
774
|
+
logprobs[query_indices_gpu],
|
775
|
+
next_token_ids_gpu,
|
776
|
+
)
|
777
|
+
assert selected_logprobs.shape[0] == ranks.shape[0]
|
778
|
+
|
779
|
+
# Logprobs of topk tokens for a batch of sequence groups.
|
780
|
+
# (num_query_tokens_across_batch).
|
781
|
+
if largest_num_logprobs > 0:
|
782
|
+
top_logprobs, top_token_ids = torch.topk(logprobs,
|
783
|
+
largest_num_logprobs,
|
784
|
+
dim=-1)
|
785
|
+
top_logprobs = top_logprobs.cpu()
|
786
|
+
top_token_ids = top_token_ids.cpu()
|
787
|
+
else:
|
788
|
+
top_logprobs, top_token_ids = None, None
|
789
|
+
|
790
|
+
selected_logprobs = selected_logprobs.cpu()
|
791
|
+
ranks = ranks.cpu()
|
792
|
+
|
793
|
+
# Find prompt/sample logprobs.
|
794
|
+
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
795
|
+
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
796
|
+
top_logprob_idx = 0
|
797
|
+
selected_logprobs_idx = 0
|
798
|
+
|
799
|
+
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
|
800
|
+
sample_results):
|
801
|
+
(prompt_logprobs, top_logprob_idx,
|
802
|
+
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
|
803
|
+
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
|
804
|
+
selected_logprobs_idx, top_logprob_idx)
|
805
|
+
prompt_logprobs_per_seq_group.append(prompt_logprobs)
|
806
|
+
|
807
|
+
(sampled_logprobs, top_logprob_idx,
|
808
|
+
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
|
809
|
+
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
|
810
|
+
top_logprobs, selected_logprobs_idx, top_logprob_idx)
|
811
|
+
sample_logprobs_per_seq_group.append(sampled_logprobs)
|
812
|
+
|
813
|
+
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
|
814
|
+
|
815
|
+
|
816
|
+
def _get_prompt_logprob_if_needed(
|
817
|
+
seq_group: SequenceGroupToSample,
|
818
|
+
selected_logprobs: torch.Tensor,
|
819
|
+
ranks: torch.Tensor,
|
820
|
+
top_token_ids: torch.Tensor,
|
821
|
+
top_logprobs: torch.Tensor,
|
822
|
+
selected_logprobs_idx: int,
|
823
|
+
top_logprob_idx: int,
|
824
|
+
):
|
825
|
+
"""Compute the prompt logprob from a sequence group if needed."""
|
826
|
+
sampling_params = seq_group.sampling_params
|
827
|
+
is_prompt = seq_group.is_prompt
|
828
|
+
|
829
|
+
# Find prompt logprobs
|
830
|
+
prompt_logprobs: Optional[PromptLogprobs] = None
|
831
|
+
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
832
|
+
prompt_logprobs = []
|
833
|
+
num_logprobs = sampling_params.prompt_logprobs
|
834
|
+
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
835
|
+
for token_id in next_prompt_tokens:
|
836
|
+
# Calculate the prompt logprob of the real prompt tokens.
|
837
|
+
# Use tuple here for performance (to use to_list()).
|
838
|
+
# {token_id: (logprob, rank_from_vocab)}
|
839
|
+
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
840
|
+
token_id: (selected_logprobs[selected_logprobs_idx].item(),
|
841
|
+
ranks[selected_logprobs_idx].item())
|
842
|
+
}
|
843
|
+
|
844
|
+
# Add top K prompt logprobs along with its rank.
|
845
|
+
if num_logprobs > 0:
|
846
|
+
prompt_logprobs_dict.update(
|
847
|
+
zip(
|
848
|
+
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
|
849
|
+
zip(
|
850
|
+
top_logprobs[
|
851
|
+
top_logprob_idx, :num_logprobs].tolist(),
|
852
|
+
# This is ranks. Since top_logprob is sorted,
|
853
|
+
# we can just use a range here.
|
854
|
+
range(1, num_logprobs + 1))))
|
855
|
+
prompt_logprobs.append({
|
856
|
+
token_id: Logprob(*logprob_and_rank)
|
857
|
+
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
858
|
+
})
|
859
|
+
# + 1 to go to the next prompt token.
|
860
|
+
top_logprob_idx += 1
|
861
|
+
selected_logprobs_idx += 1
|
862
|
+
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
863
|
+
|
864
|
+
|
865
|
+
def _get_sampled_logprob_if_needed(
|
866
|
+
seq_group: SequenceGroupToSample,
|
867
|
+
sample_result: Tuple[List[int], List[int]],
|
868
|
+
selected_logprobs: torch.Tensor,
|
869
|
+
ranks: torch.Tensor,
|
870
|
+
top_token_ids: torch.Tensor,
|
871
|
+
top_logprobs: torch.Tensor,
|
872
|
+
selected_logprobs_idx: int,
|
873
|
+
top_logprob_idx: int,
|
874
|
+
):
|
875
|
+
"""Compute the sample logprob if needed."""
|
876
|
+
seq_ids = seq_group.seq_ids
|
877
|
+
num_logprobs = seq_group.sampling_params.logprobs
|
878
|
+
if num_logprobs is None:
|
879
|
+
num_logprobs = 0
|
880
|
+
sampled_logprobs: SampleLogprobs = []
|
881
|
+
next_token_ids, parent_seq_ids = sample_result
|
882
|
+
|
883
|
+
if seq_group.do_sample:
|
884
|
+
assert len(next_token_ids) > 0
|
885
|
+
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
|
886
|
+
# Calculate the sample logprob of the real sampled tokens.
|
887
|
+
# Use tuple here for performance (to use to_list()).
|
888
|
+
# token_id: (logprob, rank_from_vocab)
|
889
|
+
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
890
|
+
next_token_id:
|
891
|
+
(selected_logprobs[selected_logprobs_idx].item(),
|
892
|
+
ranks[selected_logprobs_idx].item())
|
893
|
+
}
|
894
|
+
# +1 to go to the next sampled token. Note that
|
895
|
+
# selected_logprobs can contain duplicates unlike top_logprobs
|
896
|
+
# when beam search is enabled.
|
897
|
+
selected_logprobs_idx += 1
|
898
|
+
|
899
|
+
# Second, add top K logprobs along with its rank.
|
900
|
+
if num_logprobs >= 0:
|
901
|
+
sampled_logprobs_dict.update(
|
902
|
+
zip(
|
903
|
+
top_token_ids[top_logprob_idx +
|
904
|
+
parent_id, :num_logprobs].tolist(),
|
905
|
+
zip(
|
906
|
+
top_logprobs[top_logprob_idx +
|
907
|
+
parent_id, :num_logprobs].tolist(),
|
908
|
+
# This is rank. Since top_logprob is sorted, we
|
909
|
+
# can just use a range here.
|
910
|
+
range(1, num_logprobs + 1))))
|
911
|
+
sampled_logprobs.append({
|
912
|
+
token_id: Logprob(*logprob_and_rank)
|
913
|
+
for token_id, logprob_and_rank in
|
914
|
+
sampled_logprobs_dict.items()
|
915
|
+
})
|
916
|
+
# There are len(seq_ids) number of sampled tokens for the current
|
917
|
+
# sequence group in top_logprobs. Jump to the next seq_group.
|
918
|
+
top_logprob_idx += len(seq_ids)
|
919
|
+
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
920
|
+
|
921
|
+
|
922
|
+
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
923
|
+
sample_indices: torch.Tensor,
|
924
|
+
greedy_samples: torch.Tensor) -> None:
|
925
|
+
"""Modify the probability distributions of the greedily-sampled tokens such
|
926
|
+
that each sampled token has a "probability" of 1.0. This is required by
|
927
|
+
speculative decoding, which depends on the sampling method being encoded
|
928
|
+
within the probability distribution for correctness.
|
929
|
+
|
930
|
+
# Why do we only need to do this for greedy sampling?
|
931
|
+
|
932
|
+
vLLM's sampler performs the following steps for greedy or multinomial
|
933
|
+
(random) sampling:
|
934
|
+
1. Get logits from model.
|
935
|
+
2. Modify logits according to per-sequence sampling parameters.
|
936
|
+
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
937
|
+
according to their frequency, etc.
|
938
|
+
3. Sample a token.
|
939
|
+
- Random sampling simply samples from the modified probability
|
940
|
+
distribution.
|
941
|
+
- Greedy sampling performs `argmax` to obtain the token with the
|
942
|
+
highest likelihood.
|
943
|
+
|
944
|
+
Ignoring greedy sampling for a moment, we find that the computed probability
|
945
|
+
distribution has the following property: we can sample from it independently
|
946
|
+
and find that the token sampled by the Sampler has a frequency corresponding
|
947
|
+
to how often we see it in our sampling. In other words, for tokens sampled
|
948
|
+
with vLLM's random SamplingType, the computed probability distribution
|
949
|
+
encodes the sampling methodology completely.
|
950
|
+
|
951
|
+
Greedy sampling does not normally have this property. vLLM modifies logits
|
952
|
+
according to sampling params, then performs `argmax`, then returns the
|
953
|
+
sampled token and the computed probability distribution. If we sample from
|
954
|
+
the distribution, we'll find the likelihood of the greedily-sampled token
|
955
|
+
is not always 1.0.
|
956
|
+
|
957
|
+
Since lossless speculative decoding requires that the sampling methodology
|
958
|
+
be encoded within the probability distribution, we are motivated to modify
|
959
|
+
the probability distribution such that the sampled token has probability 1
|
960
|
+
when speculative decoding is used.
|
961
|
+
|
962
|
+
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
963
|
+
greedy sampling using multinomial computation and unite the codepaths. This
|
964
|
+
has implications on the overall design of the sampler, e.g. how to record
|
965
|
+
accurate logprobs for the user, so this improvement is deferred to later.
|
966
|
+
"""
|
967
|
+
# NOTE: logprobs are not modified so they can be returned to the user.
|
968
|
+
probs[sample_indices, :] = 0
|
969
|
+
probs[sample_indices, greedy_samples] = 1.0
|
970
|
+
|
971
|
+
|
972
|
+
def _build_sampler_output(
|
973
|
+
sample_results: SampleResultType,
|
974
|
+
sampling_metadata: SamplingMetadata,
|
975
|
+
prompt_logprobs: List[Optional[PromptLogprobs]],
|
976
|
+
sample_logprobs: List[SampleLogprobs],
|
977
|
+
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
978
|
+
torch.Tensor]],
|
979
|
+
) -> SamplerOutput:
|
980
|
+
"""Construct Python objects with the output of sampling.
|
981
|
+
|
982
|
+
Args:
|
983
|
+
on_device_tensors: Tuple containing on-device tensors with the
|
984
|
+
probabilities used in sampling and the sampled token ids. This
|
985
|
+
allows post-processing without copies to CPU/serialization, e.g. in
|
986
|
+
speculative decoding rejection sampling.
|
987
|
+
"""
|
988
|
+
|
989
|
+
sampler_output = []
|
990
|
+
for (seq_group, sample_result, group_prompt_logprobs,
|
991
|
+
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
992
|
+
sample_results, prompt_logprobs,
|
993
|
+
sample_logprobs):
|
994
|
+
seq_ids = seq_group.seq_ids
|
995
|
+
next_token_ids, parent_ids = sample_result
|
996
|
+
seq_outputs = []
|
997
|
+
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
998
|
+
next_token_ids,
|
999
|
+
group_sample_logprobs):
|
1000
|
+
seq_outputs.append(
|
1001
|
+
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
1002
|
+
sampler_output.append(
|
1003
|
+
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
1004
|
+
|
1005
|
+
# If not specified, store None values in SamplerOutput.
|
1006
|
+
if on_device_tensors is not None:
|
1007
|
+
(sampled_token_probs, logprobs_tensor,
|
1008
|
+
sampled_token_ids) = on_device_tensors
|
1009
|
+
else:
|
1010
|
+
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
1011
|
+
None)
|
1012
|
+
|
1013
|
+
return SamplerOutput(
|
1014
|
+
outputs=sampler_output,
|
1015
|
+
sampled_token_probs=sampled_token_probs,
|
1016
|
+
sampled_token_ids=sampled_token_ids,
|
1017
|
+
logprobs=logprobs_tensor,
|
1018
|
+
)
|
1019
|
+
|
1020
|
+
|
1021
|
+
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
1022
|
+
"""Get a list of next prompt tokens to compute logprob from a
|
1023
|
+
given sequence group.
|
1024
|
+
|
1025
|
+
It is used to compute prompt logprob. Imagine you have logprob for each
|
1026
|
+
query token. Query token needs to know the next prompt token id to compute
|
1027
|
+
prompt logprob. This is a helper to obtain next prompt token ids.
|
1028
|
+
|
1029
|
+
This API has to be used only when the caller knows seq_group is in prefill
|
1030
|
+
stage.
|
1031
|
+
|
1032
|
+
Returns:
|
1033
|
+
A list of next prompt tokens to compute logprob.
|
1034
|
+
"""
|
1035
|
+
assert seq_group.is_prompt, (
|
1036
|
+
"Caller should ensure the sequence group is in a prefill stage.")
|
1037
|
+
seq_ids = seq_group.seq_ids
|
1038
|
+
query_len = seq_group.query_len
|
1039
|
+
assert query_len is not None
|
1040
|
+
# prompt has only 1 seq id.
|
1041
|
+
assert len(seq_ids) == 1
|
1042
|
+
seq_data = seq_group.seq_data[seq_ids[0]]
|
1043
|
+
computed_len = seq_data.get_num_computed_tokens()
|
1044
|
+
prompt_tokens = seq_data.prompt_token_ids
|
1045
|
+
# +1 because we are looking for a next prompt token.
|
1046
|
+
next_token_index_start = computed_len + 1
|
1047
|
+
next_token_index_end = min(computed_len + query_len + 1,
|
1048
|
+
len(prompt_tokens))
|
1049
|
+
next_prompt_tokens = prompt_tokens[
|
1050
|
+
next_token_index_start:next_token_index_end]
|
1051
|
+
return next_prompt_tokens
|