vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ """vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
2
+ import vllm_npu
3
+ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
4
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
5
+ from vllm.engine.llm_engine import LLMEngine
6
+ from vllm.entrypoints.llm import LLM
7
+ from vllm.executor.ray_utils import initialize_ray_cluster
8
+ from vllm.model_executor.models import ModelRegistry
9
+ from vllm.outputs import CompletionOutput, RequestOutput
10
+ from vllm.sampling_params import SamplingParams
11
+ __version__ = "0.4.2"
12
+ __all__ = [
13
+ "LLM",
14
+ "ModelRegistry",
15
+ "SamplingParams",
16
+ "RequestOutput",
17
+ "CompletionOutput",
18
+ "LLMEngine",
19
+ "EngineArgs",
20
+ "AsyncLLMEngine",
21
+ "AsyncEngineArgs",
22
+ "initialize_ray_cluster",
23
+ ]
vllm/_custom_ops.py ADDED
@@ -0,0 +1,251 @@
1
+ from typing import Dict, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from vllm._C import cache_ops as vllm_cache_ops
7
+ from vllm._C import ops as vllm_ops
8
+ except ImportError:
9
+ pass
10
+
11
+
12
+ # activation ops
13
+ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
14
+ vllm_ops.silu_and_mul(out, x)
15
+
16
+
17
+ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
18
+ vllm_ops.gelu_and_mul(out, x)
19
+
20
+
21
+ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
22
+ vllm_ops.gelu_tanh_and_mul(out, x)
23
+
24
+
25
+ def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
26
+ vllm_ops.gelu_fast(out, x)
27
+
28
+
29
+ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
30
+ vllm_ops.gelu_new(out, x)
31
+
32
+
33
+ # page attention ops
34
+ def paged_attention_v1(
35
+ out: torch.Tensor,
36
+ query: torch.Tensor,
37
+ key_cache: torch.Tensor,
38
+ value_cache: torch.Tensor,
39
+ num_kv_heads: int,
40
+ scale: float,
41
+ block_tables: torch.Tensor,
42
+ seq_lens: torch.Tensor,
43
+ block_size: int,
44
+ max_seq_len: int,
45
+ alibi_slopes: Optional[torch.Tensor],
46
+ kv_cache_dtype: str,
47
+ kv_scale: float,
48
+ ) -> None:
49
+ vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
50
+ num_kv_heads, scale, block_tables, seq_lens,
51
+ block_size, max_seq_len, alibi_slopes,
52
+ kv_cache_dtype, kv_scale)
53
+
54
+
55
+ def paged_attention_v2(
56
+ out: torch.Tensor,
57
+ exp_sum: torch.Tensor,
58
+ max_logits: torch.Tensor,
59
+ tmp_out: torch.Tensor,
60
+ query: torch.Tensor,
61
+ key_cache: torch.Tensor,
62
+ value_cache: torch.Tensor,
63
+ num_kv_heads: int,
64
+ scale: float,
65
+ block_tables: torch.Tensor,
66
+ seq_lens: torch.Tensor,
67
+ block_size: int,
68
+ max_seq_len: int,
69
+ alibi_slopes: Optional[torch.Tensor],
70
+ kv_cache_dtype: str,
71
+ kv_scale: float,
72
+ ) -> None:
73
+ vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
74
+ key_cache, value_cache, num_kv_heads, scale,
75
+ block_tables, seq_lens, block_size,
76
+ max_seq_len, alibi_slopes, kv_cache_dtype,
77
+ kv_scale)
78
+
79
+
80
+ # pos encoding ops
81
+ def rotary_embedding(
82
+ positions: torch.Tensor,
83
+ query: torch.Tensor,
84
+ key: torch.Tensor,
85
+ head_size: int,
86
+ cos_sin_cache: torch.Tensor,
87
+ is_neox: bool,
88
+ ) -> None:
89
+ vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
90
+ is_neox)
91
+
92
+
93
+ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
94
+ key: torch.Tensor, head_size: int,
95
+ cos_sin_cache: torch.Tensor, is_neox: bool,
96
+ rot_dim: int,
97
+ cos_sin_cache_offsets: torch.Tensor) -> None:
98
+ vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
99
+ cos_sin_cache, is_neox, rot_dim,
100
+ cos_sin_cache_offsets)
101
+
102
+
103
+ # layer norm ops
104
+ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
105
+ epsilon: float) -> None:
106
+ vllm_ops.rms_norm(out, input, weight, epsilon)
107
+
108
+
109
+ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
110
+ weight: torch.Tensor, epsilon: float) -> None:
111
+ vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
112
+
113
+
114
+ # quantization ops
115
+ # awq
116
+ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
117
+ zeros: torch.Tensor, split_k_iters: int, thx: int,
118
+ thy: int) -> torch.Tensor:
119
+ return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
120
+ thy)
121
+
122
+
123
+ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
124
+ scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
125
+ return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
126
+
127
+
128
+ # gptq
129
+ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
130
+ b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
131
+ b_g_idx: torch.Tensor, use_exllama: bool,
132
+ bit: int) -> torch.Tensor:
133
+ return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
134
+ b_g_idx, use_exllama, bit)
135
+
136
+
137
+ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
138
+ bit: int) -> None:
139
+ vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
140
+
141
+
142
+ # squeezellm
143
+ def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
144
+ lookup_table: torch.Tensor) -> None:
145
+ vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
146
+
147
+
148
+ # marlin
149
+ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
150
+ b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
151
+ size_n: int, size_k: int) -> torch.Tensor:
152
+ return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
153
+ size_n, size_k)
154
+
155
+
156
+ # aqlm
157
+ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
158
+ codebooks: torch.Tensor, scales: torch.Tensor,
159
+ codebook_partition_sizes: torch.Tensor,
160
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
161
+ return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
162
+ codebook_partition_sizes, bias)
163
+
164
+
165
+ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
166
+ codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
167
+ return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
168
+
169
+
170
+ # gptq_marlin
171
+ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
172
+ size_k: int, size_n: int,
173
+ num_bits: int) -> torch.Tensor:
174
+ return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
175
+ num_bits)
176
+
177
+
178
+ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
179
+ b_scales: torch.Tensor, g_idx: torch.Tensor,
180
+ perm: torch.Tensor, workspace: torch.Tensor,
181
+ num_bits: int, size_m: int, size_n: int, size_k: int,
182
+ is_k_full: bool) -> torch.Tensor:
183
+ return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
184
+ workspace, num_bits, size_m, size_n,
185
+ size_k, is_k_full)
186
+
187
+
188
+ # fp8
189
+ def scaled_fp8_quant(
190
+ input: torch.Tensor,
191
+ scale: Optional[torch.Tensor] = None,
192
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
194
+ if scale is None:
195
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
196
+ vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
197
+ else:
198
+ vllm_ops.static_scaled_fp8_quant(output, input, scale)
199
+ return output, scale
200
+
201
+
202
+ # moe
203
+ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
204
+ block_size: int, sorted_token_ids: torch.Tensor,
205
+ experts_ids: torch.Tensor,
206
+ num_tokens_post_pad: torch.Tensor) -> None:
207
+ vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
208
+ sorted_token_ids, experts_ids,
209
+ num_tokens_post_pad)
210
+
211
+
212
+ def reshape_and_cache(
213
+ key: torch.Tensor,
214
+ value: torch.Tensor,
215
+ key_cache: torch.Tensor,
216
+ value_cache: torch.Tensor,
217
+ slot_mapping: torch.Tensor,
218
+ kv_cache_dtype: str,
219
+ kv_scale: float,
220
+ ) -> None:
221
+ vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
222
+ slot_mapping, kv_cache_dtype, kv_scale)
223
+
224
+
225
+ def reshape_and_cache_flash(
226
+ key: torch.Tensor,
227
+ value: torch.Tensor,
228
+ key_cache: torch.Tensor,
229
+ value_cache: torch.Tensor,
230
+ slot_mapping: torch.Tensor,
231
+ kv_cache_dtype: str,
232
+ ) -> None:
233
+ vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
234
+ slot_mapping, kv_cache_dtype)
235
+
236
+
237
+ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
238
+ block_mapping: torch.Tensor) -> None:
239
+ vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
240
+
241
+
242
+ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
243
+ block_mapping: Dict[int, int]) -> None:
244
+ vllm_cache_ops.swap_blocks(src, dst, block_mapping)
245
+
246
+
247
+ def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
248
+ vllm_cache_ops.convert_fp8(output, input)
249
+
250
+
251
+ #TODO: cuda_utils, custom_ar
@@ -0,0 +1,13 @@
1
+ from vllm.attention.backends.abstract import (AttentionBackend,
2
+ AttentionMetadata,
3
+ AttentionMetadataPerStage)
4
+ from vllm.attention.layer import Attention
5
+ from vllm.attention.selector import get_attn_backend
6
+
7
+ __all__ = [
8
+ "AttentionBackend",
9
+ "AttentionMetadata",
10
+ "Attention",
11
+ "get_attn_backend",
12
+ "AttentionMetadataPerStage",
13
+ ]
File without changes
@@ -0,0 +1,127 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, fields
3
+ from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
4
+ TypeVar)
5
+
6
+ import torch
7
+
8
+
9
+ class AttentionBackend(ABC):
10
+ """Abstract class for attention backends."""
11
+
12
+ @staticmethod
13
+ @abstractmethod
14
+ def get_impl_cls() -> Type["AttentionImpl"]:
15
+ raise NotImplementedError
16
+
17
+ @staticmethod
18
+ @abstractmethod
19
+ def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
20
+ raise NotImplementedError
21
+
22
+ @staticmethod
23
+ @abstractmethod
24
+ def get_kv_cache_shape(
25
+ num_blocks: int,
26
+ block_size: int,
27
+ num_kv_heads: int,
28
+ head_size: int,
29
+ ) -> Tuple[int, ...]:
30
+ raise NotImplementedError
31
+
32
+ @staticmethod
33
+ @abstractmethod
34
+ def swap_blocks(
35
+ src_kv_cache: torch.Tensor,
36
+ dst_kv_cache: torch.Tensor,
37
+ src_to_dst: Dict[int, int],
38
+ ) -> None:
39
+ raise NotImplementedError
40
+
41
+ @staticmethod
42
+ @abstractmethod
43
+ def copy_blocks(
44
+ kv_caches: List[torch.Tensor],
45
+ src_to_dists: Dict[int, List[int]],
46
+ ) -> None:
47
+ raise NotImplementedError
48
+
49
+
50
+ @dataclass
51
+ class AttentionMetadataPerStage:
52
+ """Attention metadata for a specific stage. I.e., prefill or decode."""
53
+
54
+ def asdict_zerocopy(self,
55
+ skip_fields: Optional[Set[str]] = None
56
+ ) -> Dict[str, Any]:
57
+ """Similar to dataclasses.asdict, but avoids deepcopying."""
58
+ if skip_fields is None:
59
+ skip_fields = set()
60
+ # Note that if we add dataclasses as fields, they will need
61
+ # similar handling.
62
+ return {
63
+ field.name: getattr(self, field.name)
64
+ for field in fields(self) if field.name not in skip_fields
65
+ }
66
+
67
+
68
+ T = TypeVar("T", bound=AttentionMetadataPerStage)
69
+
70
+
71
+ @dataclass
72
+ class AttentionMetadata(Generic[T]):
73
+ """Attention metadata for prefill and decode batched together."""
74
+ # Total number of prefill requests.
75
+ num_prefills: int
76
+ # Number of prefill tokens.
77
+ num_prefill_tokens: int
78
+ # Number of decode tokens. Note that it is equivalent to the number of
79
+ # decode requests.
80
+ num_decode_tokens: int
81
+ # The attention metadata for prefill requests in a batch.
82
+ # None if there's no prefill requests in a batch.
83
+ prefill_metadata: Optional[T]
84
+ # The attention metadata for decode requests in a batch.
85
+ # None if there's no decode requests in a batch.
86
+ decode_metadata: Optional[T]
87
+ # (num_tokens,). The indices of the token slots that input tokens will be
88
+ # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
89
+ # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
90
+ # in block 0, and 1st slot in block 1, respectively.
91
+ slot_mapping: torch.Tensor
92
+ # The kv cache's data type.
93
+ kv_cache_dtype: str
94
+
95
+ def __post_init__(self):
96
+ if self.num_prefill_tokens > 0:
97
+ assert self.num_prefills > 0
98
+ assert self.prefill_metadata is not None
99
+ if self.num_decode_tokens > 0:
100
+ assert self.decode_metadata is not None
101
+
102
+
103
+ class AttentionImpl(ABC):
104
+
105
+ @abstractmethod
106
+ def __init__(
107
+ self,
108
+ num_heads: int,
109
+ head_size: int,
110
+ scale: float,
111
+ num_kv_heads: Optional[int] = None,
112
+ alibi_slopes: Optional[List[float]] = None,
113
+ sliding_window: Optional[int] = None,
114
+ ) -> None:
115
+ raise NotImplementedError
116
+
117
+ @abstractmethod
118
+ def forward(
119
+ self,
120
+ query: torch.Tensor,
121
+ key: torch.Tensor,
122
+ value: torch.Tensor,
123
+ kv_cache: torch.Tensor,
124
+ attn_metadata: AttentionMetadata,
125
+ kv_scale: float,
126
+ ) -> torch.Tensor:
127
+ raise NotImplementedError
@@ -0,0 +1,271 @@
1
+ """Attention layer with Flash and PagedAttention.
2
+
3
+ NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
4
+ XFormers backend. The duplicated code will be removed once we use flash-attn or
5
+ flashinfer for all the attention operations.
6
+ """
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional, Tuple, Type
9
+
10
+ import torch
11
+ from flash_attn import flash_attn_varlen_func
12
+
13
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
+ AttentionMetadata,
15
+ AttentionMetadataPerStage)
16
+ from vllm.attention.ops.paged_attn import (PagedAttention,
17
+ PagedAttentionMetadata)
18
+
19
+
20
+ class FlashAttentionBackend(AttentionBackend):
21
+
22
+ @staticmethod
23
+ def get_impl_cls() -> Type["FlashAttentionImpl"]:
24
+ return FlashAttentionImpl
25
+
26
+ @staticmethod
27
+ def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
28
+ return FlashAttentionMetadata(*args, **kwargs)
29
+
30
+ @staticmethod
31
+ def get_kv_cache_shape(
32
+ num_blocks: int,
33
+ block_size: int,
34
+ num_kv_heads: int,
35
+ head_size: int,
36
+ ) -> Tuple[int, ...]:
37
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
38
+ num_kv_heads, head_size)
39
+
40
+ @staticmethod
41
+ def swap_blocks(
42
+ src_kv_cache: torch.Tensor,
43
+ dst_kv_cache: torch.Tensor,
44
+ src_to_dst: Dict[int, int],
45
+ ) -> None:
46
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
47
+
48
+ @staticmethod
49
+ def copy_blocks(
50
+ kv_caches: List[torch.Tensor],
51
+ src_to_dists: Dict[int, List[int]],
52
+ ) -> None:
53
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
54
+
55
+
56
+ @dataclass
57
+ class FlashAttentionMetadata(AttentionMetadataPerStage,
58
+ PagedAttentionMetadata):
59
+ """Metadata for FlashAttentionBackend.
60
+
61
+ NOTE: Any python object stored here is not updated when it is
62
+ cuda-graph replayed. If you have values that need to be changed
63
+ dynamically, it should be stored in tensor. The tensor has to be
64
+ updated from `CUDAGraphRunner.forward` API.
65
+ """
66
+ # Currently, input sequences can only contain all prompts
67
+ # or all decoding. True if all sequences are prompts.
68
+ is_prompt: bool
69
+ # (batch_size,). The sequence length per sequence. Sequence length means
70
+ # the computed tokens + new tokens None if it is a decoding.
71
+ seq_lens: Optional[List[int]]
72
+ # seq_lens stored as a tensor.
73
+ seq_lens_tensor: Optional[torch.Tensor]
74
+
75
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
76
+ # |---------- N-1 iteration --------|
77
+ # |---------------- N iteration ---------------------|
78
+ # |- tokenA -|......................|-- newTokens ---|
79
+ # |---------- context_len ----------|
80
+ # |-------------------- seq_len ----------------------|
81
+ # |-- query_len ---|
82
+
83
+ # Maximum query length in the batch.
84
+ max_query_len: Optional[int]
85
+ # Maximum sequence length in the batch.
86
+ max_seq_len: Optional[int]
87
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
88
+ # the batch, used to index into subquery. E.g., if the subquery length
89
+ # is [4, 6], it is [0, 4, 10].
90
+ subquery_start_loc: Optional[torch.Tensor]
91
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
92
+ # the batch, used to index into sequence. E.g., if the sequence length is
93
+ # [4, 6], it is [0, 4, 10].
94
+ seq_start_loc: Optional[torch.Tensor]
95
+ # (batch_size,) A tensor of context lengths (tokens that are computed
96
+ # so far).
97
+ context_lens_tensor: Optional[torch.Tensor]
98
+
99
+ # Whether or not if cuda graph is enabled.
100
+ # Cuda-graph is currently enabled for decoding only.
101
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
102
+ use_cuda_graph: bool
103
+
104
+
105
+ class FlashAttentionImpl(AttentionImpl):
106
+ """
107
+ If the input tensors contain prompt tokens, the layout is as follows:
108
+ |<--------------- num_prefill_tokens ----------------->|
109
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
110
+
111
+ Otherwise, the layout is as follows:
112
+ |<----------------- num_decode_tokens ------------------>|
113
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
114
+
115
+ Generation tokens can contain padding when cuda-graph is used.
116
+ Currently, prompt tokens don't contain any padding.
117
+
118
+ The prompts might have different lengths, while the generation tokens
119
+ always have length 1.
120
+
121
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
122
+ batched together in a flattened 1D query.
123
+
124
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
125
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
126
+
127
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
128
+ padding between prefill and decode tokens.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ num_heads: int,
134
+ head_size: int,
135
+ scale: float,
136
+ num_kv_heads: Optional[int] = None,
137
+ alibi_slopes: Optional[List[float]] = None,
138
+ sliding_window: Optional[int] = None,
139
+ ) -> None:
140
+ self.num_heads = num_heads
141
+ self.head_size = head_size
142
+ self.scale = float(scale)
143
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
144
+ self.sliding_window = ((sliding_window, sliding_window)
145
+ if sliding_window is not None else (-1, -1))
146
+ if alibi_slopes is not None:
147
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
148
+ self.alibi_slopes = alibi_slopes
149
+
150
+ assert self.num_heads % self.num_kv_heads == 0
151
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
152
+
153
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
154
+ if head_size not in suppored_head_sizes:
155
+ raise ValueError(
156
+ f"Head size {head_size} is not supported by PagedAttention. "
157
+ f"Supported head sizes are: {suppored_head_sizes}.")
158
+
159
+ def forward(
160
+ self,
161
+ query: torch.Tensor,
162
+ key: torch.Tensor,
163
+ value: torch.Tensor,
164
+ kv_cache: torch.Tensor,
165
+ attn_metadata: AttentionMetadata[FlashAttentionMetadata],
166
+ kv_scale: float,
167
+ ) -> torch.Tensor:
168
+ """Forward pass with FlashAttention and PagedAttention.
169
+
170
+ Args:
171
+ query: shape = [num_tokens, num_heads * head_size]
172
+ key: shape = [num_tokens, num_kv_heads * head_size]
173
+ value: shape = [num_tokens, num_kv_heads * head_size]
174
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
175
+ attn_metadata: Metadata for attention.
176
+ Returns:
177
+ shape = [num_tokens, num_heads * head_size]
178
+ """
179
+ num_tokens, hidden_size = query.shape
180
+ # Reshape the query, key, and value tensors.
181
+ query = query.view(-1, self.num_heads, self.head_size)
182
+ key = key.view(-1, self.num_kv_heads, self.head_size)
183
+ value = value.view(-1, self.num_kv_heads, self.head_size)
184
+
185
+ if kv_cache is not None:
186
+ key_cache, value_cache = PagedAttention.split_kv_cache(
187
+ kv_cache, self.num_kv_heads, self.head_size)
188
+
189
+ # Reshape the input keys and values and store them in the cache.
190
+ # If kv_cache is not provided, the new key and value tensors are
191
+ # not cached. This happens during the initial memory profiling run.
192
+ PagedAttention.write_to_paged_cache(key, value, key_cache,
193
+ value_cache,
194
+ attn_metadata.slot_mapping,
195
+ attn_metadata.kv_cache_dtype,
196
+ kv_scale)
197
+
198
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
199
+ num_decode_tokens = attn_metadata.num_decode_tokens
200
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
201
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
202
+
203
+ output = torch.empty_like(query)
204
+ # Query for decode. KV is not needed because it is already cached.
205
+ decode_query = query[num_prefill_tokens:]
206
+ # QKV for prefill.
207
+ query = query[:num_prefill_tokens]
208
+ key = key[:num_prefill_tokens]
209
+ value = value[:num_prefill_tokens]
210
+
211
+ assert query.shape[0] == num_prefill_tokens
212
+ assert decode_query.shape[0] == num_decode_tokens
213
+
214
+ if prefill_meta := attn_metadata.prefill_metadata:
215
+ # Prompt run.
216
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
217
+ # normal attention
218
+ # When block_tables are not filled, it means q and k are the
219
+ # prompt, and they have the same length.
220
+ out = flash_attn_varlen_func(
221
+ q=query,
222
+ k=key,
223
+ v=value,
224
+ cu_seqlens_q=prefill_meta.seq_start_loc,
225
+ cu_seqlens_k=prefill_meta.seq_start_loc,
226
+ max_seqlen_q=prefill_meta.max_seq_len,
227
+ max_seqlen_k=prefill_meta.max_seq_len,
228
+ softmax_scale=self.scale,
229
+ causal=True,
230
+ window_size=self.sliding_window,
231
+ alibi_slopes=self.alibi_slopes,
232
+ )
233
+ assert output[:num_prefill_tokens].shape == out.shape
234
+ output[:num_prefill_tokens] = out
235
+ else:
236
+ # prefix-enabled attention
237
+ # TODO(Hai) this triton kernel has regression issue (broke) to
238
+ # deal with different data types between KV and FP8 KV cache,
239
+ # to be addressed separately.
240
+ output[:num_prefill_tokens] = PagedAttention.forward_prefix(
241
+ query,
242
+ key,
243
+ value,
244
+ key_cache,
245
+ value_cache,
246
+ prefill_meta.block_tables,
247
+ prefill_meta.subquery_start_loc,
248
+ prefill_meta.seq_lens_tensor,
249
+ prefill_meta.context_lens_tensor,
250
+ prefill_meta.max_query_len,
251
+ self.alibi_slopes,
252
+ self.sliding_window[0],
253
+ )
254
+ if decode_meta := attn_metadata.decode_metadata:
255
+ # Decoding run.
256
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
257
+ decode_query,
258
+ key_cache,
259
+ value_cache,
260
+ decode_meta.block_tables,
261
+ decode_meta.seq_lens_tensor,
262
+ decode_meta.max_seq_len,
263
+ attn_metadata.kv_cache_dtype,
264
+ self.num_kv_heads,
265
+ self.scale,
266
+ self.alibi_slopes,
267
+ kv_scale,
268
+ )
269
+
270
+ # Reshape the output tensor.
271
+ return output.view(num_tokens, hidden_size)