vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/sequence.py ADDED
@@ -0,0 +1,766 @@
1
+ """Sequence and its related classes."""
2
+ import copy
3
+ import enum
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
6
+
7
+ from vllm.block import LogicalTokenBlock
8
+ from vllm.lora.request import LoRARequest
9
+ from vllm.sampling_params import SamplingParams
10
+
11
+ if TYPE_CHECKING:
12
+ import torch
13
+
14
+ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
15
+
16
+
17
+ @dataclass
18
+ class Logprob:
19
+ """Infos for supporting OpenAI compatible logprobs and token ranks.
20
+
21
+ Attributes:
22
+ logprob: The logprob of chosen token
23
+ rank: The vocab rank of chosen token (>=1)
24
+ decoded_token: The decoded chosen token index
25
+ """
26
+ logprob: float
27
+ rank: Optional[int] = None
28
+ decoded_token: Optional[str] = None
29
+
30
+
31
+ # {token_id -> logprob} per each sequence group. None if the corresponding
32
+ # sequence group doesn't require prompt logprob.
33
+ PromptLogprobs = List[Optional[Dict[int, Logprob]]]
34
+ # {token_id -> logprob} for each sequence group.
35
+ SampleLogprobs = List[Dict[int, Logprob]]
36
+
37
+
38
+ class SequenceStatus(enum.Enum):
39
+ """Status of a sequence."""
40
+ WAITING = enum.auto()
41
+ RUNNING = enum.auto()
42
+ SWAPPED = enum.auto()
43
+ FINISHED_STOPPED = enum.auto()
44
+ FINISHED_LENGTH_CAPPED = enum.auto()
45
+ FINISHED_ABORTED = enum.auto()
46
+ FINISHED_IGNORED = enum.auto()
47
+
48
+ @staticmethod
49
+ def is_finished(status: "SequenceStatus") -> bool:
50
+ return status in [
51
+ SequenceStatus.FINISHED_STOPPED,
52
+ SequenceStatus.FINISHED_LENGTH_CAPPED,
53
+ SequenceStatus.FINISHED_ABORTED,
54
+ SequenceStatus.FINISHED_IGNORED,
55
+ ]
56
+
57
+ @staticmethod
58
+ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
59
+ if status == SequenceStatus.FINISHED_STOPPED:
60
+ finish_reason = "stop"
61
+ elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
62
+ finish_reason = "length"
63
+ elif status == SequenceStatus.FINISHED_ABORTED:
64
+ finish_reason = "abort"
65
+ elif status == SequenceStatus.FINISHED_IGNORED:
66
+ # The ignored sequences are the sequences whose prompt lengths
67
+ # are longer than the model's length cap. Therefore, the stop
68
+ # reason should also be "length" as in OpenAI API.
69
+ finish_reason = "length"
70
+ else:
71
+ finish_reason = None
72
+ return finish_reason
73
+
74
+
75
+ class SequenceStage(enum.Enum):
76
+ PREFILL = enum.auto()
77
+ DECODE = enum.auto()
78
+
79
+
80
+ @dataclass
81
+ class RequestMetrics:
82
+ """Metrics associated with a request.
83
+
84
+ Attributes:
85
+ arrival_time: The time when the request arrived.
86
+ first_scheduled_time: The time when the request was first scheduled.
87
+ first_token_time: The time when the first token was generated.
88
+ time_in_queue: The time the request spent in the queue.
89
+ finished_time: The time when the request was finished.
90
+ """
91
+ arrival_time: float
92
+ last_token_time: float
93
+ first_scheduled_time: Optional[float]
94
+ first_token_time: Optional[float]
95
+ time_in_queue: Optional[float]
96
+ finished_time: Optional[float] = None
97
+
98
+
99
+ class SequenceData:
100
+ """Data associated with a sequence.
101
+
102
+ Args:
103
+ prompt_token_ids: The token IDs of the prompt.
104
+ output_token_ids: The token IDs of the output. Set to an empty list if
105
+ None.
106
+
107
+ Attributes:
108
+ prompt_token_ids: The token IDs of the prompt.
109
+ output_token_ids: The token IDs of the output.
110
+ cumulative_logprob: The cumulative log probability of the output.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ prompt_token_ids: List[int],
116
+ output_token_ids: Optional[List[int]] = None,
117
+ ) -> None:
118
+ if output_token_ids is None:
119
+ output_token_ids = []
120
+
121
+ self.prompt_token_ids = prompt_token_ids
122
+ self.output_token_ids = output_token_ids
123
+ self.cumulative_logprob = 0.0
124
+ # The number of tokens that are computed (that run against the model).
125
+ self._num_computed_tokens = 0
126
+ self._stage: SequenceStage = SequenceStage.PREFILL
127
+
128
+ def append_token_id(self, token_id: int, logprob: float) -> None:
129
+ self.output_token_ids.append(token_id)
130
+ self.cumulative_logprob += logprob
131
+
132
+ def get_len(self) -> int:
133
+ return len(self.output_token_ids) + len(self.prompt_token_ids)
134
+
135
+ def get_prompt_len(self) -> int:
136
+ return len(self.prompt_token_ids)
137
+
138
+ def get_output_len(self) -> int:
139
+ return len(self.output_token_ids)
140
+
141
+ def get_token_ids(self) -> List[int]:
142
+ return self.prompt_token_ids + self.output_token_ids
143
+
144
+ def get_num_computed_tokens(self) -> int:
145
+ """Return the number of prefill tokens that are already computed."""
146
+ return self._num_computed_tokens
147
+
148
+ def update_num_computed_tokens(self, num_new_computed_tokens: int):
149
+ """Update number of tokens computed so far."""
150
+ self._num_computed_tokens += num_new_computed_tokens
151
+ assert self._num_computed_tokens <= self.get_len(), (
152
+ self._num_computed_tokens, self.get_len())
153
+ # If all tokens are computed, it means it is in decoding phase.
154
+ if self.get_num_uncomputed_tokens() == 0:
155
+ self._stage = SequenceStage.DECODE
156
+
157
+ def reset_state_for_recompute(self) -> None:
158
+ """Reset the number of computed tokens from this sequence. It is
159
+ supposed to be called when a sequence needs to be started from
160
+ the beginning again (e.g., sequence is preempted).
161
+ """
162
+ self._num_computed_tokens = 0
163
+ self._stage = SequenceStage.PREFILL
164
+
165
+ def get_num_uncomputed_tokens(self) -> int:
166
+ """Return the number of prefill tokens that are not computed."""
167
+ # we use `get_len()` which includes prompt_len + output_len instead
168
+ # of prompt_len here. This is because during recompute we need to
169
+ # prefill for both prompt and output.
170
+ return self.get_len() - self.get_num_computed_tokens()
171
+
172
+ def get_last_token_id(self) -> int:
173
+ if not self.output_token_ids:
174
+ return self.prompt_token_ids[-1]
175
+ return self.output_token_ids[-1]
176
+
177
+ def get_prompt_token_ids(self) -> List[int]:
178
+ return self.prompt_token_ids
179
+
180
+ def get_output_token_ids(self) -> List[int]:
181
+ return self.output_token_ids
182
+
183
+ @property
184
+ def stage(self) -> SequenceStage:
185
+ return self._stage
186
+
187
+ def __repr__(self) -> str:
188
+ return (f"SequenceData("
189
+ f"prompt_token_ids={self.prompt_token_ids}, "
190
+ f"output_token_ids={self.output_token_ids}, "
191
+ f"cumulative_logprob={self.cumulative_logprob})")
192
+
193
+
194
+ class Sequence:
195
+ """Stores the data, status, and block information of a sequence.
196
+
197
+ Args:
198
+ seq_id: The ID of the sequence.
199
+ prompt: The prompt of the sequence.
200
+ prompt_token_ids: The token IDs of the prompt.
201
+ block_size: The block size of the sequence. Should be the same as the
202
+ block size used by the block manager and cache engine.
203
+ lora_request: LoRA request.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ seq_id: int,
209
+ prompt: str,
210
+ prompt_token_ids: List[int],
211
+ block_size: int,
212
+ eos_token_id: Optional[int] = None,
213
+ lora_request: Optional[LoRARequest] = None,
214
+ ) -> None:
215
+ self.seq_id = seq_id
216
+ self.prompt = prompt
217
+ self.block_size = block_size
218
+ self.eos_token_id = eos_token_id
219
+ self.lora_request = lora_request
220
+
221
+ self.data: SequenceData = SequenceData(prompt_token_ids)
222
+ self.output_logprobs: SampleLogprobs = []
223
+ self.output_text = ""
224
+
225
+ self.logical_token_blocks: List[LogicalTokenBlock] = []
226
+ # Initialize the logical token blocks with the prompt token ids.
227
+ self._append_tokens_to_blocks(prompt_token_ids)
228
+ self.status = SequenceStatus.WAITING
229
+ self.stop_reason: Union[int, str, None] = None
230
+
231
+ # Used for incremental detokenization
232
+ self.prefix_offset = 0
233
+ self.read_offset = 0
234
+ # Input + output tokens
235
+ self.tokens: Optional[List[str]] = None
236
+
237
+ @property
238
+ def lora_int_id(self) -> int:
239
+ return self.lora_request.lora_int_id if self.lora_request else 0
240
+
241
+ def get_output_text_to_return(self, buffer_length: int):
242
+ # We return the full output text if the sequence is finished.
243
+ truncate = buffer_length and not self.is_finished()
244
+ return self.output_text[:-buffer_length] if truncate else (
245
+ self.output_text)
246
+
247
+ def hash_of_block(self, logical_idx: int) -> int:
248
+ # TODO This can produce incorrect hash when block size > prompt size
249
+
250
+ # Compute the number of tokens in the sequence
251
+ # TODO: The current hashing function is O(L^2). We should optimize
252
+ # this in the future.
253
+ num_tokens = self.num_hashed_tokens_of_block(logical_idx)
254
+ return hash(
255
+ (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
256
+
257
+ def num_hashed_tokens_of_block(self, logical_idx: int):
258
+ return logical_idx * self.block_size + self.block_size
259
+
260
+ def reset_state_for_recompute(self):
261
+ """Reset the sequence states for recomputation."""
262
+ self.data.reset_state_for_recompute()
263
+
264
+ def _append_logical_block(self) -> None:
265
+ block = LogicalTokenBlock(
266
+ block_number=len(self.logical_token_blocks),
267
+ block_size=self.block_size,
268
+ )
269
+ self.logical_token_blocks.append(block)
270
+
271
+ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
272
+ cursor = 0
273
+ while cursor < len(token_ids):
274
+ if not self.logical_token_blocks:
275
+ self._append_logical_block()
276
+
277
+ last_block = self.logical_token_blocks[-1]
278
+ if last_block.is_full():
279
+ self._append_logical_block()
280
+ last_block = self.logical_token_blocks[-1]
281
+
282
+ num_empty_slots = last_block.get_num_empty_slots()
283
+ last_block.append_tokens(token_ids[cursor:cursor +
284
+ num_empty_slots])
285
+ cursor += num_empty_slots
286
+
287
+ def append_token_id(
288
+ self,
289
+ token_id: int,
290
+ logprobs: Dict[int, Logprob],
291
+ ) -> None:
292
+ assert token_id in logprobs
293
+ self._append_tokens_to_blocks([token_id])
294
+ self.output_logprobs.append(logprobs)
295
+ self.data.append_token_id(token_id, logprobs[token_id].logprob)
296
+
297
+ def get_len(self) -> int:
298
+ return self.data.get_len()
299
+
300
+ def get_prompt_len(self) -> int:
301
+ return self.data.get_prompt_len()
302
+
303
+ def get_output_len(self) -> int:
304
+ return self.data.get_output_len()
305
+
306
+ def get_token_ids(self) -> List[int]:
307
+ return self.data.get_token_ids()
308
+
309
+ def get_prompt_token_ids(self) -> List[int]:
310
+ return self.data.get_prompt_token_ids()
311
+
312
+ def get_last_token_id(self) -> int:
313
+ return self.data.get_last_token_id()
314
+
315
+ def get_output_token_ids(self) -> List[int]:
316
+ return self.data.output_token_ids
317
+
318
+ def get_cumulative_logprob(self) -> float:
319
+ return self.data.cumulative_logprob
320
+
321
+ def get_beam_search_score(self,
322
+ length_penalty: float = 1.0,
323
+ seq_len: Optional[int] = None,
324
+ eos_token_id: Optional[int] = None) -> float:
325
+ """Calculate the beam search score with length penalty.
326
+
327
+ Adapted from
328
+
329
+ https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
330
+ """
331
+ if seq_len is None:
332
+ seq_len = self.get_len()
333
+ # NOTE: HF implementation does not count the EOS token
334
+ # towards the length, we align with that here for testing.
335
+ if (eos_token_id is not None
336
+ and self.get_last_token_id() == eos_token_id):
337
+ seq_len -= 1
338
+ return self.get_cumulative_logprob() / (seq_len**length_penalty)
339
+
340
+ def is_finished(self) -> bool:
341
+ return SequenceStatus.is_finished(self.status)
342
+
343
+ def fork(self, new_seq_id: int) -> "Sequence":
344
+ new_seq = copy.deepcopy(self)
345
+ new_seq.seq_id = new_seq_id
346
+ return new_seq
347
+
348
+ def get_num_new_tokens(self) -> int:
349
+ """Get the number of new tokens to be computed.
350
+
351
+ Returns:
352
+ The new number of tokens to be computed. I.e., 1 for decode, or
353
+ the remaining prompt size for prefill.
354
+ """
355
+ if self.data.stage == SequenceStage.DECODE:
356
+ return 1
357
+ return self.data.get_num_uncomputed_tokens()
358
+
359
+ def is_prefill(self) -> bool:
360
+ return self.data.stage == SequenceStage.PREFILL
361
+
362
+ def __repr__(self) -> str:
363
+ return (f"Sequence(seq_id={self.seq_id}, "
364
+ f"status={self.status.name}, "
365
+ f"num_blocks={len(self.logical_token_blocks)})")
366
+
367
+
368
+ @dataclass
369
+ class SequenceGroupState:
370
+ """Mutable state tied to a specific sequence group"""
371
+
372
+ # torch.Generator used in seeded sampling
373
+ generator: Optional = None # type: ignore
374
+
375
+
376
+ class MultiModalData:
377
+ """Multi modal request.
378
+
379
+ Args:
380
+ type: The data type.
381
+ data: The actual data.
382
+ The required shape and semantic meaning of it depends on the vision
383
+ language config of the hosted model.
384
+ See `VisionLanguageConfig` in `config.py`.
385
+ """
386
+
387
+ class Type(enum.Enum):
388
+ IMAGE = enum.auto()
389
+
390
+ def __init__(self, type: Type, data: "torch.Tensor"):
391
+ self.type = type
392
+ self.data = data
393
+
394
+
395
+ class SequenceGroup:
396
+ """A group of sequences that are generated from the same prompt.
397
+
398
+ Args:
399
+ request_id: The ID of the request.
400
+ seqs: The list of sequences.
401
+ sampling_params: The sampling parameters used to generate the outputs.
402
+ arrival_time: The arrival time of the request.
403
+ lora_request: LoRA request.
404
+ multi_modal_data: Multi modal data associated with the request.
405
+ """
406
+
407
+ def __init__(
408
+ self,
409
+ request_id: str,
410
+ seqs: List[Sequence],
411
+ sampling_params: SamplingParams,
412
+ arrival_time: float,
413
+ lora_request: Optional[LoRARequest] = None,
414
+ multi_modal_data: Optional[MultiModalData] = None,
415
+ ) -> None:
416
+ self.request_id = request_id
417
+ self.seqs_dict = {seq.seq_id: seq for seq in seqs}
418
+ self.sampling_params = sampling_params
419
+ self.metrics = RequestMetrics(arrival_time=arrival_time,
420
+ last_token_time=arrival_time,
421
+ first_scheduled_time=None,
422
+ first_token_time=None,
423
+ time_in_queue=None)
424
+ self.lora_request = lora_request
425
+ self.prompt_logprobs: Optional[PromptLogprobs] = None
426
+ self.state = SequenceGroupState()
427
+ self.multi_modal_data = multi_modal_data
428
+
429
+ @property
430
+ def prompt(self) -> str:
431
+ # All sequences in the group should have the same prompt.
432
+ # We use the prompt of an arbitrary sequence.
433
+ return next(iter(self.seqs_dict.values())).prompt
434
+
435
+ @property
436
+ def prompt_token_ids(self) -> List[int]:
437
+ # All sequences in the group should have the same prompt.
438
+ # We use the prompt of an arbitrary sequence.
439
+ return next(iter(self.seqs_dict.values())).data.prompt_token_ids
440
+
441
+ @property
442
+ def lora_int_id(self) -> int:
443
+ return self.lora_request.lora_int_id if self.lora_request else 0
444
+
445
+ def get_last_latency(self, now: float) -> Optional[float]:
446
+ """Sets the last token time for Request level timings."""
447
+ # If still in prefill phase, raise Error.
448
+ if self.is_prefill():
449
+ raise ValueError(
450
+ "seq_group.get_last_latency() should not be called "
451
+ "if the seq_group is in prefill phase.")
452
+
453
+ # Otherwise return token latency.
454
+ latency = now - self.metrics.last_token_time
455
+ self.metrics.last_token_time = now
456
+ return latency
457
+
458
+ def maybe_set_first_token_time(self, time: float) -> None:
459
+ """Sets the first token time for Request level timings."""
460
+ # Note: in a case where a sequence_group is swapped and
461
+ # recomputed, the time between iterations is counted
462
+ # in TPOT, rather than recalculating TTFT (since from the )
463
+ # POV of the user, there is simply a long generation delay.
464
+ if (self.metrics.first_token_time is None
465
+ and self.get_seqs()[0].get_output_len() == 1):
466
+ self.metrics.first_token_time = time
467
+
468
+ def maybe_set_first_scheduled_time(self, time: float) -> None:
469
+ """Sets the first scheduled time and time in queue for Request
470
+ level timings."""
471
+ if self.metrics.first_scheduled_time is None:
472
+ self.metrics.first_scheduled_time = time
473
+ self.metrics.time_in_queue = time - self.metrics.arrival_time
474
+
475
+ def set_finished_time(self, time: Optional[float]) -> None:
476
+ """Sets the finished time for Request level timings."""
477
+ self.metrics.finished_time = time
478
+
479
+ def get_max_num_running_seqs(self) -> int:
480
+ """The maximum number of sequences running in parallel in the remaining
481
+ lifetime of the request."""
482
+ if self.sampling_params.use_beam_search:
483
+ # For beam search, maximally there will always be `best_of` beam
484
+ # candidates running in the future.
485
+ return self.sampling_params.best_of
486
+ else:
487
+ if self.sampling_params.best_of > self.num_seqs():
488
+ # At prompt stage, the sequence group is not yet filled up
489
+ # and only have one sequence running. However, in the
490
+ # generation stage, we will have `best_of` sequences running.
491
+ return self.sampling_params.best_of
492
+ # At sampling stages, return the number of actual sequences
493
+ # that are not finished yet.
494
+ return self.num_unfinished_seqs()
495
+
496
+ def get_seqs(
497
+ self,
498
+ status: Optional[SequenceStatus] = None,
499
+ ) -> List[Sequence]:
500
+ return list(self.seqs_dict.values()) if status is None else [
501
+ seq for seq in self.seqs_dict.values() if seq.status == status
502
+ ]
503
+
504
+ def get_unfinished_seqs(self) -> List[Sequence]:
505
+ return [
506
+ seq for seq in self.seqs_dict.values() if not seq.is_finished()
507
+ ]
508
+
509
+ def get_finished_seqs(self) -> List[Sequence]:
510
+ return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
511
+
512
+ def update_num_computed_tokens(self, num_new_computed_tokens: int):
513
+ """Update number of tokens computed so far."""
514
+ for seq in self.seqs_dict.values():
515
+ if not seq.is_finished():
516
+ seq.data.update_num_computed_tokens(num_new_computed_tokens)
517
+
518
+ def get_num_uncomputed_tokens(self) -> int:
519
+ num_uncomputed_tokens = 0
520
+ for seq in self.get_seqs():
521
+ if not seq.is_finished():
522
+ num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
523
+ return num_uncomputed_tokens
524
+
525
+ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
526
+ # Optimization. We don't need to call get_seqs if we don't need to
527
+ # filter by states.
528
+ if status is None:
529
+ return len(self.seqs_dict)
530
+
531
+ return len(self.get_seqs(status))
532
+
533
+ def num_unfinished_seqs(self) -> int:
534
+ return len(self.get_unfinished_seqs())
535
+
536
+ def num_finished_seqs(self) -> int:
537
+ return len(self.get_finished_seqs())
538
+
539
+ def find(self, seq_id: int) -> Sequence:
540
+ if seq_id not in self.seqs_dict:
541
+ raise ValueError(f"Sequence {seq_id} not found.")
542
+ return self.seqs_dict[seq_id]
543
+
544
+ def add(self, seq: Sequence) -> None:
545
+ if seq.seq_id in self.seqs_dict:
546
+ raise ValueError(f"Sequence {seq.seq_id} already exists.")
547
+ self.seqs_dict[seq.seq_id] = seq
548
+
549
+ def remove(self, seq_id: int) -> None:
550
+ if seq_id not in self.seqs_dict:
551
+ raise ValueError(f"Sequence {seq_id} not found.")
552
+ del self.seqs_dict[seq_id]
553
+
554
+ def is_finished(self) -> bool:
555
+ return all(seq.is_finished() for seq in self.get_seqs())
556
+
557
+ def is_prefill(self) -> bool:
558
+ # Every sequences should be in the same stage.
559
+ return self.get_seqs()[0].is_prefill()
560
+
561
+ def __repr__(self) -> str:
562
+ return (f"SequenceGroup(request_id={self.request_id}, "
563
+ f"sampling_params={self.sampling_params}, "
564
+ f"num_seqs={len(self.seqs_dict)})")
565
+
566
+
567
+ class SequenceGroupMetadata:
568
+ """Metadata for a sequence group. Used to create `AttentionMetadata`.
569
+
570
+ Args:
571
+ request_id: The ID of the request.
572
+ is_prompt: Whether the request is at prompt stage.
573
+ seq_data: The sequence data. (Seq id -> sequence data)
574
+ sampling_params: The sampling parameters used to generate the outputs.
575
+ block_tables: The block tables. (Seq id -> list of physical block
576
+ numbers)
577
+ do_sample: True if sampling is required. Sampling is not required when
578
+ e.g., prefill is chunked, and the current iteration only computes
579
+ query tokens for prefill, we don't need sampling.
580
+ token_chunk_size: The number of tokens to be processed (per sequence).
581
+ None if chunking is not required.
582
+ lora_request: LoRA request.
583
+ computed_block_nums: The block numbers that are already computed,
584
+ used in prefix caching.
585
+ state: Internal state tied to this sequence group.
586
+ multi_modal_data: Multi modal data.
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ request_id: str,
592
+ is_prompt: bool,
593
+ seq_data: Dict[int, SequenceData],
594
+ sampling_params: SamplingParams,
595
+ block_tables: Dict[int, List[int]],
596
+ do_sample: bool = True,
597
+ token_chunk_size: Optional[int] = None,
598
+ lora_request: Optional[LoRARequest] = None,
599
+ computed_block_nums: Optional[List[int]] = None,
600
+ state: Optional[SequenceGroupState] = None,
601
+ multi_modal_data: Optional[MultiModalData] = None,
602
+ ) -> None:
603
+ self.request_id = request_id
604
+ self.is_prompt = is_prompt
605
+ self.seq_data = seq_data
606
+ self.sampling_params = sampling_params
607
+ self.block_tables = block_tables
608
+ self.lora_request = lora_request
609
+ self.computed_block_nums = computed_block_nums
610
+ self.multi_modal_data = multi_modal_data
611
+ self.state = SequenceGroupState() if state is None else state
612
+ self._token_chunk_size = token_chunk_size
613
+ self.do_sample = do_sample
614
+
615
+ if self._token_chunk_size is None:
616
+ if is_prompt:
617
+ self._token_chunk_size = list(seq_data.values())[0].get_len()
618
+ else:
619
+ self._token_chunk_size = 1
620
+
621
+ @property
622
+ def lora_int_id(self) -> int:
623
+ return self.lora_request.lora_int_id if self.lora_request else 0
624
+
625
+ @property
626
+ def token_chunk_size(self) -> Optional[int]:
627
+ """Return the number of tokens to be processed (chunk size)."""
628
+ return self._token_chunk_size
629
+
630
+
631
+ class SequenceOutput:
632
+ """The model output associated with a sequence.
633
+
634
+ Args:
635
+ parent_seq_id: The ID of the parent sequence (for forking in beam
636
+ search).
637
+ output_token: The output token ID.
638
+ logprobs: The logprobs of the output token.
639
+ (Token id -> logP(x_i+1 | x_0, ..., x_i))
640
+ """
641
+
642
+ def __init__(
643
+ self,
644
+ parent_seq_id: int,
645
+ output_token: int,
646
+ logprobs: Dict[int, Logprob],
647
+ ) -> None:
648
+ self.parent_seq_id = parent_seq_id
649
+ self.output_token = output_token
650
+ self.logprobs = logprobs
651
+
652
+ def __repr__(self) -> str:
653
+ return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
654
+ f"output_token={self.output_token}, "
655
+ f"logprobs={self.logprobs})")
656
+
657
+ def __eq__(self, other: object) -> bool:
658
+ if not isinstance(other, SequenceOutput):
659
+ raise NotImplementedError()
660
+ equal = (self.parent_seq_id == other.parent_seq_id
661
+ and self.output_token == other.output_token)
662
+ log_probs_equal = other.logprobs == self.logprobs
663
+ return equal and log_probs_equal
664
+
665
+
666
+ class SequenceGroupOutput:
667
+ """The model output associated with a sequence group."""
668
+
669
+ def __init__(
670
+ self,
671
+ samples: List[SequenceOutput],
672
+ prompt_logprobs: Optional[PromptLogprobs],
673
+ ) -> None:
674
+ self.samples = samples
675
+ # Prompt logprob for each prompt query token.
676
+ self.prompt_logprobs = prompt_logprobs
677
+
678
+ def __repr__(self) -> str:
679
+ return (f"SequenceGroupOutput(samples={self.samples}, "
680
+ f"prompt_logprobs={self.prompt_logprobs})")
681
+
682
+ def __eq__(self, other: object) -> bool:
683
+ if not isinstance(other, SequenceGroupOutput):
684
+ raise NotImplementedError()
685
+ return (self.samples == other.samples
686
+ and self.prompt_logprobs == other.prompt_logprobs)
687
+
688
+
689
+ @dataclass
690
+ class SamplerOutput:
691
+ """For each sequence group, we generate a list of SequenceOutput object,
692
+ each of which contains one possible candidate for the next token.
693
+
694
+ This datastructure implements methods so it can be used like a list, but
695
+ also has optional fields for device tensors.
696
+ """
697
+
698
+ outputs: List[SequenceGroupOutput]
699
+
700
+ # On-device tensor containing probabilities of each token.
701
+ sampled_token_probs: Optional["torch.Tensor"] = None
702
+
703
+ # On-device tensor containing the logprobs of each token.
704
+ logprobs: Optional["torch.Tensor"] = None
705
+
706
+ # On-device tensor containing the sampled token ids.
707
+ sampled_token_ids: Optional["torch.Tensor"] = None
708
+
709
+ # Spec decode metrics populated by workers.
710
+ spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
711
+
712
+ def __getitem__(self, idx: int):
713
+ return self.outputs[idx]
714
+
715
+ def __setitem__(self, idx: int, value):
716
+ self.outputs[idx] = value
717
+
718
+ def __len__(self):
719
+ return len(self.outputs)
720
+
721
+ def __eq__(self, other: object):
722
+ return isinstance(other,
723
+ self.__class__) and self.outputs == other.outputs
724
+
725
+ def __repr__(self) -> str:
726
+ """Show the shape of a tensor instead of its values to reduce noise.
727
+ """
728
+ sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
729
+ else self.sampled_token_probs.shape)
730
+ sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
731
+ self.sampled_token_ids.shape)
732
+ return (
733
+ f"SamplerOutput(outputs={self.outputs}, "
734
+ f"sampled_token_probs={sampled_token_probs_repr}, "
735
+ f"sampled_token_ids={sampled_token_ids_repr}, "
736
+ f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
737
+
738
+
739
+ @dataclass
740
+ class ExecuteModelRequest:
741
+ """The model execution request."""
742
+ # The sequence group metadata list.
743
+ seq_group_metadata_list: List[SequenceGroupMetadata]
744
+ # Blocks to swap in. Dict of CPU -> GPU block number.
745
+ blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
746
+ # Blocks to swap out. Dict of GPU -> CPU block number.
747
+ blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
748
+ # Blocks to copy. Source to a list of dest blocks.
749
+ blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
750
+ # The number of slots for lookahead decoding.
751
+ num_lookahead_slots: int = 0
752
+ # The number of requests in the running queue.
753
+ running_queue_size: int = 0
754
+
755
+ def clone(
756
+ self, seq_group_metadata_list: List[SequenceGroupMetadata]
757
+ ) -> "ExecuteModelRequest":
758
+ """Clone the request with a new sequence group metadata list."""
759
+ return ExecuteModelRequest(
760
+ seq_group_metadata_list=seq_group_metadata_list,
761
+ blocks_to_swap_in=self.blocks_to_swap_in.copy(),
762
+ blocks_to_swap_out=self.blocks_to_swap_out.copy(),
763
+ blocks_to_copy=self.blocks_to_copy.copy(),
764
+ num_lookahead_slots=self.num_lookahead_slots,
765
+ running_queue_size=self.running_queue_size,
766
+ )