vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1168 @@
1
+ import contextlib
2
+ import time
3
+ from enum import IntEnum
4
+ from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
11
+ get_attn_backend)
12
+ from vllm.attention.backends.flashinfer import FlashInferBackend
13
+ from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
14
+ ParallelConfig, SchedulerConfig, VisionLanguageConfig)
15
+ from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
16
+ from vllm.distributed.device_communicators import (custom_all_reduce,
17
+ pynccl_utils)
18
+ from vllm.logger import init_logger
19
+ from vllm.lora.layers import LoRAMapping
20
+ from vllm.lora.request import LoRARequest
21
+ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
22
+ from vllm.model_executor import SamplingMetadata
23
+ from vllm.model_executor.model_loader import get_model
24
+ from vllm.sampling_params import SamplingParams
25
+ from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
26
+ SequenceGroupMetadata)
27
+ from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
28
+ is_pin_memory_available, make_tensor_with_pad)
29
+
30
+ logger = init_logger(__name__)
31
+
32
+ _PAD_SLOT_ID = -1
33
+ LORA_WARMUP_RANK = 8
34
+ _BATCH_SIZE_ALIGNMENT = 8
35
+ # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
+ # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
+ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
38
+ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
39
+ ]
40
+
41
+
42
+ class PreparePromptMetadata(NamedTuple):
43
+ input_tokens: List[int]
44
+ input_positions: List[int]
45
+ attn_metadata: Optional[AttentionMetadataPerStage]
46
+ seq_lens: List[int]
47
+ query_lens: List[int]
48
+ lora_index_mapping: List[int]
49
+ lora_prompt_mapping: List[int]
50
+ lora_requests: Set[LoRARequest]
51
+ multi_modal_input: Optional[torch.Tensor]
52
+ slot_mapping: List[int]
53
+
54
+ @classmethod
55
+ def empty(cls):
56
+ return PreparePromptMetadata(
57
+ input_tokens=[],
58
+ input_positions=[],
59
+ attn_metadata=None,
60
+ seq_lens=[],
61
+ query_lens=[],
62
+ lora_index_mapping=[],
63
+ lora_prompt_mapping=[],
64
+ lora_requests=set(),
65
+ multi_modal_input=None,
66
+ slot_mapping=[],
67
+ )
68
+
69
+
70
+ class PrepareDecodeMetadata(NamedTuple):
71
+ input_tokens: List[int]
72
+ input_positions: List[int]
73
+ attn_metadata: Optional[AttentionMetadata]
74
+ lora_index_mapping: List[int]
75
+ lora_prompt_mapping: List[int]
76
+ lora_requests: Set[LoRARequest]
77
+ slot_mapping: List[int]
78
+
79
+ @classmethod
80
+ def empty(cls):
81
+ return PrepareDecodeMetadata(
82
+ input_tokens=[],
83
+ input_positions=[],
84
+ attn_metadata=None,
85
+ lora_index_mapping=[],
86
+ lora_prompt_mapping=[],
87
+ lora_requests=set(),
88
+ slot_mapping=[],
89
+ )
90
+
91
+
92
+ # How batches are constructed.
93
+ class BatchType(IntEnum):
94
+ # Every batch is prefill.
95
+ PREFILL = 0
96
+ # Every batch is decode.
97
+ DECODE = 1
98
+ # Batch is a mixture of prefill and decode.
99
+ MIXED = 2
100
+
101
+
102
+ class ModelRunner:
103
+
104
+ def __init__(
105
+ self,
106
+ model_config: ModelConfig,
107
+ parallel_config: ParallelConfig,
108
+ scheduler_config: SchedulerConfig,
109
+ device_config: DeviceConfig,
110
+ load_config: LoadConfig,
111
+ lora_config: Optional[LoRAConfig],
112
+ kv_cache_dtype: Optional[str] = "auto",
113
+ is_driver_worker: bool = False,
114
+ vision_language_config: Optional[VisionLanguageConfig] = None,
115
+ ):
116
+ self.model_config = model_config
117
+ self.parallel_config = parallel_config
118
+ self.scheduler_config = scheduler_config
119
+ self.lora_config = lora_config
120
+ self.load_config = load_config
121
+ self.is_driver_worker = is_driver_worker
122
+
123
+ # model_config can be None in tests/samplers/test_sampler.py.
124
+ # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
125
+ self.sliding_window = (model_config.get_sliding_window()
126
+ if model_config is not None else None)
127
+ self.device_config = (device_config
128
+ if device_config is not None else DeviceConfig())
129
+ self.device = self.device_config.device
130
+
131
+ # Set after load_model.
132
+ self.lora_manager: LRUCacheWorkerLoRAManager = None
133
+
134
+ self.graph_runners: Dict[int, CUDAGraphRunner] = {}
135
+ self.graph_memory_pool: Optional[Tuple[
136
+ int, int]] = None # Set during graph capture.
137
+
138
+ self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
139
+ if self.model_config is not None else 0)
140
+
141
+ self.pin_memory = is_pin_memory_available()
142
+ self.kv_cache_dtype = kv_cache_dtype
143
+ self.vision_language_config = vision_language_config
144
+
145
+ self.attn_backend = get_attn_backend(
146
+ self.model_config.dtype if model_config is not None else None)
147
+
148
+ # Lazy initialization
149
+ self.model: torch.nn.Module # Set after load_model
150
+ self.block_size: int # Set after initial profiling.
151
+ # When using CUDA graph, the input block tables must be padded to
152
+ # max_seq_len_to_capture. However, creating the block table in
153
+ # Python can be expensive. To optimize this, we cache the block table
154
+ # in numpy and only copy the actual input content at every iteration.
155
+ # The shape of the cached block table will be
156
+ # (max batch size to capture, max context len to capture / block size).
157
+ self.graph_block_tables: torch.Tensor # Set after initial profiling.
158
+
159
+ # Set if the backend is flashinfer.
160
+ self.flashinfer_workspace_buffer: torch.Tensor
161
+
162
+ def load_model(self) -> None:
163
+ with CudaMemoryProfiler() as m:
164
+ self.model = get_model(
165
+ model_config=self.model_config,
166
+ device_config=self.device_config,
167
+ load_config=self.load_config,
168
+ lora_config=self.lora_config,
169
+ vision_language_config=self.vision_language_config,
170
+ parallel_config=self.parallel_config,
171
+ scheduler_config=self.scheduler_config,
172
+ )
173
+
174
+ self.model_memory_usage = m.consumed_memory
175
+ logger.info("Loading model weights took %.4f GB",
176
+ self.model_memory_usage / float(2**30))
177
+
178
+ if self.lora_config:
179
+ assert hasattr(self.model, "supported_lora_modules"
180
+ ) and self.model.supported_lora_modules, (
181
+ "Model does not support LoRA")
182
+ assert hasattr(
183
+ self.model,
184
+ "embedding_modules"), "Model does not have embedding_modules"
185
+ assert hasattr(self.model, "embedding_padding_modules"
186
+ ), "Model does not have embedding_padding_modules"
187
+ self.lora_manager = LRUCacheWorkerLoRAManager(
188
+ self.scheduler_config.max_num_seqs,
189
+ self.scheduler_config.max_num_batched_tokens, self.vocab_size,
190
+ self.lora_config, self.device, self.model.embedding_modules,
191
+ self.model.embedding_padding_modules)
192
+ self.model = self.lora_manager.create_lora_manager(self.model)
193
+
194
+ if self.kv_cache_dtype == "fp8" and is_hip():
195
+ # Currently scaled KV cache is only enabled on ROCm
196
+ if self.model_config.quantization_param_path is not None:
197
+ if callable(getattr(self.model, "load_kv_cache_scales", None)):
198
+ self.model.load_kv_cache_scales(
199
+ self.model_config.quantization_param_path)
200
+ else:
201
+ raise RuntimeError(
202
+ "Using FP8 KV cache and scaling factors provided but "
203
+ "model %s does not support loading scaling factors.",
204
+ self.model.__class__)
205
+ else:
206
+ logger.warning(
207
+ "Using FP8 KV cache but no scaling factors "
208
+ "provided. Defaulting to scaling factors of 1.0. "
209
+ "This may lead to less accurate results!")
210
+ elif self.model_config.quantization_param_path is not None:
211
+ logger.warning("KV cache scaling factors provided, "
212
+ "but the KV cache data type is not FP8. "
213
+ "KV cache scaling factors will not be used.")
214
+
215
+ def set_block_size(self, block_size: int) -> None:
216
+ self.block_size = block_size
217
+
218
+ self.graph_block_tables = np.zeros(
219
+ (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
220
+ dtype=np.int32)
221
+
222
+ def get_max_block_per_batch(self) -> int:
223
+ block_size = self.block_size
224
+ return (self.max_seq_len_to_capture + block_size - 1) // block_size
225
+
226
+ def _prepare_prompt(
227
+ self,
228
+ seq_group_metadata_list: List[SequenceGroupMetadata],
229
+ ) -> PreparePromptMetadata:
230
+ input_tokens: List[int] = []
231
+ input_positions: List[int] = []
232
+ slot_mapping: List[int] = []
233
+ lora_index_mapping: List[int] = []
234
+ lora_prompt_mapping: List[int] = []
235
+ lora_requests: Set[LoRARequest] = set()
236
+
237
+ seq_lens: List[int] = []
238
+ context_lens: List[int] = []
239
+ query_lens: List[int] = []
240
+ prefix_block_tables: List[List[int]] = []
241
+ multi_modal_input_list: List[torch.Tensor] = []
242
+
243
+ if len(seq_group_metadata_list) == 0:
244
+ return PreparePromptMetadata.empty()
245
+
246
+ for seq_group_metadata in seq_group_metadata_list:
247
+ assert seq_group_metadata.is_prompt
248
+ seq_ids = list(seq_group_metadata.seq_data.keys())
249
+ assert len(seq_ids) == 1
250
+ seq_id = seq_ids[0]
251
+
252
+ computed_block_nums = seq_group_metadata.computed_block_nums
253
+ if (self.scheduler_config is not None
254
+ and self.scheduler_config.chunked_prefill_enabled
255
+ and not (computed_block_nums is None
256
+ or computed_block_nums == [])):
257
+ raise RuntimeError(
258
+ "chunked prefill cannot be used with prefix caching "
259
+ "now.")
260
+
261
+ token_chunk_size = seq_group_metadata.token_chunk_size
262
+ seq_data = seq_group_metadata.seq_data[seq_id]
263
+ context_len = seq_data.get_num_computed_tokens()
264
+ # We should use get_len here because in case of preemption
265
+ # it contains output tokens.
266
+ seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
267
+ prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
268
+ seq_lens.append(seq_len)
269
+
270
+ # NOTE: This only works for oooooooxxx style attention.
271
+ if computed_block_nums is not None and len(
272
+ computed_block_nums) > 0 and self.sliding_window is None:
273
+ # Prefix is not supported with sliding_window
274
+ context_len = len(computed_block_nums) * self.block_size
275
+ prompt_tokens = prompt_tokens[context_len:]
276
+ prefix_block_tables.append(computed_block_nums)
277
+ elif self.scheduler_config.chunked_prefill_enabled:
278
+ if seq_group_metadata.block_tables is not None:
279
+ # Prefill has chunked before.
280
+ block_table = seq_group_metadata.block_tables[seq_id]
281
+ prefix_block_tables.append(block_table)
282
+ else:
283
+ # The first prefill.
284
+ prefix_block_tables.append([])
285
+ else:
286
+ prefix_block_tables.append([])
287
+ # Right now, prefill start is always 0. However, this
288
+ # assumption can be changed once chunked prefill is introduced.
289
+ assert context_len == 0
290
+
291
+ # actual prompt lens
292
+ context_lens.append(context_len)
293
+ query_lens.append(seq_len - context_len)
294
+
295
+ input_tokens.extend(prompt_tokens)
296
+ # NOTE(woosuk): Here we assume that the first token in the prompt
297
+ # is always the first token in the sequence.
298
+ input_positions.extend(list(range(context_len, seq_len)))
299
+ lora_id = seq_group_metadata.lora_int_id
300
+
301
+ if lora_id > 0:
302
+ lora_requests.add(seq_group_metadata.lora_request)
303
+
304
+ lora_index_mapping += [lora_id] * (seq_len - context_len)
305
+ lora_prompt_mapping.extend(
306
+ [lora_id] *
307
+ (seq_len - context_len
308
+ if seq_group_metadata.sampling_params.prompt_logprobs else 1))
309
+
310
+ if seq_group_metadata.multi_modal_data:
311
+ multi_modal_input_list.append(
312
+ seq_group_metadata.multi_modal_data.data)
313
+
314
+ if seq_group_metadata.block_tables is None:
315
+ # During memory profiling, the block tables are not initialized
316
+ # yet. In this case, we just use a dummy slot mapping.
317
+ slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
318
+ continue
319
+
320
+ # Compute the slot mapping.
321
+ block_table = seq_group_metadata.block_tables[seq_id]
322
+
323
+ # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
324
+ # where start_idx is max(0, seq_len - sliding_window).
325
+ # For example, if the prompt len is 10, sliding window is 8, and
326
+ # block size is 4, the first two tokens are masked and the slot
327
+ # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
328
+ start_idx = 0
329
+ if self.sliding_window is not None:
330
+ assert context_len == 0, (
331
+ "Prefix caching is currently not supported with "
332
+ "sliding window attention")
333
+ start_idx = max(0, seq_len - self.sliding_window)
334
+
335
+ for i in range(context_len, seq_len):
336
+ if i < start_idx:
337
+ slot_mapping.append(_PAD_SLOT_ID)
338
+ continue
339
+
340
+ block_number = block_table[i // self.block_size]
341
+ block_offset = i % self.block_size
342
+ slot = block_number * self.block_size + block_offset
343
+ slot_mapping.append(slot)
344
+
345
+ max_query_len = max(query_lens)
346
+ max_seq_len = max(seq_lens)
347
+ assert max_query_len > 0
348
+
349
+ context_lens_tensor = torch.tensor(context_lens,
350
+ dtype=torch.int,
351
+ device=self.device)
352
+
353
+ if multi_modal_input_list:
354
+ assert self.vision_language_config, (
355
+ "Multi-modal inputs are only supported by "
356
+ "vision language models.")
357
+ multi_modal_input = torch.cat(multi_modal_input_list,
358
+ dim=0).to(self.device)
359
+ else:
360
+ multi_modal_input = None
361
+
362
+ # Prepare prefix block tables
363
+ max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
364
+ block_tables = make_tensor_with_pad(
365
+ prefix_block_tables,
366
+ max_len=max_prompt_block_table_len,
367
+ pad=0,
368
+ dtype=torch.int,
369
+ device=self.device,
370
+ )
371
+
372
+ # Query length can be shorter than key (i.e., prompt) when prefill
373
+ # is chunked or prefix cached.
374
+ query_lens_tensor = torch.tensor(query_lens,
375
+ dtype=torch.long,
376
+ device=self.device)
377
+ subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
378
+ dtype=torch.int32,
379
+ device=self.device)
380
+
381
+ seq_lens_tensor = torch.tensor(seq_lens,
382
+ dtype=torch.int,
383
+ device=self.device)
384
+ seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
385
+ dtype=torch.int32,
386
+ device=self.device)
387
+
388
+ torch.cumsum(query_lens_tensor,
389
+ dim=0,
390
+ dtype=subquery_start_loc.dtype,
391
+ out=subquery_start_loc[1:])
392
+
393
+ torch.cumsum(seq_lens_tensor,
394
+ dim=0,
395
+ dtype=seq_start_loc.dtype,
396
+ out=seq_start_loc[1:])
397
+
398
+ if self.attn_backend is FlashInferBackend:
399
+ attn_metadata = self.attn_backend.make_metadata(
400
+ is_prompt=True,
401
+ use_cuda_graph=False,
402
+ seq_start_loc=seq_start_loc,
403
+ max_seq_len=max_seq_len,
404
+ block_tables=block_tables)
405
+ else:
406
+ attn_metadata = self.attn_backend.make_metadata(
407
+ is_prompt=True,
408
+ seq_lens=seq_lens,
409
+ seq_lens_tensor=seq_lens_tensor,
410
+ max_query_len=max_query_len,
411
+ max_seq_len=max_seq_len,
412
+ subquery_start_loc=subquery_start_loc,
413
+ seq_start_loc=seq_start_loc,
414
+ context_lens_tensor=context_lens_tensor,
415
+ block_tables=block_tables,
416
+ use_cuda_graph=False,
417
+ )
418
+
419
+ return PreparePromptMetadata(
420
+ input_tokens=input_tokens,
421
+ input_positions=input_positions,
422
+ attn_metadata=attn_metadata,
423
+ seq_lens=seq_lens,
424
+ query_lens=query_lens,
425
+ lora_index_mapping=lora_index_mapping,
426
+ lora_prompt_mapping=lora_prompt_mapping,
427
+ lora_requests=lora_requests,
428
+ multi_modal_input=multi_modal_input,
429
+ slot_mapping=slot_mapping,
430
+ )
431
+
432
+ def _prepare_decode(
433
+ self,
434
+ seq_group_metadata_list: List[SequenceGroupMetadata],
435
+ ) -> PrepareDecodeMetadata:
436
+ input_tokens: List[int] = []
437
+ input_positions: List[int] = []
438
+ slot_mapping: List[int] = []
439
+ seq_lens: List[int] = []
440
+ block_tables: List[List[int]] = []
441
+ lora_index_mapping: List[int] = []
442
+ lora_prompt_mapping: List[int] = []
443
+ lora_requests: Set[LoRARequest] = set()
444
+
445
+ # The following fields are only for flashinfer
446
+ # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
447
+ # for the precise definition of the following fields.
448
+ # An example:
449
+ # request 1, page indices [0, 5, 8]
450
+ # request 2, page indices [1, 6, 7]
451
+ # request 3, page indices [3, 4]
452
+ # paged_kv_indices is a concatenation of page indices of all requests:
453
+ # [0, 5, 8, 1, 6, 7, 3, 4]
454
+ # paged_kv_indptr is used to index into paged_kv_indices:
455
+ # [0, 3, 6, 8]
456
+ paged_kv_indices: List[int] = []
457
+ # 0 at the beginning of paged_kv_indptr indicates the start of the
458
+ # first request’s page indices in the paged_kv_indices list.
459
+ paged_kv_indptr: List[int] = [0]
460
+ # paged_kv_last_page_len is the length of the last page of each request
461
+ paged_kv_last_page_len: List[int] = []
462
+
463
+ if len(seq_group_metadata_list) == 0:
464
+ return PrepareDecodeMetadata.empty()
465
+
466
+ for seq_group_metadata in seq_group_metadata_list:
467
+ assert not seq_group_metadata.is_prompt
468
+ assert seq_group_metadata.token_chunk_size == 1
469
+
470
+ seq_ids = list(seq_group_metadata.seq_data.keys())
471
+ lora_id = seq_group_metadata.lora_int_id
472
+
473
+ if lora_id > 0:
474
+ lora_requests.add(seq_group_metadata.lora_request)
475
+
476
+ for seq_id in seq_ids:
477
+ seq_data = seq_group_metadata.seq_data[seq_id]
478
+ generation_token = seq_data.get_last_token_id()
479
+ input_tokens.append(generation_token)
480
+
481
+ seq_len = seq_data.get_len()
482
+ position = seq_len - 1
483
+ input_positions.append(position)
484
+
485
+ seq_len = seq_len if self.sliding_window is None else min(
486
+ seq_len, self.sliding_window)
487
+ seq_lens.append(seq_len)
488
+
489
+ block_table = seq_group_metadata.block_tables[seq_id]
490
+ block_number = block_table[position // self.block_size]
491
+ block_offset = position % self.block_size
492
+ slot = block_number * self.block_size + block_offset
493
+ slot_mapping.append(slot)
494
+ lora_index_mapping.append(lora_id)
495
+ lora_prompt_mapping.append(lora_id)
496
+
497
+ if self.sliding_window is not None:
498
+ sliding_window_blocks = (self.sliding_window //
499
+ self.block_size)
500
+ block_table = block_table[-sliding_window_blocks:]
501
+ block_tables.append(block_table)
502
+
503
+ paged_kv_indices.extend(block_table)
504
+ paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
505
+ last_page_len = seq_data.get_len() % self.block_size
506
+ if last_page_len == 0:
507
+ last_page_len = self.block_size
508
+ paged_kv_last_page_len.append(last_page_len)
509
+
510
+ # vLLM uses cuda graph only for decoding requests.
511
+ # See `capture_model` API for more details.
512
+ # For decoding requests, batch_size == input_tokens.
513
+ batch_size = len(input_tokens)
514
+ max_seq_len = max(seq_lens)
515
+ use_captured_graph = (not self.model_config.enforce_eager
516
+ and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
517
+ and max_seq_len <= self.max_seq_len_to_capture)
518
+ if use_captured_graph:
519
+ graph_batch_size = _get_graph_batch_size(batch_size)
520
+ assert graph_batch_size >= batch_size
521
+ for _ in range(graph_batch_size - batch_size):
522
+ input_tokens.append(0)
523
+ input_positions.append(0)
524
+ slot_mapping.append(_PAD_SLOT_ID)
525
+ seq_lens.append(1)
526
+ block_tables.append([])
527
+ lora_index_mapping.append(0)
528
+ batch_size = graph_batch_size
529
+
530
+ seq_lens_tensor = torch.tensor(seq_lens,
531
+ dtype=torch.int,
532
+ device=self.device)
533
+
534
+ if use_captured_graph:
535
+ # When using cuda-graph all these tensors should be
536
+ # padded.
537
+ assert seq_lens_tensor.shape[0] == len(input_tokens)
538
+ assert seq_lens_tensor.shape[0] == len(input_positions)
539
+ assert seq_lens_tensor.shape[0] == len(slot_mapping)
540
+
541
+ # The shape of graph_block_tables is
542
+ # [max batch size, max context len // block size].
543
+ input_block_tables = self.graph_block_tables[:batch_size]
544
+ for i, block_table in enumerate(block_tables):
545
+ if block_table:
546
+ input_block_tables[i, :len(block_table)] = block_table
547
+ block_tables = torch.tensor(input_block_tables, device=self.device)
548
+ else:
549
+ max_block_table_len = max(
550
+ len(block_table) for block_table in block_tables)
551
+ block_tables = make_tensor_with_pad(
552
+ block_tables,
553
+ max_len=max_block_table_len,
554
+ pad=0,
555
+ dtype=torch.int,
556
+ device=self.device,
557
+ )
558
+
559
+ if self.attn_backend is FlashInferBackend:
560
+ if not hasattr(self, "flashinfer_workspace_buffer"):
561
+ # Allocate 16MB workspace buffer
562
+ # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
563
+ self.flashinfer_workspace_buffer = torch.empty(
564
+ 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
565
+ paged_kv_indptr = torch.tensor(paged_kv_indptr,
566
+ dtype=torch.int,
567
+ device=self.device)
568
+ paged_kv_indices = torch.tensor(paged_kv_indices,
569
+ dtype=torch.int,
570
+ device=self.device)
571
+ paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
572
+ dtype=torch.int,
573
+ device=self.device)
574
+ kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
575
+ self.model_config.dtype)
576
+
577
+ attn_metadata = self.attn_backend.make_metadata(
578
+ is_prompt=False,
579
+ use_cuda_graph=False,
580
+ workspace_buffer=self.flashinfer_workspace_buffer,
581
+ paged_kv_indptr=paged_kv_indptr,
582
+ paged_kv_indices=paged_kv_indices,
583
+ paged_kv_last_page_len=paged_kv_last_page_len,
584
+ num_qo_heads=self.model_config.get_num_attention_heads(
585
+ self.parallel_config),
586
+ num_kv_heads=self.model_config.get_num_kv_heads(
587
+ self.parallel_config),
588
+ head_dim=self.model_config.get_head_size(),
589
+ page_size=self.block_size,
590
+ data_type=kv_cache_dtype)
591
+ else:
592
+ attn_metadata = self.attn_backend.make_metadata(
593
+ is_prompt=False,
594
+ seq_lens=None,
595
+ seq_lens_tensor=seq_lens_tensor,
596
+ max_query_len=None,
597
+ max_seq_len=max_seq_len,
598
+ subquery_start_loc=None,
599
+ seq_start_loc=None,
600
+ context_lens_tensor=None,
601
+ block_tables=block_tables,
602
+ use_cuda_graph=use_captured_graph,
603
+ )
604
+ return PrepareDecodeMetadata(
605
+ input_tokens=input_tokens,
606
+ input_positions=input_positions,
607
+ attn_metadata=attn_metadata,
608
+ lora_index_mapping=lora_index_mapping,
609
+ lora_prompt_mapping=lora_prompt_mapping,
610
+ lora_requests=lora_requests,
611
+ slot_mapping=slot_mapping,
612
+ )
613
+
614
+ def prepare_input_tensors(
615
+ self,
616
+ seq_group_metadata_list: List[SequenceGroupMetadata],
617
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
618
+ Set[LoRARequest], LoRAMapping, torch.Tensor]:
619
+ if self.is_driver_worker:
620
+ prefill_reqs = []
621
+ decode_reqs = []
622
+ for seq_group_meta in seq_group_metadata_list:
623
+ if seq_group_meta.is_prompt:
624
+ prefill_reqs.append(seq_group_meta)
625
+ else:
626
+ decode_reqs.append(seq_group_meta)
627
+
628
+ # Prepare input tensors.
629
+ (
630
+ input_tokens,
631
+ input_positions,
632
+ prefill_attn_metadata,
633
+ seq_lens,
634
+ query_lens,
635
+ lora_index_mapping,
636
+ lora_prompt_mapping,
637
+ lora_requests,
638
+ multi_modal_input,
639
+ slot_mapping,
640
+ ) = self._prepare_prompt(prefill_reqs)
641
+ (
642
+ decode_input_tokens,
643
+ decode_input_positions,
644
+ decode_attn_metadata,
645
+ decode_lora_index_mapping,
646
+ decode_lora_prompt_mapping,
647
+ decode_lora_requests,
648
+ decode_slot_mapping,
649
+ ) = self._prepare_decode(decode_reqs)
650
+ sampling_metadata = SamplingMetadata.prepare(
651
+ seq_group_metadata_list, seq_lens, query_lens, self.device,
652
+ self.pin_memory)
653
+
654
+ if not self.scheduler_config.chunked_prefill_enabled:
655
+ assert (len(prefill_reqs) and len(decode_reqs)) == 0
656
+
657
+ num_prefills = len(seq_lens)
658
+ num_prefill_tokens = len(input_tokens)
659
+ num_decode_tokens = len(decode_input_tokens)
660
+
661
+ # Coalesce tensors. Note that attn_metadata is currently not
662
+ # coalesced for simplicity.
663
+ input_tokens.extend(decode_input_tokens)
664
+ input_positions.extend(decode_input_positions)
665
+ slot_mapping.extend(decode_slot_mapping)
666
+ lora_index_mapping.extend(decode_lora_index_mapping)
667
+ lora_prompt_mapping.extend(decode_lora_prompt_mapping)
668
+ lora_requests.update(decode_lora_requests)
669
+
670
+ input_tokens = torch.tensor(input_tokens,
671
+ dtype=torch.long,
672
+ device=self.device)
673
+ input_positions = torch.tensor(input_positions,
674
+ dtype=torch.long,
675
+ device=self.device)
676
+ slot_mapping = torch.tensor(slot_mapping,
677
+ dtype=torch.long,
678
+ device=self.device)
679
+
680
+ if self.lora_config:
681
+ lora_mapping = LoRAMapping(
682
+ lora_index_mapping,
683
+ lora_prompt_mapping,
684
+ )
685
+ else:
686
+ lora_mapping = None
687
+
688
+ # Broadcast the metadata.
689
+ # If batch contains both prefill and decode, it sends 2 broadcasts.
690
+ # If it only contains 1 type, it triggers a single broadcast.
691
+ if (prefill_attn_metadata is not None
692
+ and decode_attn_metadata is not None):
693
+ batch_type = BatchType.MIXED
694
+ elif prefill_attn_metadata is not None:
695
+ batch_type = BatchType.PREFILL
696
+ else:
697
+ batch_type = BatchType.DECODE
698
+
699
+ metadata_dict = {
700
+ "input_tokens": input_tokens,
701
+ "input_positions": input_positions,
702
+ "selected_token_indices":
703
+ sampling_metadata.selected_token_indices,
704
+ "lora_requests": lora_requests,
705
+ "lora_mapping": lora_mapping,
706
+ "multi_modal_input": multi_modal_input,
707
+ "num_prefill_tokens": num_prefill_tokens,
708
+ "num_decode_tokens": num_decode_tokens,
709
+ "slot_mapping": slot_mapping,
710
+ "num_prefills": num_prefills,
711
+ "batch_type": batch_type,
712
+ }
713
+ if prefill_attn_metadata is not None:
714
+ metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
715
+ else:
716
+ assert decode_attn_metadata is not None
717
+ metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
718
+ broadcast_tensor_dict(metadata_dict, src=0)
719
+
720
+ # Broadcast decode attn metadata for mixed batch type.
721
+ # The additional broadcast costs 300us overhead on 4 A10 GPUs.
722
+ # We can potentially reduce the overhead by coelescing tensors.
723
+ if batch_type == BatchType.MIXED:
724
+ assert decode_attn_metadata is not None
725
+ metadata_dict = decode_attn_metadata.asdict_zerocopy()
726
+ broadcast_tensor_dict(metadata_dict, src=0)
727
+ else:
728
+ metadata_dict = broadcast_tensor_dict(src=0)
729
+ input_tokens = metadata_dict.pop("input_tokens")
730
+ input_positions = metadata_dict.pop("input_positions")
731
+ slot_mapping = metadata_dict.pop("slot_mapping")
732
+ num_prefills = metadata_dict.pop("num_prefills")
733
+ selected_token_indices = metadata_dict.pop(
734
+ "selected_token_indices")
735
+ lora_mapping = metadata_dict.pop("lora_mapping")
736
+ lora_requests = metadata_dict.pop("lora_requests")
737
+ multi_modal_input = metadata_dict.pop("multi_modal_input")
738
+ num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
739
+ num_decode_tokens = metadata_dict.pop("num_decode_tokens")
740
+ batch_type = metadata_dict.pop("batch_type")
741
+
742
+ # Create an attention metadata.
743
+ prefill_attn_metadata = None
744
+ decode_attn_metadata = None
745
+ if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
746
+ prefill_attn_metadata = self.attn_backend.make_metadata(
747
+ **metadata_dict)
748
+ else:
749
+ decode_attn_metadata = self.attn_backend.make_metadata(
750
+ **metadata_dict)
751
+ sampling_metadata = SamplingMetadata(
752
+ seq_groups=None,
753
+ selected_token_indices=selected_token_indices,
754
+ categorized_sample_indices=None,
755
+ num_prompts=0,
756
+ )
757
+
758
+ # if it is a mixed batch, decode attn_metadata is broadcasted
759
+ # separately.
760
+ if batch_type == BatchType.MIXED:
761
+ metadata_dict = broadcast_tensor_dict(src=0)
762
+ decode_attn_metadata = self.attn_backend.make_metadata(
763
+ **metadata_dict)
764
+
765
+ attn_metadata = AttentionMetadata(
766
+ num_prefills=num_prefills,
767
+ slot_mapping=slot_mapping,
768
+ num_prefill_tokens=num_prefill_tokens,
769
+ num_decode_tokens=num_decode_tokens,
770
+ prefill_metadata=prefill_attn_metadata,
771
+ decode_metadata=decode_attn_metadata,
772
+ kv_cache_dtype=self.kv_cache_dtype,
773
+ )
774
+
775
+ return (input_tokens, input_positions, attn_metadata,
776
+ sampling_metadata, lora_requests, lora_mapping,
777
+ multi_modal_input)
778
+
779
+ @torch.inference_mode()
780
+ def execute_model(
781
+ self,
782
+ seq_group_metadata_list: List[SequenceGroupMetadata],
783
+ kv_caches: List[torch.Tensor],
784
+ ) -> Optional[SamplerOutput]:
785
+ (input_tokens, input_positions, attn_metadata, sampling_metadata,
786
+ lora_requests, lora_mapping, multi_modal_input
787
+ ) = self.prepare_input_tensors(seq_group_metadata_list)
788
+
789
+ if self.lora_config:
790
+ self.set_active_loras(lora_requests, lora_mapping)
791
+
792
+ # Currently cuda graph is only supported by the decode phase.
793
+ prefill_meta = attn_metadata.prefill_metadata
794
+ decode_meta = attn_metadata.decode_metadata
795
+ if prefill_meta is None and decode_meta.use_cuda_graph:
796
+ graph_batch_size = input_tokens.shape[0]
797
+ model_executable = self.graph_runners[graph_batch_size]
798
+ else:
799
+ model_executable = self.model
800
+ execute_model_kwargs = {
801
+ "input_ids": input_tokens,
802
+ "positions": input_positions,
803
+ "kv_caches": kv_caches,
804
+ "attn_metadata": attn_metadata,
805
+ }
806
+ if self.vision_language_config:
807
+ execute_model_kwargs.update({"image_input": multi_modal_input})
808
+ hidden_states = model_executable(**execute_model_kwargs)
809
+
810
+ # Compute the logits.
811
+ logits = self.model.compute_logits(hidden_states, sampling_metadata)
812
+
813
+ # Only perform sampling in the driver worker.
814
+ if not self.is_driver_worker:
815
+ return None
816
+
817
+ # Sample the next token.
818
+ output = self.model.sample(
819
+ logits=logits,
820
+ sampling_metadata=sampling_metadata,
821
+ )
822
+
823
+ return output
824
+
825
+ @torch.inference_mode()
826
+ def profile_run(self) -> None:
827
+ # Enable top-k sampling to reflect the accurate memory usage.
828
+ sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
829
+ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
830
+ max_num_seqs = self.scheduler_config.max_num_seqs
831
+
832
+ # This represents the maximum number of different requests
833
+ # that will have unique loras, an therefore the max amount of memory
834
+ # consumption create dummy lora request copies from the lora request
835
+ # passed in, which contains a lora from the lora warmup path.
836
+ dummy_lora_requests = []
837
+ dummy_lora_requests_per_seq = []
838
+ if self.lora_config:
839
+ for idx in range(self.lora_config.max_loras):
840
+ lora_id = idx + 1
841
+ dummy_lora_request = LoRARequest(
842
+ lora_name=f"warmup_{lora_id}",
843
+ lora_int_id=lora_id,
844
+ lora_local_path="/not/a/real/path",
845
+ )
846
+ self.lora_manager.add_dummy_lora(dummy_lora_request,
847
+ rank=LORA_WARMUP_RANK)
848
+ dummy_lora_requests.append(dummy_lora_request)
849
+ dummy_lora_requests_per_seq = [
850
+ dummy_lora_requests[idx % len(dummy_lora_requests)]
851
+ for idx in range(max_num_seqs)
852
+ ]
853
+
854
+ # Profile memory usage with max_num_sequences sequences and the total
855
+ # number of tokens equal to max_num_batched_tokens.
856
+ seqs: List[SequenceGroupMetadata] = []
857
+ # Additional GPU memory may be needed for vision encoding, which needs
858
+ # to be accounted for when calculating the GPU blocks for
859
+ # vLLM blocker manager.
860
+ # To exercise the worst scenario for GPU memory consumption,
861
+ # the number of seqs (batch_size) is chosen to maximize the number
862
+ # of images processed.
863
+ if self.vision_language_config:
864
+ max_num_seqs = min(
865
+ max_num_seqs,
866
+ int(max_num_batched_tokens /
867
+ self.vision_language_config.image_feature_size))
868
+ for group_id in range(max_num_seqs):
869
+ seq_len = (max_num_batched_tokens // max_num_seqs +
870
+ (group_id < max_num_batched_tokens % max_num_seqs))
871
+ seq_data, fake_multi_modal_input = _prepare_fake_inputs(
872
+ seq_len, self.vision_language_config)
873
+ seq = SequenceGroupMetadata(
874
+ request_id=str(group_id),
875
+ is_prompt=True,
876
+ seq_data={group_id: seq_data},
877
+ sampling_params=sampling_params,
878
+ block_tables=None,
879
+ lora_request=dummy_lora_requests_per_seq[group_id]
880
+ if dummy_lora_requests_per_seq else None,
881
+ multi_modal_data=fake_multi_modal_input,
882
+ )
883
+ seqs.append(seq)
884
+
885
+ # Run the model with the dummy inputs.
886
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
887
+ kv_caches = [None] * num_layers
888
+ self.execute_model(seqs, kv_caches)
889
+ torch.cuda.synchronize()
890
+ return
891
+
892
+ def remove_all_loras(self):
893
+ if not self.lora_manager:
894
+ raise RuntimeError("LoRA is not enabled.")
895
+ self.lora_manager.remove_all_loras()
896
+
897
+ def set_active_loras(self, lora_requests: Set[LoRARequest],
898
+ lora_mapping: LoRAMapping) -> None:
899
+ if not self.lora_manager:
900
+ raise RuntimeError("LoRA is not enabled.")
901
+ self.lora_manager.set_active_loras(lora_requests, lora_mapping)
902
+
903
+ def add_lora(self, lora_request: LoRARequest) -> bool:
904
+ if not self.lora_manager:
905
+ raise RuntimeError("LoRA is not enabled.")
906
+ return self.lora_manager.add_lora(lora_request)
907
+
908
+ def remove_lora(self, lora_id: int) -> bool:
909
+ if not self.lora_manager:
910
+ raise RuntimeError("LoRA is not enabled.")
911
+ return self.lora_manager.remove_lora(lora_id)
912
+
913
+ def list_loras(self) -> Set[int]:
914
+ if not self.lora_manager:
915
+ raise RuntimeError("LoRA is not enabled.")
916
+ return self.lora_manager.list_loras()
917
+
918
+ @torch.inference_mode()
919
+ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
920
+ """Cuda graph capture a model.
921
+
922
+ Note that CUDA graph's performance gain is negligible if number
923
+ of batched tokens are larger than 200. And since CUDA graph
924
+ requires fixed sized tensors, supporting large/variable batch
925
+ size requires high GPU memory overhead. Thus, vLLM only captures
926
+ decoding requests. Mixed batch (chunked prefill + decoding) or
927
+ prefill requests are not captured.
928
+
929
+ Since it is used for decoding-only, it assumes there's only 1 token
930
+ per sequence in the batch.
931
+ """
932
+ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
933
+ # deleted before the CUDA graphs.
934
+ self.pynccl_backend = pynccl_utils.get_nccl_backend()
935
+
936
+ assert not self.model_config.enforce_eager
937
+ logger.info("Capturing the model for CUDA graphs. This may lead to "
938
+ "unexpected consequences if the model is not static. To "
939
+ "run the model in eager mode, set 'enforce_eager=True' or "
940
+ "use '--enforce-eager' in the CLI.")
941
+ logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
942
+ "If you are running out of memory, consider decreasing "
943
+ "`gpu_memory_utilization` or enforcing eager mode. "
944
+ "You can also reduce the `max_num_seqs` as needed "
945
+ "to decrease memory usage.")
946
+ start_time = time.perf_counter()
947
+
948
+ # Prepare dummy inputs. These will be reused for all batch sizes.
949
+ max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
950
+ input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
951
+ input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
952
+ slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
953
+ slot_mapping.fill_(_PAD_SLOT_ID)
954
+ seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
955
+ block_tables = torch.from_numpy(self.graph_block_tables).cuda()
956
+
957
+ graph_batch_size = _get_graph_batch_size(
958
+ self.scheduler_config.max_num_seqs)
959
+ batch_size_capture_list = [
960
+ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
961
+ ]
962
+
963
+ # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
964
+ # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
965
+ # either custom all-reduce kernel or pynccl. When not using CUDA
966
+ # graph, we use either custom all-reduce kernel or PyTorch NCCL.
967
+ # We always prioritize using custom all-reduce kernel but fall back
968
+ # to PyTorch or pynccl if it is disabled or not supported.
969
+ with custom_all_reduce.capture():
970
+ # NOTE: Capturing the largest batch size first may help reduce the
971
+ # memory usage of CUDA graph.
972
+ for batch_size in reversed(batch_size_capture_list):
973
+ # Create dummy attn_metadata.
974
+ decode_metadata = self.attn_backend.make_metadata(
975
+ is_prompt=False,
976
+ seq_lens=None,
977
+ seq_lens_tensor=seq_lens[:batch_size],
978
+ max_query_len=None,
979
+ max_seq_len=self.max_seq_len_to_capture,
980
+ subquery_start_loc=None,
981
+ seq_start_loc=None,
982
+ context_lens_tensor=None,
983
+ block_tables=block_tables[:batch_size],
984
+ use_cuda_graph=True,
985
+ )
986
+ attn_metadata = AttentionMetadata(
987
+ num_prefills=0,
988
+ num_prefill_tokens=0,
989
+ num_decode_tokens=batch_size,
990
+ slot_mapping=slot_mapping[:batch_size],
991
+ prefill_metadata=None,
992
+ decode_metadata=decode_metadata,
993
+ kv_cache_dtype=self.kv_cache_dtype,
994
+ )
995
+
996
+ if self.lora_config:
997
+ lora_mapping = LoRAMapping(
998
+ [0] * batch_size,
999
+ [0] * batch_size,
1000
+ )
1001
+ self.set_active_loras(set(), lora_mapping)
1002
+
1003
+ graph_runner = CUDAGraphRunner(self.model)
1004
+ graph_runner.capture(
1005
+ input_tokens[:batch_size],
1006
+ input_positions[:batch_size],
1007
+ kv_caches,
1008
+ attn_metadata,
1009
+ memory_pool=self.graph_memory_pool,
1010
+ )
1011
+ self.graph_memory_pool = graph_runner.graph.pool()
1012
+ self.graph_runners[batch_size] = graph_runner
1013
+
1014
+ end_time = time.perf_counter()
1015
+ elapsed_time = end_time - start_time
1016
+ # This usually takes < 10 seconds.
1017
+ logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1018
+
1019
+ def __del__(self) -> None:
1020
+ # Delete the CUDA graphs before deleting the pynccl communicator.
1021
+ # NOTE(woosuk): This is necessary because otherwise deadlocks can
1022
+ # happen.
1023
+ # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
1024
+ # TODO(youkaichao): when we get enough user feedback that pynccl is
1025
+ # more stable than cupy, we can remove this, e.g. in v0.4.1.
1026
+ self.graph_runners.clear()
1027
+ self.pynccl_backend = None
1028
+
1029
+ @property
1030
+ def vocab_size(self) -> int:
1031
+ return self.model_config.get_vocab_size()
1032
+
1033
+
1034
+ class CUDAGraphRunner:
1035
+
1036
+ def __init__(self, model: nn.Module):
1037
+ self.model = model
1038
+ self.input_buffers: Dict[str, torch.Tensor] = {}
1039
+ self.output_buffers: Dict[str, torch.Tensor] = {}
1040
+
1041
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
1042
+
1043
+ @property
1044
+ def graph(self):
1045
+ assert self._graph is not None
1046
+ return self._graph
1047
+
1048
+ def capture(
1049
+ self,
1050
+ input_ids: torch.Tensor,
1051
+ positions: torch.Tensor,
1052
+ kv_caches: List[torch.Tensor],
1053
+ attn_metadata: AttentionMetadata,
1054
+ memory_pool,
1055
+ **kwargs,
1056
+ ) -> None:
1057
+ assert self._graph is None
1058
+ # Run the model once without capturing the graph.
1059
+ # This is to make sure that the captured graph does not include the
1060
+ # kernel launches for initial benchmarking (e.g., Triton autotune).
1061
+ with _maybe_pynccl():
1062
+ self.model(
1063
+ input_ids,
1064
+ positions,
1065
+ kv_caches,
1066
+ attn_metadata,
1067
+ **kwargs,
1068
+ )
1069
+ torch.cuda.synchronize()
1070
+
1071
+ # Capture the graph.
1072
+ # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
1073
+ # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1074
+ self._graph = torch.cuda.CUDAGraph()
1075
+ with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
1076
+ with _maybe_pynccl():
1077
+ hidden_states = self.model(
1078
+ input_ids,
1079
+ positions,
1080
+ kv_caches,
1081
+ attn_metadata,
1082
+ **kwargs,
1083
+ )
1084
+ torch.cuda.synchronize()
1085
+
1086
+ # Save the input and output buffers.
1087
+ self.input_buffers = {
1088
+ "input_ids": input_ids,
1089
+ "positions": positions,
1090
+ "kv_caches": kv_caches,
1091
+ "slot_mapping": attn_metadata.slot_mapping,
1092
+ "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1093
+ "block_tables": attn_metadata.decode_metadata.block_tables,
1094
+ }
1095
+ self.output_buffers = {"hidden_states": hidden_states}
1096
+ return
1097
+
1098
+ def forward(
1099
+ self,
1100
+ input_ids: torch.Tensor,
1101
+ positions: torch.Tensor,
1102
+ kv_caches: List[torch.Tensor],
1103
+ attn_metadata: AttentionMetadata,
1104
+ **kwargs,
1105
+ ) -> torch.Tensor:
1106
+ # KV caches are fixed tensors, so we don't need to copy them.
1107
+ del kv_caches
1108
+
1109
+ # Copy the input tensors to the input buffers.
1110
+ self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
1111
+ self.input_buffers["positions"].copy_(positions, non_blocking=True)
1112
+ self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1113
+ non_blocking=True)
1114
+ self.input_buffers["seq_lens_tensor"].copy_(
1115
+ attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1116
+ self.input_buffers["block_tables"].copy_(
1117
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
1118
+ # Run the graph.
1119
+ self.graph.replay()
1120
+
1121
+ # Return the output tensor.
1122
+ return self.output_buffers["hidden_states"]
1123
+
1124
+ def __call__(self, *args, **kwargs):
1125
+ return self.forward(*args, **kwargs)
1126
+
1127
+
1128
+ @contextlib.contextmanager
1129
+ def _maybe_pynccl():
1130
+ if pynccl_utils.is_initialized(
1131
+ ) and not custom_all_reduce.is_initialized():
1132
+ with with_pynccl_for_all_reduce():
1133
+ yield
1134
+ else:
1135
+ yield
1136
+
1137
+
1138
+ def _get_graph_batch_size(batch_size: int) -> int:
1139
+ """Returns the padded batch size given actual batch size.
1140
+
1141
+ Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
1142
+ 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
1143
+ """
1144
+ if batch_size <= 2:
1145
+ return batch_size
1146
+ elif batch_size <= 4:
1147
+ return 4
1148
+ else:
1149
+ return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
1150
+ _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1151
+
1152
+
1153
+ def _prepare_fake_inputs(
1154
+ seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
1155
+ """Prepare fake inputs for profile run."""
1156
+ if vision_language_config:
1157
+ prompt_tokens = [
1158
+ vision_language_config.image_token_id
1159
+ ] * vision_language_config.image_feature_size + [0] * (
1160
+ seq_len - vision_language_config.image_feature_size)
1161
+ fake_image_input = MultiModalData(
1162
+ type=MultiModalData.Type.IMAGE,
1163
+ data=torch.zeros(vision_language_config.image_input_shape,
1164
+ dtype=torch.float16))
1165
+ else:
1166
+ prompt_tokens = [0] * seq_len
1167
+ fake_image_input = None
1168
+ return SequenceData(prompt_tokens), fake_image_input