vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/sequence.py
ADDED
@@ -0,0 +1,766 @@
|
|
1
|
+
"""Sequence and its related classes."""
|
2
|
+
import copy
|
3
|
+
import enum
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
6
|
+
|
7
|
+
from vllm.block import LogicalTokenBlock
|
8
|
+
from vllm.lora.request import LoRARequest
|
9
|
+
from vllm.sampling_params import SamplingParams
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class Logprob:
|
19
|
+
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
20
|
+
|
21
|
+
Attributes:
|
22
|
+
logprob: The logprob of chosen token
|
23
|
+
rank: The vocab rank of chosen token (>=1)
|
24
|
+
decoded_token: The decoded chosen token index
|
25
|
+
"""
|
26
|
+
logprob: float
|
27
|
+
rank: Optional[int] = None
|
28
|
+
decoded_token: Optional[str] = None
|
29
|
+
|
30
|
+
|
31
|
+
# {token_id -> logprob} per each sequence group. None if the corresponding
|
32
|
+
# sequence group doesn't require prompt logprob.
|
33
|
+
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
34
|
+
# {token_id -> logprob} for each sequence group.
|
35
|
+
SampleLogprobs = List[Dict[int, Logprob]]
|
36
|
+
|
37
|
+
|
38
|
+
class SequenceStatus(enum.Enum):
|
39
|
+
"""Status of a sequence."""
|
40
|
+
WAITING = enum.auto()
|
41
|
+
RUNNING = enum.auto()
|
42
|
+
SWAPPED = enum.auto()
|
43
|
+
FINISHED_STOPPED = enum.auto()
|
44
|
+
FINISHED_LENGTH_CAPPED = enum.auto()
|
45
|
+
FINISHED_ABORTED = enum.auto()
|
46
|
+
FINISHED_IGNORED = enum.auto()
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def is_finished(status: "SequenceStatus") -> bool:
|
50
|
+
return status in [
|
51
|
+
SequenceStatus.FINISHED_STOPPED,
|
52
|
+
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
53
|
+
SequenceStatus.FINISHED_ABORTED,
|
54
|
+
SequenceStatus.FINISHED_IGNORED,
|
55
|
+
]
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
59
|
+
if status == SequenceStatus.FINISHED_STOPPED:
|
60
|
+
finish_reason = "stop"
|
61
|
+
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
62
|
+
finish_reason = "length"
|
63
|
+
elif status == SequenceStatus.FINISHED_ABORTED:
|
64
|
+
finish_reason = "abort"
|
65
|
+
elif status == SequenceStatus.FINISHED_IGNORED:
|
66
|
+
# The ignored sequences are the sequences whose prompt lengths
|
67
|
+
# are longer than the model's length cap. Therefore, the stop
|
68
|
+
# reason should also be "length" as in OpenAI API.
|
69
|
+
finish_reason = "length"
|
70
|
+
else:
|
71
|
+
finish_reason = None
|
72
|
+
return finish_reason
|
73
|
+
|
74
|
+
|
75
|
+
class SequenceStage(enum.Enum):
|
76
|
+
PREFILL = enum.auto()
|
77
|
+
DECODE = enum.auto()
|
78
|
+
|
79
|
+
|
80
|
+
@dataclass
|
81
|
+
class RequestMetrics:
|
82
|
+
"""Metrics associated with a request.
|
83
|
+
|
84
|
+
Attributes:
|
85
|
+
arrival_time: The time when the request arrived.
|
86
|
+
first_scheduled_time: The time when the request was first scheduled.
|
87
|
+
first_token_time: The time when the first token was generated.
|
88
|
+
time_in_queue: The time the request spent in the queue.
|
89
|
+
finished_time: The time when the request was finished.
|
90
|
+
"""
|
91
|
+
arrival_time: float
|
92
|
+
last_token_time: float
|
93
|
+
first_scheduled_time: Optional[float]
|
94
|
+
first_token_time: Optional[float]
|
95
|
+
time_in_queue: Optional[float]
|
96
|
+
finished_time: Optional[float] = None
|
97
|
+
|
98
|
+
|
99
|
+
class SequenceData:
|
100
|
+
"""Data associated with a sequence.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
prompt_token_ids: The token IDs of the prompt.
|
104
|
+
output_token_ids: The token IDs of the output. Set to an empty list if
|
105
|
+
None.
|
106
|
+
|
107
|
+
Attributes:
|
108
|
+
prompt_token_ids: The token IDs of the prompt.
|
109
|
+
output_token_ids: The token IDs of the output.
|
110
|
+
cumulative_logprob: The cumulative log probability of the output.
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
prompt_token_ids: List[int],
|
116
|
+
output_token_ids: Optional[List[int]] = None,
|
117
|
+
) -> None:
|
118
|
+
if output_token_ids is None:
|
119
|
+
output_token_ids = []
|
120
|
+
|
121
|
+
self.prompt_token_ids = prompt_token_ids
|
122
|
+
self.output_token_ids = output_token_ids
|
123
|
+
self.cumulative_logprob = 0.0
|
124
|
+
# The number of tokens that are computed (that run against the model).
|
125
|
+
self._num_computed_tokens = 0
|
126
|
+
self._stage: SequenceStage = SequenceStage.PREFILL
|
127
|
+
|
128
|
+
def append_token_id(self, token_id: int, logprob: float) -> None:
|
129
|
+
self.output_token_ids.append(token_id)
|
130
|
+
self.cumulative_logprob += logprob
|
131
|
+
|
132
|
+
def get_len(self) -> int:
|
133
|
+
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
134
|
+
|
135
|
+
def get_prompt_len(self) -> int:
|
136
|
+
return len(self.prompt_token_ids)
|
137
|
+
|
138
|
+
def get_output_len(self) -> int:
|
139
|
+
return len(self.output_token_ids)
|
140
|
+
|
141
|
+
def get_token_ids(self) -> List[int]:
|
142
|
+
return self.prompt_token_ids + self.output_token_ids
|
143
|
+
|
144
|
+
def get_num_computed_tokens(self) -> int:
|
145
|
+
"""Return the number of prefill tokens that are already computed."""
|
146
|
+
return self._num_computed_tokens
|
147
|
+
|
148
|
+
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
149
|
+
"""Update number of tokens computed so far."""
|
150
|
+
self._num_computed_tokens += num_new_computed_tokens
|
151
|
+
assert self._num_computed_tokens <= self.get_len(), (
|
152
|
+
self._num_computed_tokens, self.get_len())
|
153
|
+
# If all tokens are computed, it means it is in decoding phase.
|
154
|
+
if self.get_num_uncomputed_tokens() == 0:
|
155
|
+
self._stage = SequenceStage.DECODE
|
156
|
+
|
157
|
+
def reset_state_for_recompute(self) -> None:
|
158
|
+
"""Reset the number of computed tokens from this sequence. It is
|
159
|
+
supposed to be called when a sequence needs to be started from
|
160
|
+
the beginning again (e.g., sequence is preempted).
|
161
|
+
"""
|
162
|
+
self._num_computed_tokens = 0
|
163
|
+
self._stage = SequenceStage.PREFILL
|
164
|
+
|
165
|
+
def get_num_uncomputed_tokens(self) -> int:
|
166
|
+
"""Return the number of prefill tokens that are not computed."""
|
167
|
+
# we use `get_len()` which includes prompt_len + output_len instead
|
168
|
+
# of prompt_len here. This is because during recompute we need to
|
169
|
+
# prefill for both prompt and output.
|
170
|
+
return self.get_len() - self.get_num_computed_tokens()
|
171
|
+
|
172
|
+
def get_last_token_id(self) -> int:
|
173
|
+
if not self.output_token_ids:
|
174
|
+
return self.prompt_token_ids[-1]
|
175
|
+
return self.output_token_ids[-1]
|
176
|
+
|
177
|
+
def get_prompt_token_ids(self) -> List[int]:
|
178
|
+
return self.prompt_token_ids
|
179
|
+
|
180
|
+
def get_output_token_ids(self) -> List[int]:
|
181
|
+
return self.output_token_ids
|
182
|
+
|
183
|
+
@property
|
184
|
+
def stage(self) -> SequenceStage:
|
185
|
+
return self._stage
|
186
|
+
|
187
|
+
def __repr__(self) -> str:
|
188
|
+
return (f"SequenceData("
|
189
|
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
190
|
+
f"output_token_ids={self.output_token_ids}, "
|
191
|
+
f"cumulative_logprob={self.cumulative_logprob})")
|
192
|
+
|
193
|
+
|
194
|
+
class Sequence:
|
195
|
+
"""Stores the data, status, and block information of a sequence.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
seq_id: The ID of the sequence.
|
199
|
+
prompt: The prompt of the sequence.
|
200
|
+
prompt_token_ids: The token IDs of the prompt.
|
201
|
+
block_size: The block size of the sequence. Should be the same as the
|
202
|
+
block size used by the block manager and cache engine.
|
203
|
+
lora_request: LoRA request.
|
204
|
+
"""
|
205
|
+
|
206
|
+
def __init__(
|
207
|
+
self,
|
208
|
+
seq_id: int,
|
209
|
+
prompt: str,
|
210
|
+
prompt_token_ids: List[int],
|
211
|
+
block_size: int,
|
212
|
+
eos_token_id: Optional[int] = None,
|
213
|
+
lora_request: Optional[LoRARequest] = None,
|
214
|
+
) -> None:
|
215
|
+
self.seq_id = seq_id
|
216
|
+
self.prompt = prompt
|
217
|
+
self.block_size = block_size
|
218
|
+
self.eos_token_id = eos_token_id
|
219
|
+
self.lora_request = lora_request
|
220
|
+
|
221
|
+
self.data: SequenceData = SequenceData(prompt_token_ids)
|
222
|
+
self.output_logprobs: SampleLogprobs = []
|
223
|
+
self.output_text = ""
|
224
|
+
|
225
|
+
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
226
|
+
# Initialize the logical token blocks with the prompt token ids.
|
227
|
+
self._append_tokens_to_blocks(prompt_token_ids)
|
228
|
+
self.status = SequenceStatus.WAITING
|
229
|
+
self.stop_reason: Union[int, str, None] = None
|
230
|
+
|
231
|
+
# Used for incremental detokenization
|
232
|
+
self.prefix_offset = 0
|
233
|
+
self.read_offset = 0
|
234
|
+
# Input + output tokens
|
235
|
+
self.tokens: Optional[List[str]] = None
|
236
|
+
|
237
|
+
@property
|
238
|
+
def lora_int_id(self) -> int:
|
239
|
+
return self.lora_request.lora_int_id if self.lora_request else 0
|
240
|
+
|
241
|
+
def get_output_text_to_return(self, buffer_length: int):
|
242
|
+
# We return the full output text if the sequence is finished.
|
243
|
+
truncate = buffer_length and not self.is_finished()
|
244
|
+
return self.output_text[:-buffer_length] if truncate else (
|
245
|
+
self.output_text)
|
246
|
+
|
247
|
+
def hash_of_block(self, logical_idx: int) -> int:
|
248
|
+
# TODO This can produce incorrect hash when block size > prompt size
|
249
|
+
|
250
|
+
# Compute the number of tokens in the sequence
|
251
|
+
# TODO: The current hashing function is O(L^2). We should optimize
|
252
|
+
# this in the future.
|
253
|
+
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
254
|
+
return hash(
|
255
|
+
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
|
256
|
+
|
257
|
+
def num_hashed_tokens_of_block(self, logical_idx: int):
|
258
|
+
return logical_idx * self.block_size + self.block_size
|
259
|
+
|
260
|
+
def reset_state_for_recompute(self):
|
261
|
+
"""Reset the sequence states for recomputation."""
|
262
|
+
self.data.reset_state_for_recompute()
|
263
|
+
|
264
|
+
def _append_logical_block(self) -> None:
|
265
|
+
block = LogicalTokenBlock(
|
266
|
+
block_number=len(self.logical_token_blocks),
|
267
|
+
block_size=self.block_size,
|
268
|
+
)
|
269
|
+
self.logical_token_blocks.append(block)
|
270
|
+
|
271
|
+
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
272
|
+
cursor = 0
|
273
|
+
while cursor < len(token_ids):
|
274
|
+
if not self.logical_token_blocks:
|
275
|
+
self._append_logical_block()
|
276
|
+
|
277
|
+
last_block = self.logical_token_blocks[-1]
|
278
|
+
if last_block.is_full():
|
279
|
+
self._append_logical_block()
|
280
|
+
last_block = self.logical_token_blocks[-1]
|
281
|
+
|
282
|
+
num_empty_slots = last_block.get_num_empty_slots()
|
283
|
+
last_block.append_tokens(token_ids[cursor:cursor +
|
284
|
+
num_empty_slots])
|
285
|
+
cursor += num_empty_slots
|
286
|
+
|
287
|
+
def append_token_id(
|
288
|
+
self,
|
289
|
+
token_id: int,
|
290
|
+
logprobs: Dict[int, Logprob],
|
291
|
+
) -> None:
|
292
|
+
assert token_id in logprobs
|
293
|
+
self._append_tokens_to_blocks([token_id])
|
294
|
+
self.output_logprobs.append(logprobs)
|
295
|
+
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
296
|
+
|
297
|
+
def get_len(self) -> int:
|
298
|
+
return self.data.get_len()
|
299
|
+
|
300
|
+
def get_prompt_len(self) -> int:
|
301
|
+
return self.data.get_prompt_len()
|
302
|
+
|
303
|
+
def get_output_len(self) -> int:
|
304
|
+
return self.data.get_output_len()
|
305
|
+
|
306
|
+
def get_token_ids(self) -> List[int]:
|
307
|
+
return self.data.get_token_ids()
|
308
|
+
|
309
|
+
def get_prompt_token_ids(self) -> List[int]:
|
310
|
+
return self.data.get_prompt_token_ids()
|
311
|
+
|
312
|
+
def get_last_token_id(self) -> int:
|
313
|
+
return self.data.get_last_token_id()
|
314
|
+
|
315
|
+
def get_output_token_ids(self) -> List[int]:
|
316
|
+
return self.data.output_token_ids
|
317
|
+
|
318
|
+
def get_cumulative_logprob(self) -> float:
|
319
|
+
return self.data.cumulative_logprob
|
320
|
+
|
321
|
+
def get_beam_search_score(self,
|
322
|
+
length_penalty: float = 1.0,
|
323
|
+
seq_len: Optional[int] = None,
|
324
|
+
eos_token_id: Optional[int] = None) -> float:
|
325
|
+
"""Calculate the beam search score with length penalty.
|
326
|
+
|
327
|
+
Adapted from
|
328
|
+
|
329
|
+
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
330
|
+
"""
|
331
|
+
if seq_len is None:
|
332
|
+
seq_len = self.get_len()
|
333
|
+
# NOTE: HF implementation does not count the EOS token
|
334
|
+
# towards the length, we align with that here for testing.
|
335
|
+
if (eos_token_id is not None
|
336
|
+
and self.get_last_token_id() == eos_token_id):
|
337
|
+
seq_len -= 1
|
338
|
+
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
339
|
+
|
340
|
+
def is_finished(self) -> bool:
|
341
|
+
return SequenceStatus.is_finished(self.status)
|
342
|
+
|
343
|
+
def fork(self, new_seq_id: int) -> "Sequence":
|
344
|
+
new_seq = copy.deepcopy(self)
|
345
|
+
new_seq.seq_id = new_seq_id
|
346
|
+
return new_seq
|
347
|
+
|
348
|
+
def get_num_new_tokens(self) -> int:
|
349
|
+
"""Get the number of new tokens to be computed.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
The new number of tokens to be computed. I.e., 1 for decode, or
|
353
|
+
the remaining prompt size for prefill.
|
354
|
+
"""
|
355
|
+
if self.data.stage == SequenceStage.DECODE:
|
356
|
+
return 1
|
357
|
+
return self.data.get_num_uncomputed_tokens()
|
358
|
+
|
359
|
+
def is_prefill(self) -> bool:
|
360
|
+
return self.data.stage == SequenceStage.PREFILL
|
361
|
+
|
362
|
+
def __repr__(self) -> str:
|
363
|
+
return (f"Sequence(seq_id={self.seq_id}, "
|
364
|
+
f"status={self.status.name}, "
|
365
|
+
f"num_blocks={len(self.logical_token_blocks)})")
|
366
|
+
|
367
|
+
|
368
|
+
@dataclass
|
369
|
+
class SequenceGroupState:
|
370
|
+
"""Mutable state tied to a specific sequence group"""
|
371
|
+
|
372
|
+
# torch.Generator used in seeded sampling
|
373
|
+
generator: Optional = None # type: ignore
|
374
|
+
|
375
|
+
|
376
|
+
class MultiModalData:
|
377
|
+
"""Multi modal request.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
type: The data type.
|
381
|
+
data: The actual data.
|
382
|
+
The required shape and semantic meaning of it depends on the vision
|
383
|
+
language config of the hosted model.
|
384
|
+
See `VisionLanguageConfig` in `config.py`.
|
385
|
+
"""
|
386
|
+
|
387
|
+
class Type(enum.Enum):
|
388
|
+
IMAGE = enum.auto()
|
389
|
+
|
390
|
+
def __init__(self, type: Type, data: "torch.Tensor"):
|
391
|
+
self.type = type
|
392
|
+
self.data = data
|
393
|
+
|
394
|
+
|
395
|
+
class SequenceGroup:
|
396
|
+
"""A group of sequences that are generated from the same prompt.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
request_id: The ID of the request.
|
400
|
+
seqs: The list of sequences.
|
401
|
+
sampling_params: The sampling parameters used to generate the outputs.
|
402
|
+
arrival_time: The arrival time of the request.
|
403
|
+
lora_request: LoRA request.
|
404
|
+
multi_modal_data: Multi modal data associated with the request.
|
405
|
+
"""
|
406
|
+
|
407
|
+
def __init__(
|
408
|
+
self,
|
409
|
+
request_id: str,
|
410
|
+
seqs: List[Sequence],
|
411
|
+
sampling_params: SamplingParams,
|
412
|
+
arrival_time: float,
|
413
|
+
lora_request: Optional[LoRARequest] = None,
|
414
|
+
multi_modal_data: Optional[MultiModalData] = None,
|
415
|
+
) -> None:
|
416
|
+
self.request_id = request_id
|
417
|
+
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
418
|
+
self.sampling_params = sampling_params
|
419
|
+
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
420
|
+
last_token_time=arrival_time,
|
421
|
+
first_scheduled_time=None,
|
422
|
+
first_token_time=None,
|
423
|
+
time_in_queue=None)
|
424
|
+
self.lora_request = lora_request
|
425
|
+
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
426
|
+
self.state = SequenceGroupState()
|
427
|
+
self.multi_modal_data = multi_modal_data
|
428
|
+
|
429
|
+
@property
|
430
|
+
def prompt(self) -> str:
|
431
|
+
# All sequences in the group should have the same prompt.
|
432
|
+
# We use the prompt of an arbitrary sequence.
|
433
|
+
return next(iter(self.seqs_dict.values())).prompt
|
434
|
+
|
435
|
+
@property
|
436
|
+
def prompt_token_ids(self) -> List[int]:
|
437
|
+
# All sequences in the group should have the same prompt.
|
438
|
+
# We use the prompt of an arbitrary sequence.
|
439
|
+
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
440
|
+
|
441
|
+
@property
|
442
|
+
def lora_int_id(self) -> int:
|
443
|
+
return self.lora_request.lora_int_id if self.lora_request else 0
|
444
|
+
|
445
|
+
def get_last_latency(self, now: float) -> Optional[float]:
|
446
|
+
"""Sets the last token time for Request level timings."""
|
447
|
+
# If still in prefill phase, raise Error.
|
448
|
+
if self.is_prefill():
|
449
|
+
raise ValueError(
|
450
|
+
"seq_group.get_last_latency() should not be called "
|
451
|
+
"if the seq_group is in prefill phase.")
|
452
|
+
|
453
|
+
# Otherwise return token latency.
|
454
|
+
latency = now - self.metrics.last_token_time
|
455
|
+
self.metrics.last_token_time = now
|
456
|
+
return latency
|
457
|
+
|
458
|
+
def maybe_set_first_token_time(self, time: float) -> None:
|
459
|
+
"""Sets the first token time for Request level timings."""
|
460
|
+
# Note: in a case where a sequence_group is swapped and
|
461
|
+
# recomputed, the time between iterations is counted
|
462
|
+
# in TPOT, rather than recalculating TTFT (since from the )
|
463
|
+
# POV of the user, there is simply a long generation delay.
|
464
|
+
if (self.metrics.first_token_time is None
|
465
|
+
and self.get_seqs()[0].get_output_len() == 1):
|
466
|
+
self.metrics.first_token_time = time
|
467
|
+
|
468
|
+
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
469
|
+
"""Sets the first scheduled time and time in queue for Request
|
470
|
+
level timings."""
|
471
|
+
if self.metrics.first_scheduled_time is None:
|
472
|
+
self.metrics.first_scheduled_time = time
|
473
|
+
self.metrics.time_in_queue = time - self.metrics.arrival_time
|
474
|
+
|
475
|
+
def set_finished_time(self, time: Optional[float]) -> None:
|
476
|
+
"""Sets the finished time for Request level timings."""
|
477
|
+
self.metrics.finished_time = time
|
478
|
+
|
479
|
+
def get_max_num_running_seqs(self) -> int:
|
480
|
+
"""The maximum number of sequences running in parallel in the remaining
|
481
|
+
lifetime of the request."""
|
482
|
+
if self.sampling_params.use_beam_search:
|
483
|
+
# For beam search, maximally there will always be `best_of` beam
|
484
|
+
# candidates running in the future.
|
485
|
+
return self.sampling_params.best_of
|
486
|
+
else:
|
487
|
+
if self.sampling_params.best_of > self.num_seqs():
|
488
|
+
# At prompt stage, the sequence group is not yet filled up
|
489
|
+
# and only have one sequence running. However, in the
|
490
|
+
# generation stage, we will have `best_of` sequences running.
|
491
|
+
return self.sampling_params.best_of
|
492
|
+
# At sampling stages, return the number of actual sequences
|
493
|
+
# that are not finished yet.
|
494
|
+
return self.num_unfinished_seqs()
|
495
|
+
|
496
|
+
def get_seqs(
|
497
|
+
self,
|
498
|
+
status: Optional[SequenceStatus] = None,
|
499
|
+
) -> List[Sequence]:
|
500
|
+
return list(self.seqs_dict.values()) if status is None else [
|
501
|
+
seq for seq in self.seqs_dict.values() if seq.status == status
|
502
|
+
]
|
503
|
+
|
504
|
+
def get_unfinished_seqs(self) -> List[Sequence]:
|
505
|
+
return [
|
506
|
+
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
507
|
+
]
|
508
|
+
|
509
|
+
def get_finished_seqs(self) -> List[Sequence]:
|
510
|
+
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
511
|
+
|
512
|
+
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
513
|
+
"""Update number of tokens computed so far."""
|
514
|
+
for seq in self.seqs_dict.values():
|
515
|
+
if not seq.is_finished():
|
516
|
+
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
517
|
+
|
518
|
+
def get_num_uncomputed_tokens(self) -> int:
|
519
|
+
num_uncomputed_tokens = 0
|
520
|
+
for seq in self.get_seqs():
|
521
|
+
if not seq.is_finished():
|
522
|
+
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
523
|
+
return num_uncomputed_tokens
|
524
|
+
|
525
|
+
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
526
|
+
# Optimization. We don't need to call get_seqs if we don't need to
|
527
|
+
# filter by states.
|
528
|
+
if status is None:
|
529
|
+
return len(self.seqs_dict)
|
530
|
+
|
531
|
+
return len(self.get_seqs(status))
|
532
|
+
|
533
|
+
def num_unfinished_seqs(self) -> int:
|
534
|
+
return len(self.get_unfinished_seqs())
|
535
|
+
|
536
|
+
def num_finished_seqs(self) -> int:
|
537
|
+
return len(self.get_finished_seqs())
|
538
|
+
|
539
|
+
def find(self, seq_id: int) -> Sequence:
|
540
|
+
if seq_id not in self.seqs_dict:
|
541
|
+
raise ValueError(f"Sequence {seq_id} not found.")
|
542
|
+
return self.seqs_dict[seq_id]
|
543
|
+
|
544
|
+
def add(self, seq: Sequence) -> None:
|
545
|
+
if seq.seq_id in self.seqs_dict:
|
546
|
+
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
547
|
+
self.seqs_dict[seq.seq_id] = seq
|
548
|
+
|
549
|
+
def remove(self, seq_id: int) -> None:
|
550
|
+
if seq_id not in self.seqs_dict:
|
551
|
+
raise ValueError(f"Sequence {seq_id} not found.")
|
552
|
+
del self.seqs_dict[seq_id]
|
553
|
+
|
554
|
+
def is_finished(self) -> bool:
|
555
|
+
return all(seq.is_finished() for seq in self.get_seqs())
|
556
|
+
|
557
|
+
def is_prefill(self) -> bool:
|
558
|
+
# Every sequences should be in the same stage.
|
559
|
+
return self.get_seqs()[0].is_prefill()
|
560
|
+
|
561
|
+
def __repr__(self) -> str:
|
562
|
+
return (f"SequenceGroup(request_id={self.request_id}, "
|
563
|
+
f"sampling_params={self.sampling_params}, "
|
564
|
+
f"num_seqs={len(self.seqs_dict)})")
|
565
|
+
|
566
|
+
|
567
|
+
class SequenceGroupMetadata:
|
568
|
+
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
request_id: The ID of the request.
|
572
|
+
is_prompt: Whether the request is at prompt stage.
|
573
|
+
seq_data: The sequence data. (Seq id -> sequence data)
|
574
|
+
sampling_params: The sampling parameters used to generate the outputs.
|
575
|
+
block_tables: The block tables. (Seq id -> list of physical block
|
576
|
+
numbers)
|
577
|
+
do_sample: True if sampling is required. Sampling is not required when
|
578
|
+
e.g., prefill is chunked, and the current iteration only computes
|
579
|
+
query tokens for prefill, we don't need sampling.
|
580
|
+
token_chunk_size: The number of tokens to be processed (per sequence).
|
581
|
+
None if chunking is not required.
|
582
|
+
lora_request: LoRA request.
|
583
|
+
computed_block_nums: The block numbers that are already computed,
|
584
|
+
used in prefix caching.
|
585
|
+
state: Internal state tied to this sequence group.
|
586
|
+
multi_modal_data: Multi modal data.
|
587
|
+
"""
|
588
|
+
|
589
|
+
def __init__(
|
590
|
+
self,
|
591
|
+
request_id: str,
|
592
|
+
is_prompt: bool,
|
593
|
+
seq_data: Dict[int, SequenceData],
|
594
|
+
sampling_params: SamplingParams,
|
595
|
+
block_tables: Dict[int, List[int]],
|
596
|
+
do_sample: bool = True,
|
597
|
+
token_chunk_size: Optional[int] = None,
|
598
|
+
lora_request: Optional[LoRARequest] = None,
|
599
|
+
computed_block_nums: Optional[List[int]] = None,
|
600
|
+
state: Optional[SequenceGroupState] = None,
|
601
|
+
multi_modal_data: Optional[MultiModalData] = None,
|
602
|
+
) -> None:
|
603
|
+
self.request_id = request_id
|
604
|
+
self.is_prompt = is_prompt
|
605
|
+
self.seq_data = seq_data
|
606
|
+
self.sampling_params = sampling_params
|
607
|
+
self.block_tables = block_tables
|
608
|
+
self.lora_request = lora_request
|
609
|
+
self.computed_block_nums = computed_block_nums
|
610
|
+
self.multi_modal_data = multi_modal_data
|
611
|
+
self.state = SequenceGroupState() if state is None else state
|
612
|
+
self._token_chunk_size = token_chunk_size
|
613
|
+
self.do_sample = do_sample
|
614
|
+
|
615
|
+
if self._token_chunk_size is None:
|
616
|
+
if is_prompt:
|
617
|
+
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
618
|
+
else:
|
619
|
+
self._token_chunk_size = 1
|
620
|
+
|
621
|
+
@property
|
622
|
+
def lora_int_id(self) -> int:
|
623
|
+
return self.lora_request.lora_int_id if self.lora_request else 0
|
624
|
+
|
625
|
+
@property
|
626
|
+
def token_chunk_size(self) -> Optional[int]:
|
627
|
+
"""Return the number of tokens to be processed (chunk size)."""
|
628
|
+
return self._token_chunk_size
|
629
|
+
|
630
|
+
|
631
|
+
class SequenceOutput:
|
632
|
+
"""The model output associated with a sequence.
|
633
|
+
|
634
|
+
Args:
|
635
|
+
parent_seq_id: The ID of the parent sequence (for forking in beam
|
636
|
+
search).
|
637
|
+
output_token: The output token ID.
|
638
|
+
logprobs: The logprobs of the output token.
|
639
|
+
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
640
|
+
"""
|
641
|
+
|
642
|
+
def __init__(
|
643
|
+
self,
|
644
|
+
parent_seq_id: int,
|
645
|
+
output_token: int,
|
646
|
+
logprobs: Dict[int, Logprob],
|
647
|
+
) -> None:
|
648
|
+
self.parent_seq_id = parent_seq_id
|
649
|
+
self.output_token = output_token
|
650
|
+
self.logprobs = logprobs
|
651
|
+
|
652
|
+
def __repr__(self) -> str:
|
653
|
+
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
654
|
+
f"output_token={self.output_token}, "
|
655
|
+
f"logprobs={self.logprobs})")
|
656
|
+
|
657
|
+
def __eq__(self, other: object) -> bool:
|
658
|
+
if not isinstance(other, SequenceOutput):
|
659
|
+
raise NotImplementedError()
|
660
|
+
equal = (self.parent_seq_id == other.parent_seq_id
|
661
|
+
and self.output_token == other.output_token)
|
662
|
+
log_probs_equal = other.logprobs == self.logprobs
|
663
|
+
return equal and log_probs_equal
|
664
|
+
|
665
|
+
|
666
|
+
class SequenceGroupOutput:
|
667
|
+
"""The model output associated with a sequence group."""
|
668
|
+
|
669
|
+
def __init__(
|
670
|
+
self,
|
671
|
+
samples: List[SequenceOutput],
|
672
|
+
prompt_logprobs: Optional[PromptLogprobs],
|
673
|
+
) -> None:
|
674
|
+
self.samples = samples
|
675
|
+
# Prompt logprob for each prompt query token.
|
676
|
+
self.prompt_logprobs = prompt_logprobs
|
677
|
+
|
678
|
+
def __repr__(self) -> str:
|
679
|
+
return (f"SequenceGroupOutput(samples={self.samples}, "
|
680
|
+
f"prompt_logprobs={self.prompt_logprobs})")
|
681
|
+
|
682
|
+
def __eq__(self, other: object) -> bool:
|
683
|
+
if not isinstance(other, SequenceGroupOutput):
|
684
|
+
raise NotImplementedError()
|
685
|
+
return (self.samples == other.samples
|
686
|
+
and self.prompt_logprobs == other.prompt_logprobs)
|
687
|
+
|
688
|
+
|
689
|
+
@dataclass
|
690
|
+
class SamplerOutput:
|
691
|
+
"""For each sequence group, we generate a list of SequenceOutput object,
|
692
|
+
each of which contains one possible candidate for the next token.
|
693
|
+
|
694
|
+
This datastructure implements methods so it can be used like a list, but
|
695
|
+
also has optional fields for device tensors.
|
696
|
+
"""
|
697
|
+
|
698
|
+
outputs: List[SequenceGroupOutput]
|
699
|
+
|
700
|
+
# On-device tensor containing probabilities of each token.
|
701
|
+
sampled_token_probs: Optional["torch.Tensor"] = None
|
702
|
+
|
703
|
+
# On-device tensor containing the logprobs of each token.
|
704
|
+
logprobs: Optional["torch.Tensor"] = None
|
705
|
+
|
706
|
+
# On-device tensor containing the sampled token ids.
|
707
|
+
sampled_token_ids: Optional["torch.Tensor"] = None
|
708
|
+
|
709
|
+
# Spec decode metrics populated by workers.
|
710
|
+
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
711
|
+
|
712
|
+
def __getitem__(self, idx: int):
|
713
|
+
return self.outputs[idx]
|
714
|
+
|
715
|
+
def __setitem__(self, idx: int, value):
|
716
|
+
self.outputs[idx] = value
|
717
|
+
|
718
|
+
def __len__(self):
|
719
|
+
return len(self.outputs)
|
720
|
+
|
721
|
+
def __eq__(self, other: object):
|
722
|
+
return isinstance(other,
|
723
|
+
self.__class__) and self.outputs == other.outputs
|
724
|
+
|
725
|
+
def __repr__(self) -> str:
|
726
|
+
"""Show the shape of a tensor instead of its values to reduce noise.
|
727
|
+
"""
|
728
|
+
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
729
|
+
else self.sampled_token_probs.shape)
|
730
|
+
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
731
|
+
self.sampled_token_ids.shape)
|
732
|
+
return (
|
733
|
+
f"SamplerOutput(outputs={self.outputs}, "
|
734
|
+
f"sampled_token_probs={sampled_token_probs_repr}, "
|
735
|
+
f"sampled_token_ids={sampled_token_ids_repr}, "
|
736
|
+
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
737
|
+
|
738
|
+
|
739
|
+
@dataclass
|
740
|
+
class ExecuteModelRequest:
|
741
|
+
"""The model execution request."""
|
742
|
+
# The sequence group metadata list.
|
743
|
+
seq_group_metadata_list: List[SequenceGroupMetadata]
|
744
|
+
# Blocks to swap in. Dict of CPU -> GPU block number.
|
745
|
+
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
746
|
+
# Blocks to swap out. Dict of GPU -> CPU block number.
|
747
|
+
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
748
|
+
# Blocks to copy. Source to a list of dest blocks.
|
749
|
+
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
|
750
|
+
# The number of slots for lookahead decoding.
|
751
|
+
num_lookahead_slots: int = 0
|
752
|
+
# The number of requests in the running queue.
|
753
|
+
running_queue_size: int = 0
|
754
|
+
|
755
|
+
def clone(
|
756
|
+
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
757
|
+
) -> "ExecuteModelRequest":
|
758
|
+
"""Clone the request with a new sequence group metadata list."""
|
759
|
+
return ExecuteModelRequest(
|
760
|
+
seq_group_metadata_list=seq_group_metadata_list,
|
761
|
+
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
762
|
+
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
763
|
+
blocks_to_copy=self.blocks_to_copy.copy(),
|
764
|
+
num_lookahead_slots=self.num_lookahead_slots,
|
765
|
+
running_queue_size=self.running_queue_size,
|
766
|
+
)
|