vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,115 @@
1
+ """A layer that compute logits from hidden_stats."""
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vllm.distributed import tensor_model_parallel_gather
8
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
9
+
10
+
11
+ class LogitsProcessor(nn.Module):
12
+ """Process logits and apply logits processors from sampling metadata.
13
+
14
+ This layer does the following:
15
+ 1. Gather logits from model hidden_states.
16
+ 2. Scale logits if needed.
17
+ 3. Apply logits processors (if any).
18
+ """
19
+
20
+ def __init__(self,
21
+ vocab_size: int,
22
+ org_vocab_size: Optional[int] = None,
23
+ scale: Optional[float] = 1.0,
24
+ logits_as_input: bool = False) -> None:
25
+ """
26
+ Args:
27
+ scale: A scaling factor to apply to the logits.
28
+ """
29
+ super().__init__()
30
+ self.scale = scale
31
+ self.vocab_size = vocab_size
32
+ # Whether the input is logits (default is hidden states).
33
+ self.logits_as_input = logits_as_input
34
+ # original vocabulary size (without LoRA).
35
+ self.org_vocab_size = org_vocab_size or vocab_size
36
+
37
+ def forward(
38
+ self,
39
+ embedding: torch.Tensor,
40
+ hidden_states: torch.Tensor,
41
+ sampling_metadata: SamplingMetadata,
42
+ embedding_bias: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ if self.logits_as_input:
45
+ logits = hidden_states
46
+ else:
47
+ hidden_states = _prune_hidden_states(hidden_states,
48
+ sampling_metadata)
49
+
50
+ # Get the logits for the next tokens.
51
+ logits = self._get_logits(hidden_states, embedding, embedding_bias)
52
+
53
+ if logits is not None:
54
+ logits *= self.scale
55
+
56
+ # Apply logits processors (if any).
57
+ logits = _apply_logits_processors(logits, sampling_metadata)
58
+
59
+ return logits
60
+
61
+ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
62
+ embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
63
+ # Get the logits for the next tokens.
64
+ logits = torch.matmul(hidden_states, embedding.t())
65
+ if embedding_bias is not None:
66
+ logits += embedding_bias
67
+ logits = tensor_model_parallel_gather(logits)
68
+ # Remove paddings in vocab (if any).
69
+ if logits is not None:
70
+ logits = logits[:, :self.org_vocab_size]
71
+ return logits
72
+
73
+ def extra_repr(self) -> str:
74
+ s = f"vocab_size={self.vocab_size}"
75
+ s += f", forg_vocab_size={self.org_vocab_size}"
76
+ s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
77
+ return s
78
+
79
+
80
+ def _prune_hidden_states(
81
+ hidden_states: torch.Tensor,
82
+ sampling_metadata: SamplingMetadata,
83
+ ) -> torch.Tensor:
84
+ return hidden_states.index_select(0,
85
+ sampling_metadata.selected_token_indices)
86
+
87
+
88
+ def _apply_logits_processors(
89
+ logits: torch.Tensor,
90
+ sampling_metadata: SamplingMetadata,
91
+ ) -> torch.Tensor:
92
+ found_logits_processors = False
93
+ logits_processed = 0
94
+ for seq_group in sampling_metadata.seq_groups:
95
+ seq_ids = seq_group.seq_ids
96
+ sampling_params = seq_group.sampling_params
97
+ logits_processors = sampling_params.logits_processors
98
+
99
+ if logits_processors:
100
+ found_logits_processors = True
101
+ for seq_id, logits_row_idx in zip(seq_ids,
102
+ seq_group.sample_indices):
103
+ logits_row = logits[logits_row_idx]
104
+ token_ids = seq_group.seq_data[seq_id].output_token_ids
105
+ for logits_processor in logits_processors:
106
+ logits_row = logits_processor(token_ids, logits_row)
107
+ logits[logits_row_idx] = logits_row
108
+
109
+ logits_processed += len(seq_group.sample_indices) + len(
110
+ seq_group.prompt_logprob_indices)
111
+
112
+ if found_logits_processors:
113
+ # verifies that no rows in logits were missed unexpectedly
114
+ assert logits_processed == logits.shape[0]
115
+ return logits
File without changes
@@ -0,0 +1,157 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ def seeded_uniform(
9
+ *size,
10
+ seeds: torch.Tensor,
11
+ out: Optional[torch.Tensor] = None,
12
+ dtype: Optional[torch.dtype] = None,
13
+ device: Optional[Union[torch.device, str]] = None,
14
+ pin_memory: Optional[bool] = False,
15
+ ) -> torch.Tensor:
16
+ """Similar to torch.rand, but allows for seeds to be set per row.
17
+
18
+ seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
19
+ If it is 3d, the additional seeds needed will be derived automatically
20
+ in a deterministic fashion:
21
+ [
22
+ row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
23
+ ]
24
+ """
25
+ n_dims = len(size)
26
+
27
+ if n_dims > 3:
28
+ raise ValueError("seeded_uniform only supports up to 3D tensors")
29
+
30
+ if out is None:
31
+ out = torch.empty(*size,
32
+ dtype=dtype,
33
+ device=device,
34
+ pin_memory=pin_memory)
35
+ elif out.shape != size:
36
+ raise ValueError("shape of out and size must be the same")
37
+
38
+ if n_dims == 3:
39
+ n_rows, n_3d, n_cols = out.shape
40
+ stride_row = out.stride(0)
41
+ stride_3d = out.stride(1)
42
+ elif n_dims == 2:
43
+ n_rows, n_cols = out.shape
44
+ n_3d = 1
45
+ stride_row = out.stride(0)
46
+ stride_3d = 1
47
+ else:
48
+ n_cols = out.shape[0]
49
+ n_rows = 1
50
+ n_3d = 1
51
+ stride_row = 1
52
+ stride_3d = 1
53
+
54
+ if seeds.ndim != 1:
55
+ raise ValueError("seeds must be a 1D tensor")
56
+
57
+ if seeds.numel() != n_rows:
58
+ raise ValueError(
59
+ "seeds must have the same number of elements as out has rows")
60
+
61
+ # The philox PRNG Triton uses generates 4 random numbers at once.
62
+ # Therefore, the most efficient use of it is to divide the
63
+ # block size by 4, and then save the generated random numbers to
64
+ # each of the 4 slices of the tensor.
65
+ full_block_size = triton.next_power_of_2(n_cols)
66
+ philox_block_size = max(full_block_size // 4, 1)
67
+ n_slices = full_block_size // philox_block_size
68
+ num_warps = 4
69
+ # Manual tuning. This seems to give best performance on A100 for
70
+ # simple kernels like this.
71
+ if philox_block_size >= 8192:
72
+ num_warps = 32
73
+ elif philox_block_size >= 4096:
74
+ num_warps = 16
75
+ elif philox_block_size >= 2048:
76
+ num_warps = 8
77
+
78
+ _seeded_uniform_triton[(n_rows, n_3d)](
79
+ out,
80
+ seeds,
81
+ stride_row,
82
+ stride_3d,
83
+ seeds.stride(0),
84
+ n_rows,
85
+ n_3d,
86
+ n_cols,
87
+ n_slices=n_slices,
88
+ num_warps=num_warps,
89
+ block_size=philox_block_size,
90
+ )
91
+ return out
92
+
93
+
94
+ @triton.jit
95
+ def _seeded_uniform_triton(
96
+ out_ptr: torch.Tensor,
97
+ seed_ptr: torch.Tensor,
98
+ out_row_stride: int,
99
+ out_3d_stride: int,
100
+ seed_row_stride: int,
101
+ n_rows: int,
102
+ n_3d: int,
103
+ n_cols: int,
104
+ n_slices: tl.constexpr,
105
+ block_size: tl.constexpr,
106
+ ):
107
+ """
108
+ Generate a random float32 number in [0, 1) for each element in the output
109
+ tensor. The random numbers in a row generated using the seed for that row.
110
+
111
+ Args:
112
+ out_ptr: The output tensor.
113
+ seed_ptr: The per-row seeds to use for random number generation.
114
+ out_row_stride: The stride between rows of the output tensor.
115
+ out_3d_stride: The stride between 3D slices of the output tensor.
116
+ seed_row_stride: The stride between rows of the seed tensor.
117
+ n_rows: The number of rows in the output tensor.
118
+ n_3d: The size of second dimension of the output tensor,
119
+ if output tensor is 3D.
120
+ n_cols: The number of columns in the output tensor.
121
+ n_slices: The number of philox outputs to use.
122
+ """
123
+ tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
124
+
125
+ # Get the row index.
126
+ row_idx = tl.program_id(axis=0)
127
+ three_d_idx = tl.program_id(axis=1)
128
+
129
+ philox_offsets = tl.arange(0, block_size)
130
+ # Get the seed for the current element.
131
+ seed = tl.load(seed_ptr + row_idx * seed_row_stride)
132
+ if three_d_idx > 0:
133
+ seed ^= three_d_idx
134
+ # Generate random numbers in [0, 1).
135
+ out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
136
+
137
+ output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
138
+ three_d_idx * out_3d_stride)
139
+ out1_offsets = philox_offsets
140
+ tl.store(output_row_start_ptr + out1_offsets,
141
+ out1,
142
+ mask=out1_offsets < n_cols)
143
+ if n_slices > 1:
144
+ out2_offsets = tl.arange(block_size, block_size * 2)
145
+ tl.store(output_row_start_ptr + out2_offsets,
146
+ out2,
147
+ mask=out2_offsets < n_cols)
148
+ if n_slices > 2:
149
+ out3_offsets = tl.arange(block_size * 2, block_size * 3)
150
+ tl.store(output_row_start_ptr + out3_offsets,
151
+ out3,
152
+ mask=out3_offsets < n_cols)
153
+ if n_slices > 3:
154
+ out4_offsets = tl.arange(block_size * 3, block_size * 4)
155
+ tl.store(output_row_start_ptr + out4_offsets,
156
+ out4,
157
+ mask=out4_offsets < n_cols)
@@ -0,0 +1,406 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from vllm.model_executor.layers.ops.rand import seeded_uniform
9
+
10
+ _EPS = 1e-6
11
+
12
+ # This is a hardcoded limit in Triton (max block size).
13
+ MAX_TRITON_N_COLS = 131072
14
+
15
+
16
+ def get_num_triton_sampler_splits(n_cols: int) -> int:
17
+ """Get the number of splits to use for Triton sampling.
18
+
19
+ Triton has a limit on the number of columns it can handle, so we need to
20
+ split the tensor and call the kernel multiple times if it's too large.
21
+ """
22
+ return math.ceil(n_cols / MAX_TRITON_N_COLS)
23
+
24
+
25
+ def _multi_split_sample(
26
+ probs: torch.Tensor,
27
+ seeds: torch.Tensor,
28
+ n_splits: int,
29
+ sampled_tokens_size: Tuple[int, int],
30
+ sampled_logprobs_size: Tuple[int, int],
31
+ sample_indices: torch.Tensor,
32
+ logprobs: torch.Tensor,
33
+ *,
34
+ modify_greedy_probs: bool = False,
35
+ save_logprobs: bool = False,
36
+ ):
37
+ """Sample tokens where vocab size is split into multiple parts
38
+ (too large for Triton otherwise)."""
39
+ assert seeds.ndim == 2 and seeds.shape[0] == n_splits
40
+ split_probs = probs.tensor_split(n_splits, 1)
41
+ split_logprobs = logprobs.tensor_split(n_splits, 1)
42
+ sampled_tokens_tmp = [
43
+ torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
44
+ for _ in range(n_splits)
45
+ ]
46
+ sampled_logprobs_tmp = [
47
+ torch.empty(sampled_logprobs_size,
48
+ dtype=probs.dtype,
49
+ device=probs.device) for _ in range(n_splits)
50
+ ]
51
+ # We are purposefuly using sampled_tokens_size as we need to always
52
+ # save modified probs in this case.
53
+ sampled_modified_probs_tmp = [
54
+ torch.empty(sampled_tokens_size,
55
+ dtype=probs.dtype,
56
+ device=probs.device) for _ in range(n_splits)
57
+ ]
58
+ for i in range(n_splits):
59
+ n_samples = sample_indices.shape[0]
60
+ n_cols = split_probs[i].shape[1]
61
+ n_best = sampled_tokens_tmp[i].shape[1]
62
+ uniform_noise = seeded_uniform(n_samples,
63
+ n_best,
64
+ n_cols,
65
+ seeds=seeds[i].flatten(),
66
+ device=split_probs[i].device,
67
+ dtype=split_probs[i].dtype)
68
+ # TODO(yard1): See if we can remove the contiguous() calls.
69
+ # Will need kernel support.
70
+ _sample(
71
+ split_probs[i].contiguous(),
72
+ split_logprobs[i].contiguous(),
73
+ sample_indices,
74
+ sampled_tokens_tmp[i],
75
+ sampled_logprobs_tmp[i],
76
+ sampled_modified_probs_tmp[i],
77
+ seeds[i],
78
+ uniform_noise,
79
+ modify_greedy_probs=False,
80
+ save_logprobs=save_logprobs,
81
+ save_modified_probs=True,
82
+ )
83
+ if i > 0:
84
+ # Add offset to sampled tokens
85
+ sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
86
+ sampled_tokens = torch.stack(sampled_tokens_tmp)
87
+ sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
88
+ # Reduce the results from the splits.
89
+ sampled_modified_probs, indices = torch.max(sampled_modified_probs,
90
+ dim=0,
91
+ keepdim=True)
92
+ sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
93
+ if save_logprobs:
94
+ sampled_logprobs = torch.stack(sampled_logprobs_tmp)
95
+ sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
96
+ else:
97
+ sampled_logprobs = None
98
+ sampled_modified_probs = sampled_modified_probs.squeeze(0)
99
+
100
+ if modify_greedy_probs:
101
+ # We need to modify the greedy probs for the sampled tokens.
102
+ # We can't do this in the kernel as we need to know the
103
+ # sampled tokens.
104
+ probs.fill_(0.0)
105
+ probs.scatter_(1, sampled_tokens, 1.0)
106
+
107
+ return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
108
+
109
+
110
+ def sample(
111
+ probs: torch.Tensor,
112
+ seeds: torch.Tensor,
113
+ *,
114
+ max_best_of: int = 1,
115
+ sample_indices: Optional[torch.Tensor] = None,
116
+ logprobs: Optional[torch.Tensor] = None,
117
+ modify_greedy_probs: bool = False,
118
+ save_logprobs: bool = False,
119
+ _save_modified_probs: bool = False, # pylint: disable=invalid-name
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
121
+ """Sample tokens from probs. with per-sequence seeds.
122
+
123
+ Can sample from a subset of sequences through sample_indices.
124
+
125
+ Args:
126
+ probs: Probabilities to sample from.
127
+ shape = [batch_size, vocab_size]
128
+ seeds: Per-sequence seed values.
129
+ shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
130
+ max_best_of: Number of samples to generate per sequence.
131
+ Sequence seed will be incremented by 1 each time.
132
+ sample_indices: Indices of sequences to sample from.
133
+ If not provided, will sample from all sequences.
134
+ shape = [n]
135
+ logprobs: Log-probabilities of the sampled tokens.
136
+ Only used for saving the logprobs if save_logprobs is True.
137
+ shape = [batch_size, vocab_size]
138
+ modify_greedy_probs: Whether to modify the greedy probabilities
139
+ for speculative sampling (sampled token = 1.0,
140
+ everything else = 0.0).
141
+ save_logprobs: Whether to save the log-probabilities of the
142
+ sampled tokens to a tensor.
143
+ _save_modified_probs: Whether to save the modified probabilities
144
+ (including gumbel noise) of the sampled tokens to a tensor.
145
+ DOES NOT include the modification done by modify_greedy_probs
146
+ (because we want to use the unmodified probs to pick the best
147
+ split in case of multi-split sampling).
148
+ This is exposed only for testing.
149
+
150
+ Returns:
151
+ sampled_tokens: shape = [n, max_best_of]
152
+ sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
153
+ sampled_modified_probs: shape = [n, max_best_of]
154
+ if save_modified_probs else None
155
+ """
156
+ if sample_indices is None:
157
+ sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
158
+
159
+ sampled_tokens_size = (sample_indices.size(0), max_best_of)
160
+ if save_logprobs:
161
+ if logprobs is None:
162
+ raise ValueError(
163
+ "logprobs tensor must be provided if save_logprobs is True")
164
+ sampled_logprobs_size = sampled_tokens_size
165
+ else:
166
+ # Empty tensors to invoke the kernel
167
+ sampled_logprobs_size = (0, 0)
168
+ logprobs = probs
169
+
170
+ assert logprobs is not None
171
+ if _save_modified_probs:
172
+ sampled_modified_probs_size = sampled_tokens_size
173
+ else:
174
+ # Empty tensors to invoke the kernel
175
+ sampled_modified_probs_size = (0, 0)
176
+
177
+ # If the number of columns in probs is too large for Triton to handle,
178
+ # we split the tensor and sample from each split separately, and then
179
+ # do an argmax+gather to combine the results.
180
+ n_splits = get_num_triton_sampler_splits(probs.shape[1])
181
+ if n_splits > 1:
182
+ (sampled_tokens, sampled_logprobs,
183
+ sampled_modified_probs) = _multi_split_sample(
184
+ probs,
185
+ seeds,
186
+ n_splits,
187
+ sampled_tokens_size,
188
+ sampled_logprobs_size,
189
+ sample_indices,
190
+ logprobs=logprobs,
191
+ modify_greedy_probs=modify_greedy_probs,
192
+ save_logprobs=save_logprobs)
193
+ else:
194
+ sampled_tokens = torch.empty(sampled_tokens_size,
195
+ dtype=torch.long,
196
+ device=probs.device)
197
+ sampled_logprobs = torch.empty(sampled_logprobs_size,
198
+ dtype=probs.dtype,
199
+ device=probs.device)
200
+ sampled_modified_probs = torch.empty(sampled_modified_probs_size,
201
+ dtype=probs.dtype,
202
+ device=probs.device)
203
+ n_samples = sample_indices.shape[0]
204
+ n_cols = probs.shape[1]
205
+ uniform_noise = seeded_uniform(n_samples,
206
+ max_best_of,
207
+ n_cols,
208
+ seeds=seeds.flatten(),
209
+ device=probs.device,
210
+ dtype=probs.dtype)
211
+
212
+ _sample(
213
+ probs,
214
+ logprobs,
215
+ sample_indices,
216
+ sampled_tokens,
217
+ sampled_logprobs,
218
+ sampled_modified_probs,
219
+ seeds,
220
+ uniform_noise,
221
+ modify_greedy_probs=modify_greedy_probs,
222
+ save_logprobs=save_logprobs,
223
+ save_modified_probs=_save_modified_probs,
224
+ )
225
+ return (sampled_tokens, sampled_logprobs if save_logprobs else None,
226
+ sampled_modified_probs if _save_modified_probs else None)
227
+
228
+
229
+ def _sample(probs: torch.Tensor,
230
+ logprobs: torch.Tensor,
231
+ sample_indices: torch.Tensor,
232
+ output_samples: torch.Tensor,
233
+ output_logprobs: torch.Tensor,
234
+ output_modified_probs: torch.Tensor,
235
+ seeds: torch.Tensor,
236
+ uniform_noise: torch.Tensor,
237
+ *,
238
+ modify_greedy_probs: bool = False,
239
+ save_logprobs: bool = True,
240
+ save_modified_probs: bool = False) -> torch.Tensor:
241
+ """Sample tokens from probs.
242
+
243
+ Args:
244
+ probs [batch_size, vocab_size]: probs to sample from.
245
+ logprobs [batch_size, vocab_size]: logprobs (used when
246
+ save_logprobsis True).
247
+ sample_indices [n]: Indices of the samples to use for each row of probs.
248
+ output_samples [n, n_best]: Output tensor to store samples in.
249
+ output_logprobs [n, n_best]: Output tensor to store logprobs in.
250
+ output_modified_probs [n, n_best]: Output tensor to store
251
+ probs of chosen tokens in (modified with noise).
252
+ seeds [n]: Seeds to use for sampling. If the seed is 0, we use
253
+ greedy sampling. Note this is ONLY used for determining
254
+ whether to use random sampling or not. The actual random
255
+ noise should be passed as uniform_noise.
256
+ uniform_noise [batch_size, n_best, vocab_size]: Uniform
257
+ noise to use for random sampling (will be converted
258
+ to exponential gumbel noise by the kernel).
259
+ modify_greedy_probs: If True, we modify the probs tensor in-place
260
+ to encode the sampling method used for each row. This is used
261
+ in speculative decoding. Only applies in greedy decoding.
262
+ save_logprobs: If True, we save the logprobs of the sampled tokens
263
+ in the output_logprobs tensor.
264
+ save_modified_probs: If True, we save the modified probs (with noise)
265
+ of the sampled tokens in the output_modified_probs tensor.
266
+ DOES NOT include the modification done by modify_greedy_probs
267
+ (because we want to use the unmodified probs to pick the best
268
+ split in case of multi-split sampling).
269
+ """
270
+ n_samples = sample_indices.shape[0]
271
+ n_cols = probs.shape[1]
272
+ n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
273
+
274
+ # The block size is the smallest power of two greater than the number of
275
+ # columns in probs
276
+ block_size = triton.next_power_of_2(n_cols)
277
+ num_warps = 4
278
+ # Manual tuning. This seems to give best performance on A100 for
279
+ # simple kernels like this.
280
+ if block_size >= 8192:
281
+ num_warps = 32
282
+ elif block_size >= 4096:
283
+ num_warps = 16
284
+ elif block_size >= 2048:
285
+ num_warps = 8
286
+
287
+ # Enqueue kernel. The 1D launch grid is simple: we have one kernel
288
+ # instance per row of the probs matrix
289
+ _sample_triton[(n_samples, n_best)](
290
+ sample_indices,
291
+ output_samples,
292
+ output_logprobs,
293
+ output_modified_probs,
294
+ probs,
295
+ logprobs,
296
+ seeds,
297
+ uniform_noise,
298
+ output_samples.stride(0),
299
+ probs.stride(0),
300
+ uniform_noise.stride(0),
301
+ uniform_noise.stride(1) if n_best > 1 else 1,
302
+ n_samples,
303
+ n_cols,
304
+ n_best,
305
+ num_warps=num_warps,
306
+ block_size=block_size,
307
+ modify_greedy_probs=modify_greedy_probs,
308
+ save_logprobs=save_logprobs,
309
+ save_modified_probs=save_modified_probs,
310
+ )
311
+ return output_samples, output_logprobs, output_modified_probs
312
+
313
+
314
+ @triton.jit
315
+ def _uniform_to_exponential(uniform_noise):
316
+ """Convert uniform samples to exponential samples."""
317
+ # tl.rand returns values in [0, 1), so we clamp lower bound
318
+ # to _EPS to avoid log(0) and thus division by 0 later
319
+ lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
320
+ uniform_noise = tl.maximum(uniform_noise, lb)
321
+ # Use the inversion method to turn uniform samples
322
+ # into exponential samples
323
+ exponential_noise = -tl.log(uniform_noise)
324
+ return exponential_noise
325
+
326
+
327
+ @triton.jit
328
+ def _sample_triton(
329
+ sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
330
+ output_logprobs_ptr: torch.Tensor,
331
+ output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
332
+ logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
333
+ uniform_noise_ptr: torch.Tensor, output_row_stride: int,
334
+ probs_row_stride: int, uniform_noise_row_stride: int,
335
+ uniform_noise_best_stride: int, n_samples: int, n_cols: int,
336
+ n_best: int, block_size: tl.constexpr,
337
+ modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
338
+ save_modified_probs: tl.constexpr):
339
+ # The rows are independent, so we parallelize across those
340
+ sample_idx = tl.program_id(0)
341
+ best_idx = tl.program_id(1)
342
+
343
+ # Load the row index from DRAM
344
+ row_idx = tl.load(sample_indices_ptr + sample_idx)
345
+ seed = tl.load(seeds_ptr + sample_idx)
346
+ uses_random_sampling = seed != 0
347
+
348
+ # The stride represents how much we need to increase the
349
+ # pointer to advance 1 row
350
+ row_start_ptr = probs_ptr + row_idx * probs_row_stride
351
+
352
+ # The block size is the next power of two greater than n_cols,
353
+ # so we can fit each row in a single block
354
+ col_offsets = tl.arange(0, block_size)
355
+
356
+ # Load the row into SRAM, using a mask since block_size may be > than n_cols
357
+ row = tl.load(row_start_ptr + col_offsets,
358
+ mask=col_offsets < n_cols,
359
+ other=float("-inf"))
360
+
361
+ if uses_random_sampling:
362
+ uniform_noise_start_ptr = (uniform_noise_ptr +
363
+ sample_idx * uniform_noise_row_stride +
364
+ best_idx * uniform_noise_best_stride)
365
+ uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
366
+ mask=col_offsets < n_cols,
367
+ other=0.5)
368
+ exponential_noise = _uniform_to_exponential(uniform_noise)
369
+ row /= exponential_noise
370
+
371
+ sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
372
+ # clamp sampled token to n_cols - 1
373
+ # this should not be necessary, but we do it
374
+ # just in case
375
+ if sampled_token >= n_cols:
376
+ sampled_token = n_cols - 1
377
+ # Write back output to DRAM
378
+ output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
379
+ best_idx)
380
+ tl.store(output_row_start_ptr, sampled_token)
381
+
382
+ if modify_greedy_probs: # noqa
383
+ if not uses_random_sampling:
384
+ # Set the probability of the sampled token to 1, all other
385
+ # tokens to zero. This is used in speculative decoding where
386
+ # the sampling method must be encoded within the sampled
387
+ # probability distributions.
388
+ row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
389
+ tl.store(row_start_ptr + col_offsets,
390
+ row,
391
+ mask=col_offsets < n_cols)
392
+
393
+ if save_modified_probs:
394
+ output_row_start_ptr = (output_modified_probs_ptr +
395
+ sample_idx * output_row_stride + best_idx)
396
+ tl.store(output_row_start_ptr, sampled_value)
397
+
398
+ if save_logprobs:
399
+ # Load the row into SRAM, using a mask since block_size
400
+ # may be > than n_cols
401
+ sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
402
+ sampled_token)
403
+ # Write back output to DRAM
404
+ output_row_start_ptr = (output_logprobs_ptr +
405
+ sample_idx * output_row_stride + best_idx)
406
+ tl.store(output_row_start_ptr, sampled_logprob)