vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,250 @@
1
+ """ Attention layer with torch scaled_dot_product_attention
2
+ and PagedAttention."""
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+ from torch.nn.functional import scaled_dot_product_attention
8
+
9
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
+ AttentionMetadata,
11
+ AttentionMetadataPerStage)
12
+ from vllm.attention.ops.paged_attn import (PagedAttention,
13
+ PagedAttentionMetadata)
14
+
15
+
16
+ class TorchSDPABackend(AttentionBackend):
17
+
18
+ @staticmethod
19
+ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
20
+ return TorchSDPABackendImpl
21
+
22
+ @staticmethod
23
+ def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
24
+ return TorchSDPAMetadata(*args, **kwargs)
25
+
26
+ @staticmethod
27
+ def get_kv_cache_shape(
28
+ num_blocks: int,
29
+ block_size: int,
30
+ num_kv_heads: int,
31
+ head_size: int,
32
+ ) -> Tuple[int, ...]:
33
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
34
+ num_kv_heads, head_size)
35
+
36
+ @staticmethod
37
+ def swap_blocks(
38
+ src_kv_cache: torch.Tensor,
39
+ dst_kv_cache: torch.Tensor,
40
+ src_to_dst: Dict[int, int],
41
+ ) -> None:
42
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
43
+
44
+ @staticmethod
45
+ def copy_blocks(
46
+ kv_caches: List[torch.Tensor],
47
+ src_to_dists: Dict[int, List[int]],
48
+ ) -> None:
49
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
50
+
51
+
52
+ @dataclass
53
+ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
54
+ AttentionMetadataPerStage):
55
+ """Metadata for TorchSDPABackend.
56
+ """
57
+ # Currently, input sequences can only contain all prompts
58
+ # or all decoding. True if all sequences are prompts.
59
+ is_prompt: bool
60
+ slot_mapping: torch.Tensor
61
+ seq_lens: Optional[List[int]]
62
+
63
+ def __post_init__(self):
64
+ # Set during the execution of the first attention op.
65
+ # It is a list because it is needed to set per prompt
66
+ # when alibi slopes is used. It is because of the limitation
67
+ # from xformer API.
68
+ # will not appear in the __repr__ and __init__
69
+ self.attn_bias: Optional[List[torch.Tensor]] = None
70
+
71
+
72
+ class TorchSDPABackendImpl(AttentionImpl):
73
+
74
+ def __init__(
75
+ self,
76
+ num_heads: int,
77
+ head_size: int,
78
+ scale: float,
79
+ num_kv_heads: Optional[int] = None,
80
+ alibi_slopes: Optional[List[float]] = None,
81
+ sliding_window: Optional[int] = None,
82
+ ) -> None:
83
+ self.num_heads = num_heads
84
+ self.head_size = head_size
85
+ self.scale = float(scale)
86
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
87
+ self.sliding_window = sliding_window
88
+ if alibi_slopes is not None:
89
+ assert len(alibi_slopes) == num_heads
90
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
91
+ self.alibi_slopes = alibi_slopes
92
+ self.need_mask = (self.alibi_slopes is not None
93
+ or self.sliding_window is not None)
94
+
95
+ assert self.num_heads % self.num_kv_heads == 0
96
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
97
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
98
+ if head_size not in suppored_head_sizes:
99
+ raise ValueError(
100
+ f"Head size {head_size} is not supported by PagedAttention. "
101
+ f"Supported head sizes are: {suppored_head_sizes}.")
102
+
103
+ def forward(
104
+ self,
105
+ query: torch.Tensor,
106
+ key: torch.Tensor,
107
+ value: torch.Tensor,
108
+ kv_cache: Optional[torch.Tensor],
109
+ attn_metadata: TorchSDPAMetadata, # type: ignore
110
+ kv_scale: float,
111
+ ) -> torch.Tensor:
112
+ """Forward pass with torch SDPA and PagedAttention.
113
+
114
+ Args:
115
+ query: shape = [num_tokens, num_heads * head_size]
116
+ key: shape = [num_tokens, num_kv_heads * head_size]
117
+ value: shape = [num_tokens, num_kv_heads * head_size]
118
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
119
+ attn_metadata: Metadata for attention.
120
+ Returns:
121
+ shape = [num_tokens, num_heads * head_size]
122
+ """
123
+ num_tokens, hidden_size = query.shape
124
+ # Reshape the query, key, and value tensors.
125
+ query = query.view(-1, self.num_heads, self.head_size)
126
+ key = key.view(-1, self.num_kv_heads, self.head_size)
127
+ value = value.view(-1, self.num_kv_heads, self.head_size)
128
+
129
+ if kv_cache is not None:
130
+ key_cache, value_cache = PagedAttention.split_kv_cache(
131
+ kv_cache, self.num_kv_heads, self.head_size)
132
+ PagedAttention.write_to_paged_cache(key, value, key_cache,
133
+ value_cache,
134
+ attn_metadata.slot_mapping,
135
+ attn_metadata.kv_cache_dtype,
136
+ kv_scale)
137
+
138
+ if attn_metadata.is_prompt:
139
+ assert attn_metadata.seq_lens is not None
140
+ if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
141
+ if self.num_kv_heads != self.num_heads:
142
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
143
+ value = value.repeat_interleave(self.num_queries_per_kv,
144
+ dim=1)
145
+
146
+ if attn_metadata.attn_bias is None:
147
+ if self.alibi_slopes is not None:
148
+ att_masks = _make_alibi_bias(
149
+ self.alibi_slopes, query.dtype,
150
+ attn_metadata.seq_lens) # type: ignore
151
+ elif self.sliding_window is not None:
152
+ att_masks = _make_sliding_window_bias(
153
+ attn_metadata.seq_lens, self.sliding_window,
154
+ query.dtype) # type: ignore
155
+ else:
156
+ att_masks = [None] * len(attn_metadata.seq_lens)
157
+ attn_metadata.attn_bias = att_masks
158
+
159
+ query = query.movedim(0, query.dim() - 2)
160
+ key = key.movedim(0, key.dim() - 2)
161
+ value = value.movedim(0, value.dim() - 2)
162
+
163
+ start = 0
164
+ output = torch.empty(
165
+ (num_tokens, self.num_heads, self.head_size),
166
+ dtype=query.dtype)
167
+ for seq_len, mask in zip(attn_metadata.seq_lens,
168
+ attn_metadata.attn_bias):
169
+ end = start + seq_len
170
+ sub_out = scaled_dot_product_attention(
171
+ query[:, start:end, :],
172
+ key[:, start:end, :],
173
+ value[:, start:end, :],
174
+ attn_mask=mask,
175
+ dropout_p=0.0,
176
+ is_causal=not self.need_mask,
177
+ scale=self.scale).movedim(query.dim() - 2, 0)
178
+ output[start:end, :, :] = sub_out
179
+ start = end
180
+ else:
181
+ # prefix-enabled attention
182
+ raise RuntimeError(
183
+ "Torch SDPA backend doesn't support prefix decoding.")
184
+
185
+ else:
186
+ # Decoding run.
187
+ output = PagedAttention.forward_decode(
188
+ query,
189
+ key_cache,
190
+ value_cache,
191
+ attn_metadata.block_tables,
192
+ attn_metadata.seq_lens_tensor,
193
+ attn_metadata.max_seq_len,
194
+ attn_metadata.kv_cache_dtype,
195
+ self.num_kv_heads,
196
+ self.scale,
197
+ self.alibi_slopes,
198
+ kv_scale,
199
+ )
200
+
201
+ # Reshape the output tensor.
202
+ return output.view(-1, self.num_heads * self.head_size)
203
+
204
+
205
+ def _make_alibi_bias(
206
+ alibi_slopes: torch.Tensor,
207
+ dtype: torch.dtype,
208
+ seq_lens: List[int],
209
+ ) -> List[torch.Tensor]:
210
+ attn_biases = []
211
+ for seq_len in seq_lens:
212
+ bias = torch.arange(seq_len, dtype=dtype)
213
+ # NOTE(zhuohan): HF uses
214
+ # `bias = bias[None, :].repeat(seq_len, 1)`
215
+ # here. We find that both biases give the same results, but
216
+ # the bias below more accurately follows the original ALiBi
217
+ # paper.
218
+ bias = bias[None, :] - bias[:, None]
219
+
220
+ num_heads = alibi_slopes.shape[0]
221
+ bias = bias[None, :].repeat((num_heads, 1, 1))
222
+ bias.mul_(alibi_slopes[:, None, None])
223
+ inf_mask = torch.empty(
224
+ (1, seq_len, seq_len),
225
+ dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
226
+ attn_biases.append((bias + inf_mask).to(dtype))
227
+
228
+ return attn_biases
229
+
230
+
231
+ def _make_sliding_window_bias(
232
+ seq_lens: List[int],
233
+ window_size: Optional[int],
234
+ dtype: torch.dtype,
235
+ ) -> List[torch.Tensor]:
236
+ attn_biases = []
237
+ for seq_len in seq_lens:
238
+ tensor = torch.full(
239
+ (1, seq_len, seq_len),
240
+ dtype=dtype,
241
+ fill_value=1,
242
+ )
243
+ shift = 0
244
+ mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
245
+ if window_size is not None:
246
+ mask = torch.triu(mask, diagonal=shift - window_size + 1)
247
+ mask = torch.log(mask)
248
+ attn_biases.append(mask.to(dtype))
249
+
250
+ return attn_biases
@@ -0,0 +1,393 @@
1
+ """Attention layer with xFormers and PagedAttention."""
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Type
4
+
5
+ import torch
6
+ from xformers import ops as xops
7
+ from xformers.ops.fmha.attn_bias import (AttentionBias,
8
+ BlockDiagonalCausalMask,
9
+ LowerTriangularMaskWithTensorBias)
10
+
11
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
+ AttentionMetadata,
13
+ AttentionMetadataPerStage)
14
+ from vllm.attention.ops.paged_attn import (PagedAttention,
15
+ PagedAttentionMetadata)
16
+ from vllm.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ class XFormersBackend(AttentionBackend):
22
+
23
+ @staticmethod
24
+ def get_impl_cls() -> Type["XFormersImpl"]:
25
+ return XFormersImpl
26
+
27
+ @staticmethod
28
+ def make_metadata(*args, **kwargs) -> "XFormersMetadata":
29
+ return XFormersMetadata(*args, **kwargs)
30
+
31
+ @staticmethod
32
+ def get_kv_cache_shape(
33
+ num_blocks: int,
34
+ block_size: int,
35
+ num_kv_heads: int,
36
+ head_size: int,
37
+ ) -> Tuple[int, ...]:
38
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
39
+ num_kv_heads, head_size)
40
+
41
+ @staticmethod
42
+ def swap_blocks(
43
+ src_kv_cache: torch.Tensor,
44
+ dst_kv_cache: torch.Tensor,
45
+ src_to_dst: Dict[int, int],
46
+ ) -> None:
47
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
48
+
49
+ @staticmethod
50
+ def copy_blocks(
51
+ kv_caches: List[torch.Tensor],
52
+ src_to_dists: Dict[int, List[int]],
53
+ ) -> None:
54
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
55
+
56
+
57
+ @dataclass
58
+ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
59
+ """Metadata for XFormersbackend.
60
+
61
+ NOTE: Any python object stored here is not updated when it is
62
+ cuda-graph replayed. If you have values that need to be changed
63
+ dynamically, it should be stored in tensor. The tensor has to be
64
+ updated from `CUDAGraphRunner.forward` API.
65
+ """
66
+ # Currently, input sequences can only contain all prompts
67
+ # or all decoding. True if all sequences are prompts.
68
+ is_prompt: bool
69
+ # (batch_size,). The sequence length per sequence. Sequence length means
70
+ # the computed tokens + new tokens None if it is a decoding.
71
+ seq_lens: Optional[List[int]]
72
+ # seq_lens stored as a tensor.
73
+ seq_lens_tensor: Optional[torch.Tensor]
74
+
75
+ # |---------- N-1 iteration --------|
76
+ # |---------------- N iteration ---------------------|
77
+ # |- tokenA -|......................|-- newTokens ---|
78
+ # |---------- context_len ----------|
79
+ # |-------------------- seq_len ----------------------|
80
+ # |-- query_len ---|
81
+
82
+ # Maximum query length in the batch.
83
+ max_query_len: Optional[int]
84
+ # FIXME: It is for flash attn.
85
+ # Maximum sequence length in the batch.
86
+ max_seq_len: Optional[int]
87
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
88
+ # the batch, used to index into subquery. E.g., if the subquery length
89
+ # is [4, 6], it is [0, 4, 10].
90
+ subquery_start_loc: Optional[torch.Tensor]
91
+ # FIXME: It is for flash attn.
92
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
93
+ # the batch, used to index into sequence. E.g., if the sequence length is
94
+ # [4, 6], it is [0, 4, 10].
95
+ seq_start_loc: Optional[torch.Tensor]
96
+ # (batch_size,) A tensor of context lengths (tokens that are computed
97
+ # so far).
98
+ context_lens_tensor: Optional[torch.Tensor]
99
+
100
+ # Whether or not if cuda graph is enabled.
101
+ # Cuda-graph is currently enabled for decoding only.
102
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
103
+ use_cuda_graph: bool
104
+
105
+ def __post_init__(self):
106
+ # Set during the execution of the first attention op.
107
+ # It is a list because it is needed to set per prompt
108
+ # when alibi slopes is used. It is because of the limitation
109
+ # from xformer API.
110
+ # will not appear in the __repr__ and __init__
111
+ self.attn_bias: Optional[List[AttentionBias]] = None
112
+
113
+
114
+ class XFormersImpl(AttentionImpl):
115
+ """
116
+ If the input tensors contain prompt tokens, the layout is as follows:
117
+ |<--------------- num_prefill_tokens ----------------->|
118
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
119
+
120
+ Otherwise, the layout is as follows:
121
+ |<----------------- num_decode_tokens ------------------>|
122
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
123
+
124
+ Generation tokens can contain padding when cuda-graph is used.
125
+ Currently, prompt tokens don't contain any padding.
126
+
127
+ The prompts might have different lengths, while the generation tokens
128
+ always have length 1.
129
+
130
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
131
+ batched together in a flattened 1D query.
132
+
133
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
134
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
135
+
136
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
137
+ padding between prefill and decode tokens.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ num_heads: int,
143
+ head_size: int,
144
+ scale: float,
145
+ num_kv_heads: Optional[int] = None,
146
+ alibi_slopes: Optional[List[float]] = None,
147
+ sliding_window: Optional[int] = None,
148
+ ) -> None:
149
+ self.num_heads = num_heads
150
+ self.head_size = head_size
151
+ self.scale = float(scale)
152
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
153
+ self.sliding_window = sliding_window
154
+ if alibi_slopes is not None:
155
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
156
+ self.alibi_slopes = alibi_slopes
157
+
158
+ assert self.num_heads % self.num_kv_heads == 0
159
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
160
+
161
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
162
+ if head_size not in suppored_head_sizes:
163
+ raise ValueError(
164
+ f"Head size {head_size} is not supported by PagedAttention. "
165
+ f"Supported head sizes are: {suppored_head_sizes}.")
166
+
167
+ def forward(
168
+ self,
169
+ query: torch.Tensor,
170
+ key: torch.Tensor,
171
+ value: torch.Tensor,
172
+ kv_cache: Optional[torch.Tensor],
173
+ attn_metadata: AttentionMetadata[XFormersMetadata],
174
+ kv_scale: float,
175
+ ) -> torch.Tensor:
176
+ """Forward pass with xFormers and PagedAttention.
177
+
178
+ Args:
179
+ query: shape = [num_tokens, num_heads * head_size]
180
+ key: shape = [num_tokens, num_kv_heads * head_size]
181
+ value: shape = [num_tokens, num_kv_heads * head_size]
182
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
183
+ attn_metadata: Metadata for attention.
184
+ Returns:
185
+ shape = [num_tokens, num_heads * head_size]
186
+ """
187
+ num_tokens, hidden_size = query.shape
188
+ query = query.view(-1, self.num_heads, self.head_size)
189
+ key = key.view(-1, self.num_kv_heads, self.head_size)
190
+ value = value.view(-1, self.num_kv_heads, self.head_size)
191
+
192
+ if kv_cache is not None:
193
+ key_cache, value_cache = PagedAttention.split_kv_cache(
194
+ kv_cache, self.num_kv_heads, self.head_size)
195
+
196
+ # Reshape the input keys and values and store them in the cache.
197
+ # If kv_cache is not provided, the new key and value tensors are
198
+ # not cached. This happens during the initial memory profiling run.
199
+ PagedAttention.write_to_paged_cache(key, value, key_cache,
200
+ value_cache,
201
+ attn_metadata.slot_mapping,
202
+ attn_metadata.kv_cache_dtype,
203
+ kv_scale)
204
+
205
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
206
+ num_decode_tokens = attn_metadata.num_decode_tokens
207
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
208
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
209
+
210
+ output = torch.empty_like(query)
211
+ # Query for decode. KV is not needed because it is already cached.
212
+ decode_query = query[num_prefill_tokens:]
213
+ # QKV for prefill.
214
+ query = query[:num_prefill_tokens]
215
+ key = key[:num_prefill_tokens]
216
+ value = value[:num_prefill_tokens]
217
+
218
+ assert query.shape[0] == num_prefill_tokens
219
+ assert decode_query.shape[0] == num_decode_tokens
220
+
221
+ if prefill_meta := attn_metadata.prefill_metadata:
222
+ # Prompt run.
223
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
224
+ # normal attention.
225
+ # block tables are empty if the prompt does not have a cached
226
+ # prefix.
227
+ out = self._run_memory_efficient_xformers_forward(
228
+ query, key, value, prefill_meta)
229
+ assert out.shape == output[:num_prefill_tokens].shape
230
+ output[:num_prefill_tokens] = out
231
+ else:
232
+ # prefix-enabled attention
233
+ # TODO(Hai) this triton kernel has regression issue (broke) to
234
+ # deal with different data types between KV and FP8 KV cache,
235
+ # to be addressed separately.
236
+ out = PagedAttention.forward_prefix(
237
+ query,
238
+ key,
239
+ value,
240
+ key_cache,
241
+ value_cache,
242
+ prefill_meta.block_tables,
243
+ prefill_meta.subquery_start_loc,
244
+ prefill_meta.seq_lens_tensor,
245
+ prefill_meta.context_lens_tensor,
246
+ prefill_meta.max_query_len,
247
+ self.alibi_slopes,
248
+ self.sliding_window,
249
+ )
250
+ assert output[:num_prefill_tokens].shape == out.shape
251
+ output[:num_prefill_tokens] = out
252
+
253
+ if decode_meta := attn_metadata.decode_metadata:
254
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
255
+ decode_query,
256
+ key_cache,
257
+ value_cache,
258
+ decode_meta.block_tables,
259
+ decode_meta.seq_lens_tensor,
260
+ decode_meta.max_seq_len,
261
+ attn_metadata.kv_cache_dtype,
262
+ self.num_kv_heads,
263
+ self.scale,
264
+ self.alibi_slopes,
265
+ kv_scale,
266
+ )
267
+
268
+ # Reshape the output tensor.
269
+ return output.view(-1, self.num_heads * self.head_size)
270
+
271
+ def _run_memory_efficient_xformers_forward(
272
+ self,
273
+ query: torch.Tensor,
274
+ key: torch.Tensor,
275
+ value: torch.Tensor,
276
+ attn_metadata: XFormersMetadata,
277
+ ) -> torch.Tensor:
278
+ """Attention for 1D query of multiple prompts. Multiple prompt
279
+ tokens are flattened in to `query` input.
280
+
281
+ See https://facebookresearch.github.io/xformers/components/ops.html
282
+ for API spec.
283
+
284
+ Args:
285
+ output: shape = [num_prefill_tokens, num_heads, head_size]
286
+ query: shape = [num_prefill_tokens, num_heads, head_size]
287
+ key: shape = [num_prefill_tokens, num_kv_heads, head_size]
288
+ value: shape = [num_prefill_tokens, num_kv_heads, head_size]
289
+ attn_metadata: Metadata for attention.
290
+ """
291
+ assert attn_metadata.seq_lens is not None
292
+ original_query = query
293
+ if self.num_kv_heads != self.num_heads:
294
+ # GQA/MQA requires the shape [B, M, G, H, K].
295
+ # Note that the output also has the same shape (which is different
296
+ # from a spec from the doc).
297
+ query = query.view(query.shape[0], self.num_kv_heads,
298
+ self.num_queries_per_kv, query.shape[-1])
299
+ key = key[:, :,
300
+ None, :].expand(key.shape[0], self.num_kv_heads,
301
+ self.num_queries_per_kv, key.shape[-1])
302
+ value = value[:, :,
303
+ None, :].expand(value.shape[0], self.num_kv_heads,
304
+ self.num_queries_per_kv,
305
+ value.shape[-1])
306
+ # Set attention bias if not provided. This typically happens at
307
+ # the very attention layer of every iteration.
308
+ # FIXME(woosuk): This is a hack.
309
+ if attn_metadata.attn_bias is None:
310
+ if self.alibi_slopes is None:
311
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(
312
+ attn_metadata.seq_lens)
313
+ if self.sliding_window is not None:
314
+ attn_bias = attn_bias.make_local_attention(
315
+ self.sliding_window)
316
+ attn_metadata.attn_bias = [attn_bias]
317
+ else:
318
+ attn_metadata.attn_bias = _make_alibi_bias(
319
+ self.alibi_slopes, self.num_kv_heads, query.dtype,
320
+ attn_metadata.seq_lens)
321
+
322
+ # No alibi slopes.
323
+ # TODO(woosuk): Too many view operations. Let's try to reduce
324
+ # them in the future for code readability.
325
+ if self.alibi_slopes is None:
326
+ # Add the batch dimension.
327
+ query = query.unsqueeze(0)
328
+ key = key.unsqueeze(0)
329
+ value = value.unsqueeze(0)
330
+ out = xops.memory_efficient_attention_forward(
331
+ query,
332
+ key,
333
+ value,
334
+ attn_bias=attn_metadata.attn_bias[0],
335
+ p=0.0,
336
+ scale=self.scale)
337
+ return out.view_as(original_query)
338
+
339
+ # Attention with alibi slopes.
340
+ # FIXME(woosuk): Because xformers does not support dynamic sequence
341
+ # lengths with custom attention bias, we process each prompt one by
342
+ # one. This is inefficient, especially when we have many short prompts.
343
+ output = torch.empty_like(original_query)
344
+ start = 0
345
+ for i, seq_len in enumerate(attn_metadata.seq_lens):
346
+ end = start + seq_len
347
+ out = xops.memory_efficient_attention_forward(
348
+ query[None, start:end],
349
+ key[None, start:end],
350
+ value[None, start:end],
351
+ attn_bias=attn_metadata.attn_bias[i],
352
+ p=0.0,
353
+ scale=self.scale)
354
+ # TODO(woosuk): Unnecessary copy. Optimize.
355
+ output[start:end].copy_(out.view_as(original_query[start:end]))
356
+ start += seq_len
357
+ return output
358
+
359
+
360
+ def _make_alibi_bias(
361
+ alibi_slopes: torch.Tensor,
362
+ num_kv_heads: int,
363
+ dtype: torch.dtype,
364
+ seq_lens: List[int],
365
+ ) -> LowerTriangularMaskWithTensorBias:
366
+ attn_biases = []
367
+ for seq_len in seq_lens:
368
+ bias = torch.arange(seq_len, dtype=dtype)
369
+ # NOTE(zhuohan): HF uses
370
+ # `bias = bias[None, :].repeat(seq_len, 1)`
371
+ # here. We find that both biases give the same results, but
372
+ # the bias below more accurately follows the original ALiBi
373
+ # paper.
374
+ # Calculate a matrix where each element represents ith element- jth
375
+ # element.
376
+ bias = bias[None, :] - bias[:, None]
377
+
378
+ padded_len = (seq_len + 7) // 8 * 8
379
+ num_heads = alibi_slopes.shape[0]
380
+ bias = torch.empty(
381
+ 1, # batch size
382
+ num_heads,
383
+ seq_len,
384
+ padded_len,
385
+ device=alibi_slopes.device,
386
+ dtype=dtype,
387
+ )[:, :, :, :seq_len].copy_(bias)
388
+ bias.mul_(alibi_slopes[:, None, None])
389
+ if num_heads != num_kv_heads:
390
+ bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
391
+ attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
392
+
393
+ return attn_biases
@@ -0,0 +1,56 @@
1
+ """Attention layer."""
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vllm.attention.backends.abstract import (AttentionMetadata,
8
+ AttentionMetadataPerStage)
9
+ from vllm.attention.selector import get_attn_backend
10
+
11
+
12
+ class Attention(nn.Module):
13
+ """Attention layer.
14
+
15
+ This class takes query, key, and value tensors as input. The input tensors
16
+ can either contain prompt tokens or generation tokens.
17
+ The class does the following:
18
+
19
+ 1. Store the input key and value tensors in the KV cache.
20
+ 2. Perform (multi-head/multi-query/grouped-query) attention.
21
+ 3. Return the output tensor.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_heads: int,
27
+ head_size: int,
28
+ scale: float,
29
+ num_kv_heads: Optional[int] = None,
30
+ alibi_slopes: Optional[List[float]] = None,
31
+ sliding_window: Optional[int] = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.backend = get_attn_backend(torch.get_default_dtype())
35
+ impl_cls = self.backend.get_impl_cls()
36
+ self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
37
+ alibi_slopes, sliding_window)
38
+
39
+ def forward(
40
+ self,
41
+ query: torch.Tensor,
42
+ key: torch.Tensor,
43
+ value: torch.Tensor,
44
+ kv_cache: Optional[torch.Tensor],
45
+ attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
46
+ kv_scale: float = 1.0,
47
+ ) -> torch.Tensor:
48
+ return self.impl.forward(query, key, value, kv_cache, attn_metadata,
49
+ kv_scale)
50
+
51
+ def extra_repr(self) -> str:
52
+ s = f"head_size={self.impl.head_size}" # type: ignore
53
+ s += f", num_heads={self.impl.num_heads}" # type: ignore
54
+ s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
55
+ s += f", scale={self.impl.scale}" # type: ignore
56
+ return s
File without changes