vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1051 @@
1
+ """A layer that samples the next tokens from the model's outputs."""
2
+ import itertools
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from vllm.model_executor.layers.ops.sample import sample as sample_triton
9
+ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
10
+ SamplingTensors,
11
+ SequenceGroupToSample)
12
+ from vllm.sampling_params import SamplingType
13
+ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
14
+ SamplerOutput, SequenceGroupOutput, SequenceOutput)
15
+
16
+ # (num_token_ids, num_parent_ids) per sequence group.
17
+ SampleResultType = List[Tuple[List[int], List[int]]]
18
+
19
+
20
+ class Sampler(nn.Module):
21
+ """Samples the next tokens from the model's outputs.
22
+
23
+ This layer does the following:
24
+ 1. Discard the hidden states that are not used for sampling (i.e., all
25
+ tokens except the final one in each prompt).
26
+ 2. Compute the logits for the next tokens.
27
+ 3. Apply presence, frequency and repetition penalties.
28
+ 4. Apply temperature scaling.
29
+ 5. Apply top-p and top-k truncation.
30
+ 6. Sample the next tokens.
31
+ Here, each sequence group within the batch can have different sampling
32
+ parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
33
+
34
+ The structure of the logits tensor is coupled with the seq_groups in
35
+ sampling_metadata. Typically, each sequence in each seq_group has one row in
36
+ logits for the next token to be sampled; however, for a seq_group with a
37
+ prompt request with the prompt_logprobs sampling parameter, there are rows
38
+ in logits for each token in the input prompt.
39
+ """
40
+
41
+ def __init__(self):
42
+ super().__init__()
43
+
44
+ # Whether or not the SamplerOutput should have on-device tensors
45
+ # containing the sampled token ids and probabilities. This is used by
46
+ # speculative decoding.
47
+ self.include_gpu_probs_tensor = False
48
+
49
+ def forward(
50
+ self,
51
+ logits: torch.Tensor,
52
+ sampling_metadata: SamplingMetadata,
53
+ ) -> Optional[SamplerOutput]:
54
+ """
55
+ Args:
56
+ logits: (num_tokens, vocab_size).
57
+ sampling_metadata: Metadata for sampling.
58
+ """
59
+ assert logits is not None
60
+ _, vocab_size = logits.shape
61
+
62
+ logits = _apply_min_tokens_penalty(logits, sampling_metadata)
63
+
64
+ # Prepare sampling tensors with pinned memory to avoid blocking.
65
+ (sampling_tensors, do_penalties, do_top_p_top_k,
66
+ do_min_p) = SamplingTensors.from_sampling_metadata(
67
+ sampling_metadata, vocab_size, logits.device, logits.dtype)
68
+
69
+ # Apply presence and frequency penalties.
70
+ if do_penalties:
71
+ logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
72
+ sampling_tensors.output_tokens,
73
+ sampling_tensors.presence_penalties,
74
+ sampling_tensors.frequency_penalties,
75
+ sampling_tensors.repetition_penalties)
76
+
77
+ # Apply temperature scaling.
78
+ # Use in-place division to avoid creating a new tensor.
79
+ logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
80
+
81
+ if do_top_p_top_k:
82
+ logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
83
+ sampling_tensors.top_ks)
84
+
85
+ if do_min_p:
86
+ logits = _apply_min_p(logits, sampling_tensors.min_ps)
87
+
88
+ # We use float32 for probabilities and log probabilities.
89
+ # Compute the probabilities.
90
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
91
+ # Compute the log probabilities.
92
+ logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
93
+
94
+ # Sample the next tokens.
95
+ sample_results, maybe_sampled_tokens_tensor = _sample(
96
+ probs,
97
+ logprobs,
98
+ sampling_metadata,
99
+ sampling_tensors,
100
+ include_gpu_probs_tensor=self.include_gpu_probs_tensor,
101
+ modify_greedy_probs=self._should_modify_greedy_probs_inplace,
102
+ )
103
+
104
+ if self.include_gpu_probs_tensor:
105
+ assert maybe_sampled_tokens_tensor is not None
106
+ on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
107
+ else:
108
+ on_device_tensors = None
109
+
110
+ # Get the logprobs query results.
111
+ prompt_logprobs, sample_logprobs = _get_logprobs(
112
+ logprobs, sampling_metadata, sample_results)
113
+ return _build_sampler_output(sample_results,
114
+ sampling_metadata,
115
+ prompt_logprobs,
116
+ sample_logprobs,
117
+ on_device_tensors=on_device_tensors)
118
+
119
+ @property
120
+ def _should_modify_greedy_probs_inplace(self) -> bool:
121
+ """Whether or not the sampler should modify the probability distribution
122
+ of greedily-sampled tokens such that multinomial sampling would sample
123
+ the greedily-sampled token.
124
+
125
+ In other words, if True then we set the probability of the greedily-
126
+ sampled token to 1.
127
+
128
+ This is used by speculative decoding, which requires that the sampling
129
+ method be encoded into the probability distribution.
130
+ """
131
+ # Modify greedy probs if include_gpu_probs_tensor is set.
132
+ return self.include_gpu_probs_tensor
133
+
134
+
135
+ def _get_bin_counts_and_mask(
136
+ tokens: torch.Tensor,
137
+ vocab_size: int,
138
+ num_seqs: int,
139
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ # Compute the bin counts for the tokens.
141
+ # vocab_size + 1 for padding.
142
+ bin_counts = torch.zeros((num_seqs, vocab_size + 1),
143
+ dtype=torch.long,
144
+ device=tokens.device)
145
+ bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
146
+ bin_counts = bin_counts[:, :vocab_size]
147
+ mask = bin_counts > 0
148
+
149
+ return bin_counts, mask
150
+
151
+
152
+ def _apply_min_tokens_penalty(
153
+ logits: torch.Tensor,
154
+ sampling_metadata: SamplingMetadata,
155
+ ) -> torch.Tensor:
156
+ """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
157
+ have not been generated yet
158
+ """
159
+ # list of indices in logits that will be set to -inf
160
+ logits_to_penalize: List[Tuple[int, int]] = []
161
+ logits_applied = 0
162
+ for seq_group in sampling_metadata.seq_groups:
163
+ seq_ids = seq_group.seq_ids
164
+ sampling_params = seq_group.sampling_params
165
+
166
+ sample_indices = seq_group.sample_indices
167
+ logits_applied += len(sample_indices) + len(
168
+ seq_group.prompt_logprob_indices)
169
+ if not seq_group.do_sample:
170
+ continue
171
+
172
+ start_idx = sample_indices[0]
173
+ min_tokens = sampling_params.min_tokens
174
+ token_ids_to_penalize = sampling_params.all_stop_token_ids
175
+ if min_tokens > 0 and token_ids_to_penalize:
176
+ seqs_to_penalize = []
177
+ for j, seq_id in enumerate(seq_ids):
178
+ seq_data = seq_group.seq_data[seq_id]
179
+ if len(seq_data.output_token_ids) < min_tokens:
180
+ seqs_to_penalize.append(j)
181
+
182
+ if seqs_to_penalize:
183
+ # convert to the index into logits
184
+ seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
185
+ # itertools.product pairs each seq index with every token id
186
+ logits_to_penalize.extend(
187
+ itertools.product(seqs_to_penalize, token_ids_to_penalize))
188
+
189
+ if logits_to_penalize:
190
+ # use zip and * to group indices along each dimension
191
+ # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
192
+ logits[tuple(zip(*logits_to_penalize))] = -float("inf")
193
+
194
+ # verifies that no rows in logits were missed unexpectedly
195
+ assert logits_applied == logits.shape[0]
196
+ return logits
197
+
198
+
199
+ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
200
+ output_tokens_tensor: torch.Tensor,
201
+ presence_penalties: torch.Tensor,
202
+ frequency_penalties: torch.Tensor,
203
+ repetition_penalties: torch.Tensor) -> torch.Tensor:
204
+ num_seqs, vocab_size = logits.shape
205
+ _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
206
+ num_seqs)
207
+ output_bin_counts, output_mask = _get_bin_counts_and_mask(
208
+ output_tokens_tensor, vocab_size, num_seqs)
209
+
210
+ repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
211
+ repetition_penalties[~(prompt_mask | output_mask)] = 1.0
212
+ logits = torch.where(logits > 0, logits / repetition_penalties,
213
+ logits * repetition_penalties)
214
+
215
+ # We follow the definition in OpenAI API.
216
+ # Refer to https://platform.openai.com/docs/api-reference/parameter-details
217
+ logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
218
+ logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
219
+ return logits
220
+
221
+
222
+ def _apply_top_k_top_p(
223
+ logits: torch.Tensor,
224
+ p: torch.Tensor,
225
+ k: torch.Tensor,
226
+ ) -> torch.Tensor:
227
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
228
+
229
+ # Apply top-k.
230
+ top_k_mask = logits_sort.size(1) - k.to(torch.long)
231
+ # Get all the top_k values.
232
+ top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
233
+ top_k_mask = logits_sort < top_k_mask
234
+ logits_sort.masked_fill_(top_k_mask, -float("inf"))
235
+
236
+ # Apply top-p.
237
+ probs_sort = logits_sort.softmax(dim=-1)
238
+ probs_sum = probs_sort.cumsum(dim=-1)
239
+ top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
240
+ # at least one
241
+ top_p_mask[:, -1] = False
242
+ logits_sort.masked_fill_(top_p_mask, -float("inf"))
243
+
244
+ # Re-sort the probabilities.
245
+ src = torch.arange(logits_idx.shape[-1],
246
+ device=logits_idx.device).expand_as(logits_idx)
247
+ logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
248
+ index=logits_idx,
249
+ src=src)
250
+ logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
251
+ return logits
252
+
253
+
254
+ def _apply_min_p(
255
+ logits: torch.Tensor,
256
+ min_p: torch.Tensor,
257
+ ) -> torch.Tensor:
258
+ """
259
+ Adapted from
260
+ https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
261
+ """
262
+ probs = torch.softmax(logits, dim=-1)
263
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
264
+ scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
265
+ tokens_to_remove = probs < scaled_min_p
266
+ logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
267
+
268
+ return logits
269
+
270
+
271
+ def _greedy_sample(
272
+ selected_seq_groups: List[SequenceGroupToSample],
273
+ samples: torch.Tensor,
274
+ ) -> SampleResultType:
275
+ """Run greedy sampling on a given samples.
276
+
277
+ Args:
278
+ selected_seq_groups: A list of sequence groups batched.
279
+ samples: (num_selected_samples,) A tensor of samples. The length of
280
+ samples could be smaller than selected_seq_groups if
281
+ seq_group.do_sample is False.
282
+ Returns:
283
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
284
+ same as the length of selected_seq_groups. If the corresponding
285
+ seq_group has do_sample=False, tuple contains ([], [])
286
+ """
287
+ samples = samples.tolist()
288
+ sample_idx = 0
289
+ results: SampleResultType = []
290
+ for seq_group in selected_seq_groups:
291
+ if not seq_group.do_sample:
292
+ results.append(([], []))
293
+ continue
294
+
295
+ seq_ids = seq_group.seq_ids
296
+ num_parent_seqs = len(seq_ids)
297
+ assert num_parent_seqs == 1, (
298
+ "Greedy sampling should have only one seq.")
299
+ parent_ids = list(range(num_parent_seqs))
300
+ next_token_ids = [samples[sample_idx]]
301
+ results.append((next_token_ids, parent_ids))
302
+ sample_idx += num_parent_seqs
303
+ return results
304
+
305
+
306
+ def _random_sample(
307
+ selected_seq_groups: List[SequenceGroupToSample],
308
+ random_samples: torch.Tensor,
309
+ ) -> SampleResultType:
310
+ """Run random sampling on a given samples.
311
+
312
+ Args:
313
+ selected_seq_groups: A list of sequence groups batched.
314
+ random_samples: (num_selected_samples,) A tensor of samples. The
315
+ length of samples could be smaller than selected_seq_groups if
316
+ seq_group.do_sample is False.
317
+ Returns:
318
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
319
+ same as the length of selected_seq_groups. If the corresponding
320
+ seq_group has do_sample=False, tuple contains ([], [])
321
+ """
322
+ # Find the maximum best_of value of the prompt phase requests.
323
+ random_samples = random_samples.cpu()
324
+ sample_idx = 0
325
+ results: SampleResultType = []
326
+ for seq_group in selected_seq_groups:
327
+ if not seq_group.do_sample:
328
+ results.append(([], []))
329
+ continue
330
+
331
+ seq_ids = seq_group.seq_ids
332
+ sampling_params = seq_group.sampling_params
333
+ is_prompt = seq_group.is_prompt
334
+ num_parent_seqs = len(seq_ids)
335
+ if is_prompt:
336
+ # Prompt phase.
337
+ parent_ids = [0] * sampling_params.best_of
338
+ next_token_ids = random_samples[
339
+ sample_idx, :sampling_params.best_of].tolist()
340
+ else:
341
+ # Generation phase.
342
+ parent_ids = list(range(num_parent_seqs))
343
+ next_token_ids = random_samples[sample_idx:sample_idx +
344
+ num_parent_seqs, 0].tolist()
345
+ results.append((next_token_ids, parent_ids))
346
+ sample_idx += num_parent_seqs
347
+ return results
348
+
349
+
350
+ def _beam_search_sample(
351
+ selected_seq_groups: List[SequenceGroupToSample],
352
+ logprobs: torch.Tensor,
353
+ ) -> SampleResultType:
354
+ """Run beam sampling on a given samples.
355
+
356
+ Args:
357
+ selected_seq_groups: A list of sequence groups batched.
358
+ logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
359
+ on selected sample indices.
360
+ Returns:
361
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
362
+ same as the length of selected_seq_groups. If the corresponding
363
+ seq_group has do_sample=False, tuple contains ([], [])
364
+ """
365
+ # We sample 2 * beam_width candidates to make sure that with high
366
+ # probability we can get `beam_width` candidates in addition to
367
+ # the finished sequences for the next iteration. See
368
+ # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
369
+ # for details. See also HF reference:
370
+ # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
371
+ #
372
+ # NOTE: Beam search is not vectorized, so its speed can be slower than
373
+ # other sampling methods.
374
+ sample_idx = 0
375
+ results: SampleResultType = []
376
+ for seq_group in selected_seq_groups:
377
+ if not seq_group.do_sample:
378
+ results.append(([], []))
379
+ continue
380
+
381
+ is_prompt = seq_group.is_prompt
382
+ seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
383
+ num_parent_seqs = len(seq_ids)
384
+ beam_width = sampling_params.best_of
385
+ seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
386
+ if is_prompt:
387
+ # Prompt phase.
388
+ assert num_parent_seqs == 1, (
389
+ "Prompt input should have only one seq.")
390
+ parent_ids = [0] * (2 * beam_width)
391
+ _, next_token_ids = torch.topk(seq_group_logprobs[0],
392
+ 2 * beam_width)
393
+ next_token_ids = next_token_ids.tolist()
394
+ else:
395
+ # Generation phase.
396
+ cumulative_logprobs: List[int] = [
397
+ seq_group.seq_data[seq_id].cumulative_logprob
398
+ for seq_id in seq_ids
399
+ ]
400
+ cumulative_logprobs_tensor = torch.tensor(
401
+ cumulative_logprobs,
402
+ dtype=torch.float,
403
+ device=seq_group_logprobs.device)
404
+ seq_group_logprobs = (seq_group_logprobs +
405
+ cumulative_logprobs_tensor.unsqueeze(dim=1))
406
+ _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
407
+ 2 * beam_width)
408
+ topk_ids = topk_ids.tolist()
409
+ vocab_size = seq_group_logprobs.size(-1)
410
+ parent_ids = [i // vocab_size for i in topk_ids]
411
+ next_token_ids = [i % vocab_size for i in topk_ids]
412
+ results.append((next_token_ids, parent_ids))
413
+ sample_idx += num_parent_seqs
414
+ assert sample_idx == logprobs.size(0)
415
+ return results
416
+
417
+
418
+ # torch.multinomial forces a GPU<->CPU sync.
419
+ # Therefore, we use an optimized implementation instead.
420
+ # Note that we always sample with replacement.
421
+ # probs will be modified in place, but this is fine, as we pass
422
+ # in a copy already.
423
+ def _multinomial(
424
+ probs: torch.Tensor,
425
+ num_samples: int,
426
+ seq_groups: Optional[List[SequenceGroupToSample]] = None,
427
+ ) -> torch.Tensor:
428
+ if num_samples > 1:
429
+ # This is equivalent to torch.repeat_interleaved (which also
430
+ # forces a GPU<->CPU sync).
431
+ # This allows us to do sampling with replacement by creating
432
+ # num_samples copies of each row in the tensor, and then
433
+ # batch sampling the resulting tensor.
434
+ probs = probs[:, None, :].expand(probs.shape[0], num_samples,
435
+ probs.shape[1]).contiguous().view(
436
+ -1, probs.shape[1])
437
+ q = torch.empty_like(probs)
438
+ if seq_groups is None:
439
+ q.exponential_()
440
+ else:
441
+ sample_idx = 0
442
+ for seq_group in seq_groups:
443
+ seq_ids = seq_group.seq_ids
444
+ next_sample_idx = sample_idx + len(seq_ids) * num_samples
445
+ q[sample_idx:next_sample_idx].exponential_(
446
+ generator=seq_group.generator)
447
+ sample_idx = next_sample_idx
448
+ return probs.div_(q).argmax(dim=1).view(-1, num_samples)
449
+
450
+
451
+ def _sample_with_torch(
452
+ probs: torch.Tensor,
453
+ logprobs: torch.Tensor,
454
+ sampling_metadata: SamplingMetadata,
455
+ include_gpu_probs_tensor: bool,
456
+ modify_greedy_probs: bool,
457
+ ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
458
+ categorized_seq_group_ids: Dict[SamplingType,
459
+ List[int]] = {t: []
460
+ for t in SamplingType}
461
+ categorized_sample_indices = sampling_metadata.categorized_sample_indices
462
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
463
+ sampling_params = seq_group.sampling_params
464
+ sampling_type = sampling_params.sampling_type
465
+ categorized_seq_group_ids[sampling_type].append(i)
466
+
467
+ sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
468
+ sample_metadata = {}
469
+ multinomial_samples = {}
470
+
471
+ # Create output tensor for sampled token ids.
472
+ if include_gpu_probs_tensor:
473
+ sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
474
+ 1,
475
+ dtype=torch.long,
476
+ device=logprobs.device)
477
+ else:
478
+ sampled_token_ids_tensor = None
479
+
480
+ # Counterintiutively, having two loops here is actually faster.
481
+ # The first loop can run without waiting on GPU<->CPU sync.
482
+ for sampling_type in SamplingType:
483
+ sample_indices = categorized_sample_indices[sampling_type][:, 0]
484
+ num_tokens = len(sample_indices)
485
+ if num_tokens == 0:
486
+ continue
487
+
488
+ seq_group_id = categorized_seq_group_ids[sampling_type]
489
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
490
+ sample_metadata[sampling_type] = (seq_group_id, seq_groups)
491
+ long_sample_indices = sample_indices.long()
492
+ if sampling_type == SamplingType.GREEDY:
493
+ greedy_samples = torch.argmax(logprobs[long_sample_indices],
494
+ dim=-1)
495
+
496
+ if include_gpu_probs_tensor:
497
+ # Store sampled tokens in output tensor.
498
+ sampled_token_ids_tensor[
499
+ long_sample_indices] = greedy_samples.unsqueeze(-1)
500
+
501
+ if modify_greedy_probs:
502
+ # If required, modify the probabilities such that sampling from
503
+ # the modified distribution would always sample the argmax
504
+ # token id.
505
+ _modify_greedy_probs_inplace(logprobs, probs,
506
+ long_sample_indices,
507
+ greedy_samples)
508
+
509
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
510
+ max_best_of_in_batch = 1
511
+ for seq_group in seq_groups:
512
+ if seq_group.is_prompt:
513
+ sampling_params = seq_group.sampling_params
514
+ max_best_of_in_batch = max(max_best_of_in_batch,
515
+ sampling_params.best_of)
516
+ seeded_args = {} if sampling_type == SamplingType.RANDOM else {
517
+ "seq_groups": seq_groups,
518
+ }
519
+
520
+ multinomial_samples[sampling_type] = _multinomial(
521
+ probs[long_sample_indices], max_best_of_in_batch,
522
+ **seeded_args)
523
+
524
+ if include_gpu_probs_tensor:
525
+ # Store sampled tokens in output tensor.
526
+ sampled_token_ids_tensor[
527
+ long_sample_indices] = multinomial_samples[sampling_type]
528
+
529
+ elif sampling_type == SamplingType.BEAM:
530
+ beam_search_logprobs = logprobs[sample_indices]
531
+ else:
532
+ raise ValueError(f"Unsupported sampling type: {sampling_type}")
533
+
534
+ # GPU<->CPU sync happens in the loop below.
535
+ # This also converts the sample output to Python objects.
536
+ for sampling_type in SamplingType:
537
+ if sampling_type not in sample_metadata:
538
+ continue
539
+ (seq_group_id, seq_groups) = sample_metadata[sampling_type]
540
+ if sampling_type == SamplingType.GREEDY:
541
+ sample_results = _greedy_sample(seq_groups, greedy_samples)
542
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
543
+ sample_results = _random_sample(seq_groups,
544
+ multinomial_samples[sampling_type])
545
+ elif sampling_type == SamplingType.BEAM:
546
+ sample_results = _beam_search_sample(seq_groups,
547
+ beam_search_logprobs)
548
+ sample_results_dict.update(zip(seq_group_id, sample_results))
549
+
550
+ sample_results = [
551
+ sample_results_dict.get(i, ([], []))
552
+ for i in range(len(sampling_metadata.seq_groups))
553
+ ]
554
+ return sample_results, sampled_token_ids_tensor
555
+
556
+
557
+ def _sample_with_triton_kernel(
558
+ probs: torch.Tensor,
559
+ logprobs: torch.Tensor,
560
+ sampling_metadata: SamplingMetadata,
561
+ sampling_tensors: SamplingTensors,
562
+ ) -> SampleResultType:
563
+ categorized_seq_group_ids: Dict[SamplingType,
564
+ List[int]] = {t: []
565
+ for t in SamplingType}
566
+ categorized_sample_indices = sampling_metadata.categorized_sample_indices
567
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
568
+ sampling_params = seq_group.sampling_params
569
+ sampling_type = sampling_params.sampling_type
570
+ categorized_seq_group_ids[sampling_type].append(i)
571
+
572
+ sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
573
+ sample_metadata = {}
574
+ max_best_of_in_batch = 1
575
+
576
+ # Counterintiutively, having two loops here is actually faster.
577
+ # The first loop can run without waiting on GPU<->CPU sync.
578
+ for sampling_type in SamplingType:
579
+ sample_indices = categorized_sample_indices[sampling_type][:, 0]
580
+ sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
581
+ num_tokens = len(sample_indices)
582
+ if num_tokens == 0:
583
+ continue
584
+ seq_group_id = categorized_seq_group_ids[sampling_type]
585
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
586
+ sample_metadata[sampling_type] = (seq_group_id, seq_groups,
587
+ sample_indices,
588
+ sampled_token_indices)
589
+ if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
590
+ SamplingType.RANDOM_SEED):
591
+ for seq_group in seq_groups:
592
+ if seq_group.is_prompt:
593
+ sampling_params = seq_group.sampling_params
594
+ max_best_of_in_batch = max(max_best_of_in_batch,
595
+ sampling_params.best_of)
596
+ elif sampling_type == SamplingType.BEAM:
597
+ beam_search_logprobs = logprobs[sample_indices]
598
+ else:
599
+ raise ValueError(f"Unsupported sampling type: {sampling_type}")
600
+
601
+ sampled_tokens, _, _ = sample_triton(
602
+ probs=probs,
603
+ seeds=sampling_tensors.sampling_seeds,
604
+ max_best_of=max_best_of_in_batch,
605
+ sample_indices=sampling_tensors.sample_indices,
606
+ logprobs=logprobs,
607
+ # don't save logprobs because we have logic for that below
608
+ # TODO: use this instead of the CPU-based logic below
609
+ save_logprobs=False,
610
+ )
611
+
612
+ # GPU<->CPU sync happens in the loop below.
613
+
614
+ for sampling_type in SamplingType:
615
+ if sampling_type not in sample_metadata:
616
+ continue
617
+ (seq_group_id, seq_groups, sample_indices,
618
+ sampled_token_indices) = sample_metadata[sampling_type]
619
+ if sampling_type == SamplingType.GREEDY:
620
+ sample_results = _greedy_sample(
621
+ seq_groups, sampled_tokens[sampled_token_indices][:, 0])
622
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
623
+ sample_results = _random_sample(
624
+ seq_groups, sampled_tokens[sampled_token_indices])
625
+ elif sampling_type == SamplingType.BEAM:
626
+ sample_results = _beam_search_sample(seq_groups,
627
+ beam_search_logprobs)
628
+ sample_results_dict.update(zip(seq_group_id, sample_results))
629
+
630
+ sample_results = [
631
+ sample_results_dict.get(i, ([], []))
632
+ for i in range(len(sampling_metadata.seq_groups))
633
+ ]
634
+ return sample_results
635
+
636
+
637
+ def _sample(
638
+ probs: torch.Tensor, logprobs: torch.Tensor,
639
+ sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
640
+ include_gpu_probs_tensor: bool, modify_greedy_probs: bool
641
+ ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
642
+ """
643
+ Args:
644
+ probs: (num_query_tokens_in_batch, num_vocab)
645
+ logprobs: (num_query_tokens_in_batch, num_vocab)
646
+ sampling_metadata: The metadata for a batch for sampling.
647
+ sampling_tensors: Tensors that include sampling related metadata.
648
+
649
+ Returns:
650
+ (next_token_ids, parent_seq_ids) for each seq group in a batch.
651
+ If sampling is skipped, it returns ([], [])
652
+ sampled_token_ids_tensor: A tensor of sampled token ids.
653
+ """
654
+ return _sample_with_torch(
655
+ probs,
656
+ logprobs,
657
+ sampling_metadata,
658
+ include_gpu_probs_tensor=include_gpu_probs_tensor,
659
+ modify_greedy_probs=modify_greedy_probs,
660
+ )
661
+
662
+ # TODO: Enable once Triton kernel & associated code is faster.
663
+ # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
664
+ # sampling_tensors)
665
+
666
+
667
+ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
668
+ """
669
+ This function calculates the ranks of the chosen tokens in a logprob tensor.
670
+
671
+ Args:
672
+ x (torch.Tensor): 2D logprob tensor of shape (N, M)
673
+ where N is the no. of tokens and M is the vocab dim.
674
+ indices (torch.Tensor): List of chosen token indices.
675
+
676
+ Returns:
677
+ torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
678
+ Each element in the returned tensor represents the rank
679
+ of the chosen token in the input logprob tensor.
680
+ """
681
+ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
682
+ indices]
683
+ return (x > vals[:, None]).long().sum(1).add_(1)
684
+
685
+
686
+ def _get_logprobs(
687
+ logprobs: torch.Tensor,
688
+ sampling_metadata: SamplingMetadata,
689
+ sample_results: SampleResultType,
690
+ ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
691
+ """Return sample lobprobs and prompt logprobs.
692
+
693
+ The logic consists of 3 parts.
694
+ - Select indices to compute logprob from, ranks of token ids, and
695
+ the top k token ids from logprobs.
696
+ - Compute prompt logprobs if required.
697
+ - Compute sample logprobs if required.
698
+
699
+ Args:
700
+ logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
701
+ logprob per vocab. Sequence groups' query tokens are batched in a
702
+ single flattened tensor. For example, assuming there are N
703
+ seq groups, it is sorted by prefill tokens for seq_group_1 (if
704
+ prompt logprob is enabled), decode tokens for seq_group_1 (if
705
+ sampling is required), prefill tokens for seq_group_2, ...
706
+ sampling_metadata: The sampling metadata.
707
+ sample_results: (num_seq_groups) The tuple of (next_token_ids,
708
+ parent_ids) for each sequence group. When beam search is enabled,
709
+ sample_results can contain different number of seq_ids from
710
+ sampling_metadata.seq_groups. It is because beam search creates
711
+ 2 * BEAM_WIDTH number of samples (whereas there are only up to
712
+ BEAM_WIDTH number of seq_ids).
713
+
714
+ Returns:
715
+ A tuple of prompt and sample logprobs per sequence group in a batch.
716
+ """
717
+ # The index of query token to calculate logprobs. It includes both
718
+ # prompt and sample logprob indices.
719
+ query_indices: List[int] = []
720
+ # The next token ids to get the logprob value from.
721
+ next_token_ids: List[int] = []
722
+ # The largest requested number of logprobs. We find logprobs as many as the
723
+ # largest num logprobs in this API.
724
+ largest_num_logprobs = 1
725
+
726
+ # Select indices to compute logprob from, ranks of token ids, and the top
727
+ # k token ids from logprobs.
728
+ for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
729
+ sample_results):
730
+ sampling_params = seq_group.sampling_params
731
+
732
+ # Update indices and tokens for prompt logprobs.
733
+ if (seq_group.is_prompt
734
+ and sampling_params.prompt_logprobs is not None):
735
+ largest_num_logprobs = max(largest_num_logprobs,
736
+ sampling_params.prompt_logprobs)
737
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
738
+ query_indices.extend(seq_group.prompt_logprob_indices)
739
+ next_token_ids.extend(next_prompt_tokens)
740
+
741
+ # Update indices and next tokenes for sample logprob.
742
+ if seq_group.do_sample:
743
+ token_ids, parent_seq_ids = sample_result
744
+ # NOTE: We cannot directly use sample_indices because
745
+ # sample_indices only contain parent seq_ids of a previous step.
746
+ # The current step may have different number of seq_ids, and
747
+ # we can obtain it from `sample_result[1]`.
748
+ query_idx = seq_group.sample_indices[0]
749
+ query_indices.extend(
750
+ [query_idx + parent_id for parent_id in parent_seq_ids])
751
+ next_token_ids.extend(token_ids)
752
+
753
+ if sampling_params.logprobs is not None:
754
+ largest_num_logprobs = max(largest_num_logprobs,
755
+ sampling_params.logprobs)
756
+
757
+ assert len(next_token_ids) == len(query_indices)
758
+
759
+ if len(query_indices) == 0:
760
+ empty_sampled_logprob: SampleLogprobs = []
761
+ empty_prompt_logprob: Optional[PromptLogprobs] = None
762
+ return [empty_prompt_logprob], [empty_sampled_logprob]
763
+
764
+ query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
765
+ next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
766
+
767
+ # (num_selected_query_tokens, num_logprobs). Note that query_indices can
768
+ # contain duplicates if beam search is enabled.
769
+ selected_logprobs = logprobs[[
770
+ query_indices_gpu,
771
+ next_token_ids_gpu,
772
+ ]]
773
+ ranks = _get_ranks(
774
+ logprobs[query_indices_gpu],
775
+ next_token_ids_gpu,
776
+ )
777
+ assert selected_logprobs.shape[0] == ranks.shape[0]
778
+
779
+ # Logprobs of topk tokens for a batch of sequence groups.
780
+ # (num_query_tokens_across_batch).
781
+ if largest_num_logprobs > 0:
782
+ top_logprobs, top_token_ids = torch.topk(logprobs,
783
+ largest_num_logprobs,
784
+ dim=-1)
785
+ top_logprobs = top_logprobs.cpu()
786
+ top_token_ids = top_token_ids.cpu()
787
+ else:
788
+ top_logprobs, top_token_ids = None, None
789
+
790
+ selected_logprobs = selected_logprobs.cpu()
791
+ ranks = ranks.cpu()
792
+
793
+ # Find prompt/sample logprobs.
794
+ prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
795
+ sample_logprobs_per_seq_group: List[SampleLogprobs] = []
796
+ top_logprob_idx = 0
797
+ selected_logprobs_idx = 0
798
+
799
+ for seq_group, sample_result in zip(sampling_metadata.seq_groups,
800
+ sample_results):
801
+ (prompt_logprobs, top_logprob_idx,
802
+ selected_logprobs_idx) = _get_prompt_logprob_if_needed(
803
+ seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
804
+ selected_logprobs_idx, top_logprob_idx)
805
+ prompt_logprobs_per_seq_group.append(prompt_logprobs)
806
+
807
+ (sampled_logprobs, top_logprob_idx,
808
+ selected_logprobs_idx) = _get_sampled_logprob_if_needed(
809
+ seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
810
+ top_logprobs, selected_logprobs_idx, top_logprob_idx)
811
+ sample_logprobs_per_seq_group.append(sampled_logprobs)
812
+
813
+ return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
814
+
815
+
816
+ def _get_prompt_logprob_if_needed(
817
+ seq_group: SequenceGroupToSample,
818
+ selected_logprobs: torch.Tensor,
819
+ ranks: torch.Tensor,
820
+ top_token_ids: torch.Tensor,
821
+ top_logprobs: torch.Tensor,
822
+ selected_logprobs_idx: int,
823
+ top_logprob_idx: int,
824
+ ):
825
+ """Compute the prompt logprob from a sequence group if needed."""
826
+ sampling_params = seq_group.sampling_params
827
+ is_prompt = seq_group.is_prompt
828
+
829
+ # Find prompt logprobs
830
+ prompt_logprobs: Optional[PromptLogprobs] = None
831
+ if (is_prompt and sampling_params.prompt_logprobs is not None):
832
+ prompt_logprobs = []
833
+ num_logprobs = sampling_params.prompt_logprobs
834
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
835
+ for token_id in next_prompt_tokens:
836
+ # Calculate the prompt logprob of the real prompt tokens.
837
+ # Use tuple here for performance (to use to_list()).
838
+ # {token_id: (logprob, rank_from_vocab)}
839
+ prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
840
+ token_id: (selected_logprobs[selected_logprobs_idx].item(),
841
+ ranks[selected_logprobs_idx].item())
842
+ }
843
+
844
+ # Add top K prompt logprobs along with its rank.
845
+ if num_logprobs > 0:
846
+ prompt_logprobs_dict.update(
847
+ zip(
848
+ top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
849
+ zip(
850
+ top_logprobs[
851
+ top_logprob_idx, :num_logprobs].tolist(),
852
+ # This is ranks. Since top_logprob is sorted,
853
+ # we can just use a range here.
854
+ range(1, num_logprobs + 1))))
855
+ prompt_logprobs.append({
856
+ token_id: Logprob(*logprob_and_rank)
857
+ for token_id, logprob_and_rank in prompt_logprobs_dict.items()
858
+ })
859
+ # + 1 to go to the next prompt token.
860
+ top_logprob_idx += 1
861
+ selected_logprobs_idx += 1
862
+ return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
863
+
864
+
865
+ def _get_sampled_logprob_if_needed(
866
+ seq_group: SequenceGroupToSample,
867
+ sample_result: Tuple[List[int], List[int]],
868
+ selected_logprobs: torch.Tensor,
869
+ ranks: torch.Tensor,
870
+ top_token_ids: torch.Tensor,
871
+ top_logprobs: torch.Tensor,
872
+ selected_logprobs_idx: int,
873
+ top_logprob_idx: int,
874
+ ):
875
+ """Compute the sample logprob if needed."""
876
+ seq_ids = seq_group.seq_ids
877
+ num_logprobs = seq_group.sampling_params.logprobs
878
+ if num_logprobs is None:
879
+ num_logprobs = 0
880
+ sampled_logprobs: SampleLogprobs = []
881
+ next_token_ids, parent_seq_ids = sample_result
882
+
883
+ if seq_group.do_sample:
884
+ assert len(next_token_ids) > 0
885
+ for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
886
+ # Calculate the sample logprob of the real sampled tokens.
887
+ # Use tuple here for performance (to use to_list()).
888
+ # token_id: (logprob, rank_from_vocab)
889
+ sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
890
+ next_token_id:
891
+ (selected_logprobs[selected_logprobs_idx].item(),
892
+ ranks[selected_logprobs_idx].item())
893
+ }
894
+ # +1 to go to the next sampled token. Note that
895
+ # selected_logprobs can contain duplicates unlike top_logprobs
896
+ # when beam search is enabled.
897
+ selected_logprobs_idx += 1
898
+
899
+ # Second, add top K logprobs along with its rank.
900
+ if num_logprobs >= 0:
901
+ sampled_logprobs_dict.update(
902
+ zip(
903
+ top_token_ids[top_logprob_idx +
904
+ parent_id, :num_logprobs].tolist(),
905
+ zip(
906
+ top_logprobs[top_logprob_idx +
907
+ parent_id, :num_logprobs].tolist(),
908
+ # This is rank. Since top_logprob is sorted, we
909
+ # can just use a range here.
910
+ range(1, num_logprobs + 1))))
911
+ sampled_logprobs.append({
912
+ token_id: Logprob(*logprob_and_rank)
913
+ for token_id, logprob_and_rank in
914
+ sampled_logprobs_dict.items()
915
+ })
916
+ # There are len(seq_ids) number of sampled tokens for the current
917
+ # sequence group in top_logprobs. Jump to the next seq_group.
918
+ top_logprob_idx += len(seq_ids)
919
+ return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
920
+
921
+
922
+ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
923
+ sample_indices: torch.Tensor,
924
+ greedy_samples: torch.Tensor) -> None:
925
+ """Modify the probability distributions of the greedily-sampled tokens such
926
+ that each sampled token has a "probability" of 1.0. This is required by
927
+ speculative decoding, which depends on the sampling method being encoded
928
+ within the probability distribution for correctness.
929
+
930
+ # Why do we only need to do this for greedy sampling?
931
+
932
+ vLLM's sampler performs the following steps for greedy or multinomial
933
+ (random) sampling:
934
+ 1. Get logits from model.
935
+ 2. Modify logits according to per-sequence sampling parameters.
936
+ - Multiply by temperature, top-k and top-p masking, penalize tokens
937
+ according to their frequency, etc.
938
+ 3. Sample a token.
939
+ - Random sampling simply samples from the modified probability
940
+ distribution.
941
+ - Greedy sampling performs `argmax` to obtain the token with the
942
+ highest likelihood.
943
+
944
+ Ignoring greedy sampling for a moment, we find that the computed probability
945
+ distribution has the following property: we can sample from it independently
946
+ and find that the token sampled by the Sampler has a frequency corresponding
947
+ to how often we see it in our sampling. In other words, for tokens sampled
948
+ with vLLM's random SamplingType, the computed probability distribution
949
+ encodes the sampling methodology completely.
950
+
951
+ Greedy sampling does not normally have this property. vLLM modifies logits
952
+ according to sampling params, then performs `argmax`, then returns the
953
+ sampled token and the computed probability distribution. If we sample from
954
+ the distribution, we'll find the likelihood of the greedily-sampled token
955
+ is not always 1.0.
956
+
957
+ Since lossless speculative decoding requires that the sampling methodology
958
+ be encoded within the probability distribution, we are motivated to modify
959
+ the probability distribution such that the sampled token has probability 1
960
+ when speculative decoding is used.
961
+
962
+ NOTE: Alternatively, we could use an extremely low temperature to achieve
963
+ greedy sampling using multinomial computation and unite the codepaths. This
964
+ has implications on the overall design of the sampler, e.g. how to record
965
+ accurate logprobs for the user, so this improvement is deferred to later.
966
+ """
967
+ # NOTE: logprobs are not modified so they can be returned to the user.
968
+ probs[sample_indices, :] = 0
969
+ probs[sample_indices, greedy_samples] = 1.0
970
+
971
+
972
+ def _build_sampler_output(
973
+ sample_results: SampleResultType,
974
+ sampling_metadata: SamplingMetadata,
975
+ prompt_logprobs: List[Optional[PromptLogprobs]],
976
+ sample_logprobs: List[SampleLogprobs],
977
+ on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
978
+ torch.Tensor]],
979
+ ) -> SamplerOutput:
980
+ """Construct Python objects with the output of sampling.
981
+
982
+ Args:
983
+ on_device_tensors: Tuple containing on-device tensors with the
984
+ probabilities used in sampling and the sampled token ids. This
985
+ allows post-processing without copies to CPU/serialization, e.g. in
986
+ speculative decoding rejection sampling.
987
+ """
988
+
989
+ sampler_output = []
990
+ for (seq_group, sample_result, group_prompt_logprobs,
991
+ group_sample_logprobs) in zip(sampling_metadata.seq_groups,
992
+ sample_results, prompt_logprobs,
993
+ sample_logprobs):
994
+ seq_ids = seq_group.seq_ids
995
+ next_token_ids, parent_ids = sample_result
996
+ seq_outputs = []
997
+ for parent_id, next_token_id, logprobs in zip(parent_ids,
998
+ next_token_ids,
999
+ group_sample_logprobs):
1000
+ seq_outputs.append(
1001
+ SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
1002
+ sampler_output.append(
1003
+ SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
1004
+
1005
+ # If not specified, store None values in SamplerOutput.
1006
+ if on_device_tensors is not None:
1007
+ (sampled_token_probs, logprobs_tensor,
1008
+ sampled_token_ids) = on_device_tensors
1009
+ else:
1010
+ sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
1011
+ None)
1012
+
1013
+ return SamplerOutput(
1014
+ outputs=sampler_output,
1015
+ sampled_token_probs=sampled_token_probs,
1016
+ sampled_token_ids=sampled_token_ids,
1017
+ logprobs=logprobs_tensor,
1018
+ )
1019
+
1020
+
1021
+ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1022
+ """Get a list of next prompt tokens to compute logprob from a
1023
+ given sequence group.
1024
+
1025
+ It is used to compute prompt logprob. Imagine you have logprob for each
1026
+ query token. Query token needs to know the next prompt token id to compute
1027
+ prompt logprob. This is a helper to obtain next prompt token ids.
1028
+
1029
+ This API has to be used only when the caller knows seq_group is in prefill
1030
+ stage.
1031
+
1032
+ Returns:
1033
+ A list of next prompt tokens to compute logprob.
1034
+ """
1035
+ assert seq_group.is_prompt, (
1036
+ "Caller should ensure the sequence group is in a prefill stage.")
1037
+ seq_ids = seq_group.seq_ids
1038
+ query_len = seq_group.query_len
1039
+ assert query_len is not None
1040
+ # prompt has only 1 seq id.
1041
+ assert len(seq_ids) == 1
1042
+ seq_data = seq_group.seq_data[seq_ids[0]]
1043
+ computed_len = seq_data.get_num_computed_tokens()
1044
+ prompt_tokens = seq_data.prompt_token_ids
1045
+ # +1 because we are looking for a next prompt token.
1046
+ next_token_index_start = computed_len + 1
1047
+ next_token_index_end = min(computed_len + query_len + 1,
1048
+ len(prompt_tokens))
1049
+ next_prompt_tokens = prompt_tokens[
1050
+ next_token_index_start:next_token_index_end]
1051
+ return next_prompt_tokens