vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,588 @@
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
8
+ from vllm.sampling_params import SamplingParams, SamplingType
9
+ from vllm.sequence import SequenceData, SequenceGroupMetadata
10
+ from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
11
+ maybe_expand_dim)
12
+
13
+ _SAMPLING_EPS = 1e-5
14
+ _SEED_0_REPLACEMENT = 3403598558
15
+
16
+
17
+ @dataclass
18
+ class SequenceGroupToSample:
19
+ # |---------- N-1 iteration --------|
20
+ # |---------------- N iteration ---------------------|
21
+ # |- tokenA -|......................|-- newTokens ---|
22
+ # |---------- context_len ----------|
23
+ # |-------------------- seq_len ----------------------|
24
+ # |-- query_len ---|
25
+
26
+ # Sequence ids for the sequence group in a previous step.
27
+ seq_ids: List[int]
28
+ sampling_params: SamplingParams
29
+ # seq_id -> sequence data.
30
+ seq_data: Dict[int, SequenceData]
31
+ # The length of the sequence (all tokens seen in the past + new token to
32
+ # compute attention) of the sequence group. None if it is in a decode
33
+ # stage.
34
+ seq_len: Optional[int]
35
+ # The length of new query tokens to compute in the current step. None if it
36
+ # is in a decode stage. The length of query_len <= seq_len if chunked
37
+ # prefill is enabled.
38
+ query_len: Optional[int]
39
+ # A random number generator for sampling.
40
+ generator: Optional[torch.Generator]
41
+ # True if the sequence group is in prefill stage. False if it is in a
42
+ # decode stage.
43
+ is_prompt: bool
44
+ # Query token indices from logits. to compute prompt logprob. Empty if
45
+ # prompt logprob is not required.
46
+ prompt_logprob_indices: List[int]
47
+ # Sample token indices from logits. Empty if sampling is not required.
48
+ sample_indices: List[int]
49
+
50
+ @property
51
+ def do_sample(self):
52
+ return len(self.sample_indices) > 0
53
+
54
+ def __post_init__(self):
55
+ if len(self.prompt_logprob_indices) > 0:
56
+ assert self.sampling_params.prompt_logprobs is not None
57
+ if self.is_prompt:
58
+ assert self.seq_len is not None
59
+ assert self.query_len is not None
60
+
61
+
62
+ class SamplingMetadata:
63
+ """Metadata for input sequences. Used in sampler.
64
+
65
+ The usage is as follow;
66
+ ```
67
+ hidden_states = execute_model(...)
68
+ logits = hidden_states[sampling_metadata.selected_token_indices]
69
+ sample(logits)
70
+
71
+ def sample(logits):
72
+ # Use categorized_sample_indices for sampling....
73
+ ```
74
+
75
+ Args:
76
+ seq_groups: List of batched sequence groups.
77
+ selected_token_indices: (num_query_tokens_to_logprob). Indices to find
78
+ logits from the initial model output hidden states.
79
+ categorized_sample_indices: SamplingType -> token indices to sample.
80
+ Each token indices is 2D tensor of (num_indices, num_indices) where
81
+ the first item means the sample index within the returned logit
82
+ (before pruning padding), and the second item means the sample
83
+ index after pruning using selected_token_indices.
84
+ For example, if the returned logit is [1, 2, 3], and we select
85
+ [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
86
+ The first tuple is [1, 2] (sampled index within original logit),
87
+ and the second tuple is [0, 1] (sampled index within pruned logit).
88
+ num_prompts: Number of prompt sequence groups in seq_groups.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ seq_groups: List[SequenceGroupToSample],
94
+ selected_token_indices: torch.Tensor,
95
+ categorized_sample_indices: Dict[SamplingType, torch.Tensor],
96
+ num_prompts: int,
97
+ ) -> None:
98
+ self.seq_groups = seq_groups
99
+ self.selected_token_indices = selected_token_indices
100
+ self.categorized_sample_indices = categorized_sample_indices
101
+ self.num_prompts = num_prompts
102
+
103
+ @staticmethod
104
+ def prepare(
105
+ seq_group_metadata_list: List[SequenceGroupMetadata],
106
+ seq_lens: List[int],
107
+ query_lens: Optional[List[int]],
108
+ device: str,
109
+ pin_memory: bool,
110
+ ) -> "SamplingMetadata":
111
+ (
112
+ seq_groups,
113
+ selected_token_indices,
114
+ categorized_sample_indices,
115
+ num_prompts,
116
+ ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
117
+ device)
118
+ selected_token_indices = async_tensor_h2d(selected_token_indices,
119
+ dtype=torch.long,
120
+ target_device=device,
121
+ pin_memory=pin_memory)
122
+ categorized_sample_indices = {
123
+ t: maybe_expand_dim(
124
+ async_tensor_h2d(seq_ids,
125
+ dtype=torch.int,
126
+ target_device=device,
127
+ pin_memory=pin_memory), 2, 2)
128
+ for t, seq_ids in categorized_sample_indices.items()
129
+ }
130
+
131
+ sampling_metadata = SamplingMetadata(
132
+ seq_groups=seq_groups,
133
+ selected_token_indices=selected_token_indices,
134
+ categorized_sample_indices=categorized_sample_indices,
135
+ num_prompts=num_prompts,
136
+ )
137
+ return sampling_metadata
138
+
139
+ def __repr__(self) -> str:
140
+ return (
141
+ "SamplingMetadata("
142
+ f"seq_groups={self.seq_groups}, "
143
+ f"selected_token_indices={self.selected_token_indices}, "
144
+ f"categorized_sample_indices={self.categorized_sample_indices}), ")
145
+
146
+
147
+ def _prepare_seq_groups(
148
+ seq_group_metadata_list: List[SequenceGroupMetadata],
149
+ seq_lens: List[int],
150
+ query_lens: Optional[List[int]],
151
+ device: str,
152
+ ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
153
+ SamplingType, List[Tuple[int, int]]], int]:
154
+ """Prepare sequence groups and indices for sampling.
155
+
156
+ Args:
157
+ seq_group_metadata_list: A list of sequence group to batch.
158
+ seq_lens: A list of sequence lens per sequence group.
159
+ Index of prompt len should match with seq_group_metadata_list.
160
+ query_lens: A list of query lengths. Prompt lens include the length
161
+ of entire prompt tokens, and it could be shorter.
162
+ device: A device to use for random number generator,
163
+ `SequenceGroupToSample.generator`.
164
+
165
+ Returns:
166
+ seq_groups: A list of sequence group to sample.
167
+ selected_token_indices: See the definition from `SamplingMetadata`.
168
+ categorized_sample_indices: See the definition from `SamplingMetadata`.
169
+ num_prompts: Total number of prompts from `seq_group_metadata_list`.
170
+ """
171
+ # Batched sequence groups for the current model forward stsep.
172
+ seq_groups: List[SequenceGroupToSample] = []
173
+ # A list of token indices to sample/compute logprob. It is used to
174
+ # prune the outcome logits from the model for the performance.
175
+ selected_token_indices: List[int] = []
176
+ # Used for selected_token_indices.
177
+ model_output_idx = 0
178
+
179
+ # Sampling type -> (
180
+ # indices to sample/prompt logprob within pruned output logits,
181
+ # indices to sample within pruned logits)
182
+ categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
183
+ t: []
184
+ for t in SamplingType
185
+ }
186
+ # Index of logits to compute logprob. Logits include both prompt logprob
187
+ # and sample logprob indices.
188
+ logit_idx = 0
189
+ # Index to sample from a sample tensor. It is used by triton sample kernel.
190
+ # See `_sample_with_triton_kernel` for more details.
191
+ sample_idx = 0
192
+ # Total number of prompts from given sequence groups.
193
+ num_prompts = 0
194
+
195
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
196
+ seq_ids = list(seq_group_metadata.seq_data.keys())
197
+ sampling_params = seq_group_metadata.sampling_params
198
+ is_prompt = seq_group_metadata.is_prompt
199
+ generator: Optional[torch.Generator] = None
200
+ # If the current seq group is in decode stage, it is None.
201
+ seq_len: Optional[int] = None
202
+ query_len: Optional[int] = None
203
+ prompt_logprob_indices: List[int] = []
204
+ sample_indices: List[int] = []
205
+ do_sample = seq_group_metadata.do_sample
206
+
207
+ if seq_group_metadata.is_prompt:
208
+ if sampling_params.seed is not None:
209
+ seq_group_metadata.state.generator = torch.Generator(
210
+ device=device).manual_seed(sampling_params.seed)
211
+
212
+ num_prompts += 1
213
+ num_prefill_sample = len(seq_ids)
214
+ assert num_prefill_sample == 1
215
+ assert query_lens is not None and seq_lens is not None
216
+ query_len, seq_len = query_lens[i], seq_lens[i]
217
+ # If we need sampling, exclude num_prefill_sample tokens from
218
+ # prompt logprob.
219
+ prompt_logprob_len = (query_len - num_prefill_sample
220
+ if do_sample else query_len)
221
+ sample_len = num_prefill_sample if do_sample else 0
222
+ else:
223
+ # Decode
224
+ prompt_logprob_len = 0
225
+ sample_len = len(seq_ids) if do_sample else 0
226
+
227
+ # Update indices to select from the model output.
228
+ """
229
+ This blocks computes selected_token_indices which is used in the
230
+ following way.
231
+
232
+ hidden_states = model(...)
233
+ logits = hidden_states[selected_token_indices]
234
+ """
235
+
236
+ if sampling_params.prompt_logprobs:
237
+ selected_token_indices.extend(
238
+ range(model_output_idx, model_output_idx + prompt_logprob_len))
239
+ model_output_idx += prompt_logprob_len
240
+ if do_sample:
241
+ selected_token_indices.extend(
242
+ range(model_output_idx, model_output_idx + sample_len))
243
+ model_output_idx += sample_len
244
+
245
+ # We now find indices for logprob computation and sampling.
246
+ """
247
+ This block computes categorized_sample_indices which is used in the
248
+ following way.
249
+
250
+ hidden_states = model(...)
251
+ logits = hidden_states[selected_token_indices]
252
+ def sample(logits):
253
+ # Use categorized_sample_indices for sampling.
254
+ # prompt_logprob_indices to find prompt logprob indices.
255
+ # sample_indices to find sample indices.
256
+ """
257
+
258
+ if sampling_params.prompt_logprobs is not None:
259
+ prompt_logprob_indices.extend(
260
+ range(logit_idx, logit_idx + prompt_logprob_len))
261
+ logit_idx += prompt_logprob_len
262
+ if do_sample:
263
+ sample_indices.extend(range(logit_idx, logit_idx + sample_len))
264
+ categorized_sample_indices[sampling_params.sampling_type].extend(
265
+ list(
266
+ zip(range(logit_idx, logit_idx + sample_len),
267
+ range(sample_idx, sample_idx + sample_len))))
268
+ logit_idx += sample_len
269
+ sample_idx += sample_len
270
+
271
+ if sampling_params.seed is not None:
272
+ generator = seq_group_metadata.state.generator
273
+
274
+ seq_groups.append(
275
+ SequenceGroupToSample(
276
+ seq_ids=seq_ids,
277
+ sampling_params=sampling_params,
278
+ seq_data=seq_group_metadata.seq_data,
279
+ seq_len=seq_len,
280
+ query_len=query_len,
281
+ generator=generator,
282
+ is_prompt=is_prompt,
283
+ prompt_logprob_indices=list(prompt_logprob_indices),
284
+ sample_indices=list(sample_indices)))
285
+ return (seq_groups, selected_token_indices, categorized_sample_indices,
286
+ num_prompts)
287
+
288
+
289
+ @dataclass
290
+ class SamplingTensors:
291
+ """Tensors for sampling."""
292
+
293
+ temperatures: torch.Tensor
294
+ top_ps: torch.Tensor
295
+ top_ks: torch.Tensor
296
+ min_ps: torch.Tensor
297
+ presence_penalties: torch.Tensor
298
+ frequency_penalties: torch.Tensor
299
+ repetition_penalties: torch.Tensor
300
+ sampling_seeds: torch.Tensor
301
+ sample_indices: torch.Tensor
302
+ extra_seeds: Optional[torch.Tensor]
303
+ prompt_tokens: torch.Tensor
304
+ output_tokens: torch.Tensor
305
+
306
+ @classmethod
307
+ def from_sampling_metadata(
308
+ cls,
309
+ sampling_metadata: "SamplingMetadata",
310
+ vocab_size: int,
311
+ device: torch.device,
312
+ dtype: torch.dtype,
313
+ *,
314
+ extra_seeds_to_generate: int = 0,
315
+ extra_entropy: Optional[Tuple[int, ...]] = None
316
+ ) -> Tuple["SamplingTensors", bool, bool, bool]:
317
+ """
318
+ extra_seeds_to_generate: extra seeds to generate using the
319
+ user-defined seed for each sequence.
320
+ extra_entropy: extra entropy to use when generating seeds.
321
+ """
322
+ prompt_tokens: List[List[int]] = []
323
+ output_tokens: List[List[int]] = []
324
+ top_ks: List[int] = []
325
+ temperatures: List[float] = []
326
+ top_ps: List[float] = []
327
+ min_ps: List[float] = []
328
+ presence_penalties: List[float] = []
329
+ frequency_penalties: List[float] = []
330
+ repetition_penalties: List[float] = []
331
+ sampling_seeds: List[int] = []
332
+ sample_indices: List[int] = []
333
+ prompt_best_of: List[int] = []
334
+ do_penalties = False
335
+ do_top_p_top_k = False
336
+ do_min_p = False
337
+
338
+ # We need one base seed per Triton slice.
339
+ seeds_to_generate = (extra_seeds_to_generate +
340
+ get_num_triton_sampler_splits(vocab_size))
341
+
342
+ assert sampling_metadata.seq_groups is not None
343
+ for seq_group in sampling_metadata.seq_groups:
344
+ seq_ids = seq_group.seq_ids
345
+ sampling_params = seq_group.sampling_params
346
+ temperature = sampling_params.temperature
347
+ p = sampling_params.presence_penalty
348
+ f = sampling_params.frequency_penalty
349
+ r = sampling_params.repetition_penalty
350
+ top_p = sampling_params.top_p
351
+ min_p = sampling_params.min_p
352
+ seed = sampling_params.seed
353
+
354
+ is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
355
+
356
+ # k should not be greater than the vocab size.
357
+ top_k = min(sampling_params.top_k, vocab_size)
358
+ top_k = vocab_size if top_k == -1 else top_k
359
+ if temperature < _SAMPLING_EPS:
360
+ # NOTE: Zero temperature means deterministic sampling
361
+ # (i.e., greedy sampling or beam search).
362
+ # Set the temperature to 1 to avoid division by zero.
363
+ temperature = 1.0
364
+ if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
365
+ or top_k != vocab_size):
366
+ do_top_p_top_k = True
367
+ if not do_min_p and min_p > _SAMPLING_EPS:
368
+ do_min_p = True
369
+ if not do_penalties and (abs(p) >= _SAMPLING_EPS
370
+ or abs(f) >= _SAMPLING_EPS
371
+ or abs(r - 1.0) >= _SAMPLING_EPS):
372
+ do_penalties = True
373
+
374
+ is_prompt = seq_group.is_prompt
375
+ if (seq_group.is_prompt
376
+ and sampling_params.prompt_logprobs is not None):
377
+ # For tokens in the prompt that we only need to get
378
+ # their logprobs
379
+ query_len = seq_group.query_len
380
+ assert query_len is not None
381
+ prefill_len = len(seq_group.prompt_logprob_indices)
382
+ temperatures += [temperature] * prefill_len
383
+ top_ps += [top_p] * prefill_len
384
+ top_ks += [top_k] * prefill_len
385
+ min_ps += [min_p] * prefill_len
386
+ presence_penalties += [0] * prefill_len
387
+ frequency_penalties += [0] * prefill_len
388
+ repetition_penalties += [1] * prefill_len
389
+ prompt_tokens.extend([] for _ in range(prefill_len))
390
+ output_tokens.extend([] for _ in range(prefill_len))
391
+
392
+ if seq_group.do_sample:
393
+ sample_lens = len(seq_group.sample_indices)
394
+ assert sample_lens == len(seq_ids)
395
+ for seq_id in seq_ids:
396
+ seq_data = seq_group.seq_data[seq_id]
397
+ prompt_tokens.append(seq_data.prompt_token_ids)
398
+ output_tokens.append(seq_data.output_token_ids)
399
+ temperatures += [temperature] * len(seq_ids)
400
+ top_ps += [top_p] * len(seq_ids)
401
+ top_ks += [top_k] * len(seq_ids)
402
+ min_ps += [min_p] * len(seq_ids)
403
+ presence_penalties += [p] * len(seq_ids)
404
+ frequency_penalties += [f] * len(seq_ids)
405
+ repetition_penalties += [r] * len(seq_ids)
406
+
407
+ if is_prompt:
408
+ prompt_best_of.append(sampling_params.best_of)
409
+ query_len = seq_group.query_len
410
+ assert query_len is not None
411
+
412
+ for seq_id in seq_ids:
413
+ seq_data = seq_group.seq_data[seq_id]
414
+ extra_entropy = extra_entropy or ()
415
+ seq_seeds = cls._get_sequence_seeds(
416
+ seed,
417
+ seq_data.get_len(),
418
+ *extra_entropy,
419
+ seq_id,
420
+ seeds_to_generate=seeds_to_generate,
421
+ is_greedy=is_greedy)
422
+ sampling_seeds.append(seq_seeds)
423
+ sample_indices.extend(seq_group.sample_indices)
424
+
425
+ sampling_tensors = SamplingTensors.from_lists(
426
+ temperatures, top_ps, top_ks, min_ps, presence_penalties,
427
+ frequency_penalties, repetition_penalties, sampling_seeds,
428
+ sample_indices, prompt_tokens, output_tokens, vocab_size,
429
+ extra_seeds_to_generate, device, dtype)
430
+ return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
431
+
432
+ @classmethod
433
+ def from_lists(cls, temperatures: List[float], top_ps: List[float],
434
+ top_ks: List[int], min_ps: List[float],
435
+ presence_penalties: List[float],
436
+ frequency_penalties: List[float],
437
+ repetition_penalties: List[float],
438
+ sampling_seeds: List[int], sample_indices: List[int],
439
+ prompt_tokens: List[List[int]],
440
+ output_tokens: List[List[int]], vocab_size: int,
441
+ extra_seeds_to_generate: int, device: torch.device,
442
+ dtype: torch.dtype) -> "SamplingTensors":
443
+ # Note that the performance will be very bad without
444
+ # pinned memory.
445
+ pin_memory = is_pin_memory_available()
446
+ prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
447
+ default=0)
448
+ prompt_padded_tokens = [
449
+ tokens + [vocab_size] * (prompt_max_len - len(tokens))
450
+ for tokens in prompt_tokens
451
+ ]
452
+ output_max_len = max([len(tokens) for tokens in output_tokens],
453
+ default=0)
454
+ output_padded_tokens = [
455
+ tokens + [vocab_size] * (output_max_len - len(tokens))
456
+ for tokens in output_tokens
457
+ ]
458
+
459
+ temperatures_t = torch.tensor(
460
+ temperatures,
461
+ device="cpu",
462
+ dtype=dtype,
463
+ pin_memory=pin_memory,
464
+ )
465
+ top_ps_t = torch.tensor(
466
+ top_ps,
467
+ device="cpu",
468
+ dtype=dtype,
469
+ pin_memory=pin_memory,
470
+ )
471
+ min_ps_t = torch.tensor(
472
+ min_ps,
473
+ device="cpu",
474
+ dtype=dtype,
475
+ pin_memory=pin_memory,
476
+ )
477
+ presence_penalties_t = torch.tensor(
478
+ presence_penalties,
479
+ device="cpu",
480
+ dtype=dtype,
481
+ pin_memory=pin_memory,
482
+ )
483
+ frequency_penalties_t = torch.tensor(
484
+ frequency_penalties,
485
+ device="cpu",
486
+ dtype=dtype,
487
+ pin_memory=pin_memory,
488
+ )
489
+ repetition_penalties_t = torch.tensor(
490
+ repetition_penalties,
491
+ device="cpu",
492
+ dtype=dtype,
493
+ pin_memory=pin_memory,
494
+ )
495
+ top_ks_t = torch.tensor(
496
+ top_ks,
497
+ device="cpu",
498
+ dtype=torch.int,
499
+ pin_memory=pin_memory,
500
+ )
501
+ sample_indices_t = torch.tensor(
502
+ sample_indices,
503
+ device="cpu",
504
+ dtype=torch.long,
505
+ pin_memory=pin_memory,
506
+ )
507
+ prompt_tensor = torch.tensor(
508
+ prompt_padded_tokens,
509
+ device="cpu",
510
+ dtype=torch.long,
511
+ pin_memory=pin_memory,
512
+ )
513
+ output_tensor = torch.tensor(
514
+ output_padded_tokens,
515
+ device="cpu",
516
+ dtype=torch.long,
517
+ pin_memory=pin_memory,
518
+ )
519
+ # need to transpose and make contiguous to
520
+ # copy the tensor correctly.
521
+ # [batch_size, n_seeds] -> [n_seeds, batch_size]
522
+ sampling_seeds_t = torch.tensor(
523
+ sampling_seeds,
524
+ device="cpu",
525
+ dtype=torch.long,
526
+ pin_memory=pin_memory,
527
+ ).T.contiguous()
528
+
529
+ # Because the memory is pinned, we can do non-blocking
530
+ # transfer to device.
531
+
532
+ # How many seeds the sample operation itself will need.
533
+ num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
534
+ sampling_seeds_gpu = sampling_seeds_t.to(device=device,
535
+ non_blocking=True)
536
+ extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
537
+ if not extra_seeds_gpu.numel():
538
+ extra_seeds_gpu = None
539
+ sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
540
+
541
+ return cls(
542
+ temperatures=temperatures_t.to(device=device, non_blocking=True),
543
+ top_ps=top_ps_t.to(device=device, non_blocking=True),
544
+ top_ks=top_ks_t.to(device=device, non_blocking=True),
545
+ min_ps=min_ps_t.to(device=device, non_blocking=True),
546
+ presence_penalties=presence_penalties_t.to(device=device,
547
+ non_blocking=True),
548
+ frequency_penalties=frequency_penalties_t.to(device=device,
549
+ non_blocking=True),
550
+ repetition_penalties=repetition_penalties_t.to(device=device,
551
+ non_blocking=True),
552
+ prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
553
+ output_tokens=output_tensor.to(device=device, non_blocking=True),
554
+ sampling_seeds=sampling_seeds_gpu,
555
+ sample_indices=sample_indices_t.to(device=device,
556
+ non_blocking=True),
557
+ extra_seeds=extra_seeds_gpu,
558
+ )
559
+
560
+ @staticmethod
561
+ def _get_sequence_seeds(
562
+ seed: int,
563
+ *extra_entropy: int,
564
+ seeds_to_generate: int,
565
+ is_greedy: bool,
566
+ ):
567
+ """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
568
+ if not is_greedy:
569
+ if seed is None:
570
+ randint_fn = random.randint
571
+ else:
572
+ generator = random.Random(str((seed, ) + extra_entropy))
573
+ randint_fn = generator.randint
574
+ lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
575
+ # If the user/random sets seed = 0 but request should
576
+ # have sampling, we need to change it to something
577
+ # else. We use a constant in that case.
578
+ # This way we don't need to create and load a bool
579
+ # matrix in the sampling kernel, which reduces CPU
580
+ # overhead and latency.
581
+ seq_seeds = [
582
+ randint_fn(lo, hi) or _SEED_0_REPLACEMENT
583
+ for _ in range(seeds_to_generate)
584
+ ]
585
+ else:
586
+ # For the kernel, seed == 0 means greedy decoding.
587
+ seq_seeds = [0] * seeds_to_generate
588
+ return seq_seeds
@@ -0,0 +1,35 @@
1
+ """Utils for model executor."""
2
+ import random
3
+ from typing import Any, Dict, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def set_random_seed(seed: int) -> None:
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ if torch.cuda.is_available():
14
+ torch.cuda.manual_seed_all(seed)
15
+
16
+
17
+ def set_weight_attrs(
18
+ weight: torch.Tensor,
19
+ weight_attrs: Optional[Dict[str, Any]],
20
+ ):
21
+ """Set attributes on a weight tensor.
22
+
23
+ This method is used to set attributes on a weight tensor. This method
24
+ will not overwrite existing attributes.
25
+
26
+ Args:
27
+ weight: The weight tensor.
28
+ weight_attrs: A dictionary of attributes to set on the weight tensor.
29
+ """
30
+ if weight_attrs is None:
31
+ return
32
+ for key, value in weight_attrs.items():
33
+ assert not hasattr(
34
+ weight, key), (f"Overwriting existing tensor attribute: {key}")
35
+ setattr(weight, key, value)