vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,220 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Set, Tuple, Type
3
+
4
+ try:
5
+ import flashinfer
6
+ from flash_attn import flash_attn_varlen_func
7
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
8
+ except ImportError:
9
+ flashinfer = None
10
+ flash_attn_varlen_func = None
11
+ BatchDecodeWithPagedKVCacheWrapper = None
12
+
13
+ import torch
14
+
15
+ from vllm import _custom_ops as ops
16
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
17
+ AttentionMetadata,
18
+ AttentionMetadataPerStage)
19
+
20
+
21
+ class FlashInferBackend(AttentionBackend):
22
+
23
+ @staticmethod
24
+ def get_impl_cls() -> Type["FlashInferImpl"]:
25
+ return FlashInferImpl
26
+
27
+ @staticmethod
28
+ def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
29
+ return FlashInferMetadata(*args, **kwargs)
30
+
31
+ @staticmethod
32
+ def get_kv_cache_shape(
33
+ num_blocks: int,
34
+ block_size: int,
35
+ num_kv_heads: int,
36
+ head_size: int,
37
+ ) -> Tuple[int, ...]:
38
+ return (num_blocks, 2, block_size, num_kv_heads, head_size)
39
+
40
+ @staticmethod
41
+ def swap_blocks(
42
+ src_kv_cache: torch.Tensor,
43
+ dst_kv_cache: torch.Tensor,
44
+ src_to_dst: Dict[int, int],
45
+ ) -> None:
46
+ raise NotImplementedError
47
+
48
+ @staticmethod
49
+ def copy_blocks(
50
+ kv_caches: List[torch.Tensor],
51
+ src_to_dists: Dict[int, List[int]],
52
+ ) -> None:
53
+ raise NotImplementedError
54
+
55
+ @staticmethod
56
+ def get_supported_head_sizes() -> List[int]:
57
+ return [64, 128, 256]
58
+
59
+
60
+ @dataclass
61
+ class FlashInferMetadata(AttentionMetadataPerStage):
62
+
63
+ is_prompt: bool
64
+
65
+ use_cuda_graph: bool = False
66
+
67
+ decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
68
+
69
+ # Metadata for the prefill stage since we still
70
+ # use flash attention for prefill.
71
+ seq_start_loc: Optional[torch.Tensor] = None
72
+ max_seq_len: Optional[int] = None
73
+ block_tables: Optional[torch.Tensor] = None
74
+
75
+ # Metadata for the decode stage
76
+ # Workspace buffer required by the kernel, the buffer should not
77
+ # be allocated/deacollated by the FalshInfermetadata object.
78
+ workspace_buffer: Optional[torch.Tensor] = None
79
+ # An example for paged_kv_indices, paged_kv_indptr:
80
+ # request 1, page indices [0, 5, 8]
81
+ # request 2, page indices [1, 6, 7]
82
+ # request 3, page indices [3, 4]
83
+ # paged_kv_indices is a concatenation of page indices of all requests:
84
+ # [0, 5, 8, 1, 6, 7, 3, 4]
85
+ # paged_kv_indptr is used to index into paged_kv_indices:
86
+ # [0, 3, 6, 8]
87
+ # The indptr of the paged kv cache, shape: [batch_size + 1]
88
+ paged_kv_indptr: Optional[torch.Tensor] = None
89
+ # The page indices of the paged kv cache
90
+ paged_kv_indices: Optional[torch.Tensor] = None
91
+ # The number of entries in the last page of each request in
92
+ # the paged kv cache, shape: [batch_size]
93
+ paged_kv_last_page_len: Optional[torch.Tensor] = None
94
+ # The number of query/output heads
95
+ num_qo_heads: Optional[int] = None
96
+ # The number of key/value heads
97
+ num_kv_heads: Optional[int] = None
98
+ # The dimension of the attention heads
99
+ head_dim: Optional[int] = None
100
+ # Block size of vllm
101
+ page_size: Optional[int] = None
102
+ # The data type of the paged kv cache
103
+ data_type: torch.dtype = None
104
+
105
+ def __post_init__(self):
106
+ # Refer to
107
+ # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
108
+ supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
109
+ if self.head_dim is not None and self.head_dim \
110
+ not in supported_head_sizes:
111
+ raise ValueError(
112
+ f"Only {supported_head_sizes} are supported for head_dim,",
113
+ f"received {self.head_dim}.")
114
+
115
+ # When using flashinfer, we are also creating the FlashInferMetadata,
116
+ # which will also call post_init by default, here we want to skip the
117
+ # post_init if it's the prefill phase.
118
+ if not self.is_prompt:
119
+ self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
120
+ self.workspace_buffer, "NHD")
121
+ self.decode_wrapper.begin_forward(
122
+ self.paged_kv_indptr,
123
+ self.paged_kv_indices,
124
+ self.paged_kv_last_page_len,
125
+ self.num_qo_heads,
126
+ self.num_kv_heads,
127
+ self.head_dim,
128
+ self.page_size,
129
+ # Disable flashinfer's pos encoding and use vllm's rope.
130
+ pos_encoding_mode="NONE",
131
+ data_type=self.data_type)
132
+
133
+ def asdict_zerocopy(self,
134
+ skip_fields: Optional[Set[str]] = None
135
+ ) -> Dict[str, Any]:
136
+ if skip_fields is None:
137
+ skip_fields = set()
138
+ # We need to skip the decode_wrapper field since it cannot be
139
+ # broadcasted with nccl when TP is enabled.
140
+ skip_fields.add('decode_wrapper')
141
+ return super().asdict_zerocopy(skip_fields)
142
+
143
+
144
+ class FlashInferImpl(AttentionImpl):
145
+
146
+ def __init__(
147
+ self,
148
+ num_heads: int,
149
+ head_size: int,
150
+ scale: float,
151
+ num_kv_heads: Optional[int] = None,
152
+ alibi_slopes: Optional[List[float]] = None,
153
+ sliding_window: Optional[int] = None,
154
+ ) -> None:
155
+ if sliding_window is not None:
156
+ raise ValueError("Sliding window is not supported in FlashInfer.")
157
+ self.sliding_window = (-1, -1)
158
+ self.alibi_slopes = alibi_slopes
159
+ self.scale = scale
160
+ self.num_heads = num_heads
161
+ self.head_size = head_size
162
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
163
+
164
+ def forward(self, query: torch.Tensor, key: torch.Tensor,
165
+ value: torch.Tensor, kv_cache: Optional[torch.Tensor],
166
+ attn_metadata: AttentionMetadata[FlashInferMetadata],
167
+ kv_scale: float):
168
+ num_tokens, hidden_size = query.shape
169
+ query = query.view(-1, self.num_heads, self.head_size)
170
+ key = key.view(-1, self.num_kv_heads, self.head_size)
171
+ value = value.view(-1, self.num_kv_heads, self.head_size)
172
+
173
+ if attn_metadata.num_prefill_tokens > 0:
174
+ assert attn_metadata.num_decode_tokens == 0, (
175
+ "Chunked prefill is not supported with flashinfer yet.")
176
+ if attn_metadata.num_decode_tokens > 0:
177
+ assert attn_metadata.num_prefill_tokens == 0, (
178
+ "Chunked prefill is not supported with flashinfer yet.")
179
+
180
+ if kv_cache is not None:
181
+ # Use the same reshape and cache kernel as flash attention.
182
+ ops.reshape_and_cache_flash(
183
+ key,
184
+ value,
185
+ kv_cache[:, 0],
186
+ kv_cache[:, 1],
187
+ attn_metadata.slot_mapping.flatten(),
188
+ attn_metadata.kv_cache_dtype,
189
+ )
190
+
191
+ if prefill_meta := attn_metadata.prefill_metadata:
192
+ assert prefill_meta.block_tables is not None
193
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
194
+ output = flash_attn_varlen_func(
195
+ q=query,
196
+ k=key,
197
+ v=value,
198
+ cu_seqlens_q=prefill_meta.seq_start_loc,
199
+ cu_seqlens_k=prefill_meta.seq_start_loc,
200
+ max_seqlen_q=prefill_meta.max_seq_len,
201
+ max_seqlen_k=prefill_meta.max_seq_len,
202
+ softmax_scale=self.scale,
203
+ causal=True,
204
+ window_size=self.sliding_window,
205
+ alibi_slopes=self.alibi_slopes,
206
+ )
207
+ else:
208
+ raise NotImplementedError(
209
+ "Prefix caching is not supported with flashinfer yet.")
210
+ else:
211
+ assert attn_metadata.decode_metadata is not None
212
+ assert attn_metadata.decode_metadata.decode_wrapper is not None
213
+ query = query.contiguous(
214
+ ) # Flashinfer requires query to be contiguous
215
+ output = attn_metadata.decode_metadata.decode_wrapper.forward(
216
+ query,
217
+ kv_cache,
218
+ sm_scale=self.scale,
219
+ )
220
+ return output.view(num_tokens, hidden_size)
@@ -0,0 +1,374 @@
1
+ """Attention layer ROCm GPUs."""
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Type
4
+
5
+ import torch
6
+
7
+ import vllm.envs as envs
8
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
9
+ AttentionMetadata,
10
+ AttentionMetadataPerStage)
11
+ from vllm.attention.ops.paged_attn import (PagedAttention,
12
+ PagedAttentionMetadata)
13
+ from vllm.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class ROCmFlashAttentionBackend(AttentionBackend):
19
+
20
+ @staticmethod
21
+ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
22
+ return ROCmFlashAttentionImpl
23
+
24
+ @staticmethod
25
+ def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
26
+ return ROCmFlashAttentionMetadata(*args, **kwargs)
27
+
28
+ @staticmethod
29
+ def get_kv_cache_shape(
30
+ num_blocks: int,
31
+ block_size: int,
32
+ num_kv_heads: int,
33
+ head_size: int,
34
+ ) -> Tuple[int, ...]:
35
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
36
+ num_kv_heads, head_size)
37
+
38
+ @staticmethod
39
+ def swap_blocks(
40
+ src_kv_cache: torch.Tensor,
41
+ dst_kv_cache: torch.Tensor,
42
+ src_to_dst: Dict[int, int],
43
+ ) -> None:
44
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
45
+
46
+ @staticmethod
47
+ def copy_blocks(
48
+ kv_caches: List[torch.Tensor],
49
+ src_to_dists: Dict[int, List[int]],
50
+ ) -> None:
51
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
52
+
53
+
54
+ @dataclass
55
+ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
56
+ PagedAttentionMetadata):
57
+ """Metadata for FlashAttentionBackend.
58
+
59
+ NOTE: Any python object stored here is not updated when it is
60
+ cuda-graph replayed. If you have values that need to be changed
61
+ dynamically, it should be stored in tensor. The tensor has to be
62
+ updated from `CUDAGraphRunner.forward` API.
63
+ """
64
+ # Currently, input sequences can only contain all prompts
65
+ # or all decoding. True if all sequences are prompts.
66
+ is_prompt: bool
67
+ # (batch_size,). The sequence length per sequence. Sequence length means
68
+ # the computed tokens + new tokens None if it is a decoding.
69
+ seq_lens: Optional[List[int]]
70
+ # seq_lens stored as a tensor.
71
+ seq_lens_tensor: Optional[torch.Tensor]
72
+
73
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
74
+ # |---------- N-1 iteration --------|
75
+ # |---------------- N iteration ---------------------|
76
+ # |- tokenA -|......................|-- newTokens ---|
77
+ # |---------- context_len ----------|
78
+ # |-------------------- seq_len ----------------------|
79
+ # |-- query_len ---|
80
+
81
+ # Maximum query length in the batch.
82
+ max_query_len: Optional[int]
83
+ # Maximum sequence length in the batch.
84
+ max_seq_len: Optional[int]
85
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
86
+ # the batch, used to index into subquery. E.g., if the subquery length
87
+ # is [4, 6], it is [0, 4, 10].
88
+ subquery_start_loc: Optional[torch.Tensor]
89
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
90
+ # the batch, used to index into sequence. E.g., if the sequence length is
91
+ # [4, 6], it is [0, 4, 10].
92
+ seq_start_loc: Optional[torch.Tensor]
93
+
94
+ # Whether or not if cuda graph is enabled.
95
+ # Cuda-graph is currently enabled for decoding only.
96
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
97
+ use_cuda_graph: bool
98
+ # (batch_size,) A tensor of context lengths (tokens that are computed
99
+ # so far).
100
+ context_lens_tensor: Optional[torch.Tensor]
101
+
102
+
103
+ class ROCmFlashAttentionImpl(AttentionImpl):
104
+ """
105
+ If the input tensors contain prompt tokens, the layout is as follows:
106
+ |<--------------- num_prompt_tokens -------------->|
107
+ |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
108
+
109
+ Otherwise, the layout is as follows:
110
+ |<------------------ num_generation_tokens (M) ----------------->|
111
+ |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
112
+
113
+ Generation tokens can contain padding when cuda-graph is used.
114
+ Currently, prompt tokens don't contain any padding.
115
+
116
+ The prompts might have different lengths, while the generation tokens
117
+ always have length 1.
118
+
119
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
120
+ batched together in a flattened 1D query.
121
+
122
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
123
+ |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
124
+
125
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
126
+ padding between prefill and decode tokens.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ num_heads: int,
132
+ head_size: int,
133
+ scale: float,
134
+ num_kv_heads: Optional[int] = None,
135
+ alibi_slopes: Optional[List[float]] = None,
136
+ sliding_window: Optional[int] = None,
137
+ ) -> None:
138
+ self.num_heads = num_heads
139
+ self.head_size = head_size
140
+ self.scale = float(scale)
141
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
142
+ self.sliding_window = ((sliding_window, sliding_window)
143
+ if sliding_window is not None else (-1, -1))
144
+ if alibi_slopes is not None:
145
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
146
+ self.alibi_slopes = alibi_slopes
147
+
148
+ assert self.num_heads % self.num_kv_heads == 0
149
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
150
+
151
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
152
+ if head_size not in suppored_head_sizes:
153
+ raise ValueError(
154
+ f"Head size {head_size} is not supported by PagedAttention. "
155
+ f"Supported head sizes are: {suppored_head_sizes}.")
156
+
157
+ self.use_naive_attn = False
158
+ # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
159
+ self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
160
+ if self.use_triton_flash_attn:
161
+ from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
162
+ triton_attention)
163
+ self.attn_func = triton_attention
164
+ logger.debug("Using Triton FA in ROCmBackend")
165
+ else:
166
+ # if not using triton, navi3x not use flash-attn either
167
+ if torch.cuda.get_device_capability()[0] == 11:
168
+ self.use_naive_attn = True
169
+ else:
170
+ try:
171
+ from flash_attn import flash_attn_varlen_func # noqa: F401
172
+ self.attn_func = flash_attn_varlen_func
173
+ logger.debug("Using CK FA in ROCmBackend")
174
+ except ModuleNotFoundError:
175
+ self.use_naive_attn = True
176
+
177
+ if self.use_naive_attn:
178
+ self.attn_func = _naive_attention
179
+ logger.debug("Using naive attention in ROCmBackend")
180
+
181
+ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
182
+ """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
183
+ tokens, n_kv_heads, head_dim = x.shape
184
+ return (x[:, :,
185
+ None, :].expand(tokens, n_kv_heads, n_rep,
186
+ head_dim).reshape(tokens, n_kv_heads * n_rep,
187
+ head_dim))
188
+
189
+ def forward(
190
+ self,
191
+ query: torch.Tensor,
192
+ key: torch.Tensor,
193
+ value: torch.Tensor,
194
+ kv_cache: torch.Tensor,
195
+ attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
196
+ kv_scale: float = 1.0,
197
+ ) -> torch.Tensor:
198
+ """Forward pass with FlashAttention and PagedAttention.
199
+
200
+ Args:
201
+ query: shape = [num_tokens, num_heads * head_size]
202
+ key: shape = [num_tokens, num_kv_heads * head_size]
203
+ value: shape = [num_tokens, num_kv_heads * head_size]
204
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
205
+ attn_metadata: Metadata for attention.
206
+ Returns:
207
+ shape = [num_tokens, num_heads * head_size]
208
+ """
209
+ num_tokens, hidden_size = query.shape
210
+ # Reshape the query, key, and value tensors.
211
+ query = query.view(-1, self.num_heads, self.head_size)
212
+ key = key.view(-1, self.num_kv_heads, self.head_size)
213
+ value = value.view(-1, self.num_kv_heads, self.head_size)
214
+
215
+ if kv_cache is not None:
216
+ key_cache, value_cache = PagedAttention.split_kv_cache(
217
+ kv_cache, self.num_kv_heads, self.head_size)
218
+
219
+ # Reshape the input keys and values and store them in the cache.
220
+ # If kv_cache is not provided, the new key and value tensors are
221
+ # not cached. This happens during the initial memory profiling run.
222
+ PagedAttention.write_to_paged_cache(
223
+ key,
224
+ value,
225
+ key_cache,
226
+ value_cache,
227
+ attn_metadata.slot_mapping,
228
+ attn_metadata.kv_cache_dtype,
229
+ kv_scale,
230
+ )
231
+
232
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
233
+ num_decode_tokens = attn_metadata.num_decode_tokens
234
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
235
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
236
+
237
+ output = torch.empty_like(query)
238
+ # Query for decode. KV is not needed because it is already cached.
239
+ decode_query = query[num_prefill_tokens:]
240
+ # QKV for prefill.
241
+ query = query[:num_prefill_tokens]
242
+ key = key[:num_prefill_tokens]
243
+ value = value[:num_prefill_tokens]
244
+
245
+ assert query.shape[0] == num_prefill_tokens
246
+ assert decode_query.shape[0] == num_decode_tokens
247
+
248
+ if prefill_meta := attn_metadata.prefill_metadata:
249
+ # Prompt run.
250
+ assert prefill_meta.seq_lens is not None
251
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
252
+ # triton attention
253
+ # When block_tables are not filled, it means q and k are the
254
+ # prompt, and they have the same length.
255
+ if self.use_triton_flash_attn:
256
+ out, _ = self.attn_func(
257
+ query,
258
+ key,
259
+ value,
260
+ None,
261
+ prefill_meta.seq_start_loc,
262
+ prefill_meta.seq_start_loc,
263
+ prefill_meta.max_seq_len,
264
+ prefill_meta.max_seq_len,
265
+ True,
266
+ self.scale,
267
+ )
268
+ elif self.use_naive_attn:
269
+ if self.num_kv_heads != self.num_heads:
270
+ # Interleave for MQA workaround.
271
+ key = self.repeat_kv(key, self.num_queries_per_kv)
272
+ value = self.repeat_kv(value, self.num_queries_per_kv)
273
+ out = self.attn_func(
274
+ query,
275
+ key,
276
+ value,
277
+ prefill_meta.seq_lens,
278
+ self.scale,
279
+ )
280
+ else:
281
+ out = self.attn_func(
282
+ q=query,
283
+ k=key,
284
+ v=value,
285
+ cu_seqlens_q=prefill_meta.seq_start_loc,
286
+ cu_seqlens_k=prefill_meta.seq_start_loc,
287
+ max_seqlen_q=prefill_meta.max_seq_len,
288
+ max_seqlen_k=prefill_meta.max_seq_len,
289
+ softmax_scale=self.scale,
290
+ causal=True,
291
+ )
292
+
293
+ # common code for prefill
294
+ assert output[:num_prefill_tokens].shape == out.shape
295
+ output[:num_prefill_tokens] = out
296
+ else:
297
+ # prefix-enabled attention
298
+ output[:num_prefill_tokens] = PagedAttention.forward_prefix(
299
+ query,
300
+ key,
301
+ value,
302
+ key_cache,
303
+ value_cache,
304
+ prefill_meta.block_tables,
305
+ prefill_meta.subquery_start_loc,
306
+ prefill_meta.seq_lens_tensor,
307
+ prefill_meta.context_lens_tensor,
308
+ prefill_meta.max_query_len,
309
+ self.alibi_slopes,
310
+ self.sliding_window[0],
311
+ )
312
+
313
+ if decode_meta := attn_metadata.decode_metadata:
314
+ # Decoding run.
315
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
316
+ decode_query,
317
+ key_cache,
318
+ value_cache,
319
+ decode_meta.block_tables,
320
+ decode_meta.seq_lens_tensor,
321
+ decode_meta.max_seq_len,
322
+ attn_metadata.kv_cache_dtype,
323
+ self.num_kv_heads,
324
+ self.scale,
325
+ self.alibi_slopes,
326
+ kv_scale,
327
+ )
328
+
329
+ # Reshape the output tensor.
330
+ return output.view(num_tokens, hidden_size)
331
+
332
+
333
+ def _naive_attention(
334
+ query: torch.Tensor,
335
+ key: torch.Tensor,
336
+ value: torch.Tensor,
337
+ seq_lens: List[int],
338
+ scale: float,
339
+ ) -> torch.Tensor:
340
+ output = torch.empty_like(query)
341
+ start = 0
342
+ for _, seq_len in enumerate(seq_lens):
343
+ end = start + seq_len
344
+ out = _naive_masked_attention(
345
+ query[start:end],
346
+ key[start:end],
347
+ value[start:end],
348
+ scale,
349
+ )
350
+ # TODO(woosuk): Unnecessary copy. Optimize.
351
+ output[start:end].copy_(out)
352
+ start += seq_len
353
+
354
+ return output
355
+
356
+
357
+ def _naive_masked_attention(
358
+ query: torch.Tensor,
359
+ key: torch.Tensor,
360
+ value: torch.Tensor,
361
+ scale: float,
362
+ ) -> torch.Tensor:
363
+ seq_len, head_size, head_dim = query.shape
364
+ attn_mask = torch.triu(torch.ones(seq_len,
365
+ seq_len,
366
+ dtype=query.dtype,
367
+ device=query.device),
368
+ diagonal=1)
369
+ attn_mask = attn_mask * torch.finfo(query.dtype).min
370
+ attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
371
+ attn_weights = attn_weights + attn_mask.float()
372
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
373
+ out = torch.einsum("hqk,khd->qhd", attn_weights, value)
374
+ return out