vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,284 @@
|
|
1
|
+
from typing import Dict, List, Tuple, Union
|
2
|
+
|
3
|
+
from vllm.config import SchedulerConfig
|
4
|
+
from vllm.core.scheduler import Scheduler
|
5
|
+
from vllm.engine.output_processor.interfaces import (
|
6
|
+
SequenceGroupOutputProcessor)
|
7
|
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
8
|
+
from vllm.logger import init_logger
|
9
|
+
from vllm.sampling_params import SamplingParams
|
10
|
+
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
11
|
+
SequenceOutput, SequenceStatus)
|
12
|
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
13
|
+
from vllm.utils import Counter
|
14
|
+
|
15
|
+
logger = init_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
19
|
+
"""SequenceGroupOutputProcessor which handles "output processing" logic,
|
20
|
+
which happens after the model returns generated token ids and before
|
21
|
+
scheduling of the next batch. Output processing logic includes
|
22
|
+
detokenization, and determining if a sequence is finished (e.g. via max len
|
23
|
+
or eos token).
|
24
|
+
|
25
|
+
The SingleStepOutputProcessor is specialized to the case where the model
|
26
|
+
emits at most a single token per invocation, which precludes configurations
|
27
|
+
such as speculative decoding or multi-step decoding. This enables beam
|
28
|
+
search sampling, which requires forking/finishing/freeing sequences in a way
|
29
|
+
that is currently difficult to schedule multiple steps ahead of time.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
scheduler_config: SchedulerConfig,
|
35
|
+
detokenizer: Detokenizer,
|
36
|
+
scheduler: Scheduler,
|
37
|
+
seq_counter: Counter,
|
38
|
+
stop_checker: StopChecker,
|
39
|
+
):
|
40
|
+
self.scheduler_config = scheduler_config
|
41
|
+
self.detokenizer = detokenizer
|
42
|
+
self.scheduler = scheduler
|
43
|
+
self.seq_counter = seq_counter
|
44
|
+
self.stop_checker = stop_checker
|
45
|
+
|
46
|
+
def process_outputs(self, sequence_group: SequenceGroup,
|
47
|
+
outputs: List[SequenceGroupOutput]) -> None:
|
48
|
+
"""Append all new tokens to sequences in the sequence group. Fork any
|
49
|
+
surviving beam candidates; free any unsurviving ones.
|
50
|
+
|
51
|
+
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
52
|
+
as finished if they meet stop conditions.
|
53
|
+
"""
|
54
|
+
assert (len(outputs) == 1
|
55
|
+
), f"{type(self)} does not support multiple outputs per step"
|
56
|
+
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
57
|
+
|
58
|
+
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
59
|
+
outputs: List[SequenceGroupOutput]) -> None:
|
60
|
+
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
61
|
+
output = outputs[0]
|
62
|
+
prompt_logprobs = output.prompt_logprobs
|
63
|
+
if (prompt_logprobs is not None
|
64
|
+
and seq_group.sampling_params.detokenize and self.detokenizer):
|
65
|
+
self.detokenizer.decode_prompt_logprobs_inplace(
|
66
|
+
seq_group, prompt_logprobs)
|
67
|
+
if not seq_group.prompt_logprobs:
|
68
|
+
# The first prompt token's logprob is None because it doesn't
|
69
|
+
# have tokens that are precedent.
|
70
|
+
seq_group.prompt_logprobs = [None]
|
71
|
+
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
72
|
+
|
73
|
+
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
74
|
+
outputs: SequenceGroupOutput) -> None:
|
75
|
+
# Process samples
|
76
|
+
samples = outputs.samples
|
77
|
+
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
78
|
+
existing_finished_seqs = seq_group.get_finished_seqs()
|
79
|
+
parent_child_dict: Dict[int, List[SequenceOutput]] = {
|
80
|
+
parent_seq.seq_id: []
|
81
|
+
for parent_seq in parent_seqs
|
82
|
+
}
|
83
|
+
for sample in samples:
|
84
|
+
parent_child_dict[sample.parent_seq_id].append(sample)
|
85
|
+
# List of (child, parent)
|
86
|
+
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
87
|
+
|
88
|
+
# Process the child samples for each parent sequence
|
89
|
+
for parent in parent_seqs:
|
90
|
+
child_samples: List[SequenceOutput] = parent_child_dict[
|
91
|
+
parent.seq_id]
|
92
|
+
if len(child_samples) == 0:
|
93
|
+
# This parent sequence has no children samples. Remove
|
94
|
+
# the parent sequence from the sequence group since it will
|
95
|
+
# not be used in the future iterations.
|
96
|
+
parent.status = SequenceStatus.FINISHED_ABORTED
|
97
|
+
seq_group.remove(parent.seq_id)
|
98
|
+
self.scheduler.free_seq(parent)
|
99
|
+
continue
|
100
|
+
# Fork the parent sequence if there are multiple child samples.
|
101
|
+
for child_sample in child_samples[:-1]:
|
102
|
+
new_child_seq_id: int = next(self.seq_counter)
|
103
|
+
child = parent.fork(new_child_seq_id)
|
104
|
+
child.append_token_id(child_sample.output_token,
|
105
|
+
child_sample.logprobs)
|
106
|
+
child_seqs.append((child, parent))
|
107
|
+
# Continue the parent sequence for the last child sample.
|
108
|
+
# We reuse the parent sequence here to reduce redundant memory
|
109
|
+
# copies, especially when using non-beam search sampling methods.
|
110
|
+
last_child_sample = child_samples[-1]
|
111
|
+
parent.append_token_id(last_child_sample.output_token,
|
112
|
+
last_child_sample.logprobs)
|
113
|
+
child_seqs.append((parent, parent))
|
114
|
+
|
115
|
+
for seq, _ in child_seqs:
|
116
|
+
if seq_group.sampling_params.detokenize and self.detokenizer:
|
117
|
+
new_char_count = self.detokenizer.decode_sequence_inplace(
|
118
|
+
seq, seq_group.sampling_params)
|
119
|
+
else:
|
120
|
+
new_char_count = 0
|
121
|
+
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
|
122
|
+
seq_group.sampling_params)
|
123
|
+
|
124
|
+
# Non-beam search case
|
125
|
+
if not seq_group.sampling_params.use_beam_search:
|
126
|
+
# For newly created child sequences, add them to the sequence group
|
127
|
+
# and fork them in block manager if they are not finished.
|
128
|
+
for seq, parent in child_seqs:
|
129
|
+
if seq is not parent:
|
130
|
+
seq_group.add(seq)
|
131
|
+
if not seq.is_finished():
|
132
|
+
self.scheduler.fork_seq(parent, seq)
|
133
|
+
|
134
|
+
# Free the finished and selected parent sequences' memory in block
|
135
|
+
# manager. Keep them in the sequence group as candidate output.
|
136
|
+
# NOTE: we need to fork the new sequences before freeing the
|
137
|
+
# old sequences.
|
138
|
+
for seq, parent in child_seqs:
|
139
|
+
if seq is parent and seq.is_finished():
|
140
|
+
self.scheduler.free_seq(seq)
|
141
|
+
return
|
142
|
+
|
143
|
+
# Beam search case
|
144
|
+
# Select the child sequences to keep in the sequence group.
|
145
|
+
selected_child_seqs = []
|
146
|
+
unselected_child_seqs = []
|
147
|
+
beam_width = seq_group.sampling_params.best_of
|
148
|
+
length_penalty = seq_group.sampling_params.length_penalty
|
149
|
+
|
150
|
+
# Select the newly finished sequences with the highest scores
|
151
|
+
# to replace existing finished sequences.
|
152
|
+
# Tuple of (seq, parent, is_new)
|
153
|
+
existing_finished_seqs = [(seq, None, False)
|
154
|
+
for seq in existing_finished_seqs]
|
155
|
+
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
156
|
+
if seq.is_finished()]
|
157
|
+
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
158
|
+
# Sort the finished sequences by their scores.
|
159
|
+
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
160
|
+
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
161
|
+
reverse=True)
|
162
|
+
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
163
|
+
if is_new:
|
164
|
+
# A newly generated child sequence finishes and has a high
|
165
|
+
# score, so we will add it into the sequence group.
|
166
|
+
selected_child_seqs.append((seq, parent))
|
167
|
+
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
168
|
+
if is_new:
|
169
|
+
# A newly generated child sequence finishes but has a low
|
170
|
+
# score, so we will not add it into the sequence group.
|
171
|
+
# Additionally, if this sequence is a continuation of a
|
172
|
+
# parent sequence, we will need remove the parent sequence
|
173
|
+
# from the sequence group.
|
174
|
+
unselected_child_seqs.append((seq, parent))
|
175
|
+
else:
|
176
|
+
# An existing finished sequence has a low score, so we will
|
177
|
+
# remove it from the sequence group.
|
178
|
+
seq_group.remove(seq.seq_id)
|
179
|
+
|
180
|
+
# select the top beam_width sequences from the running
|
181
|
+
# sequences for the next iteration to continue the beam
|
182
|
+
# search.
|
183
|
+
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
184
|
+
if not seq.is_finished()]
|
185
|
+
# Sort the running sequences by their scores.
|
186
|
+
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
187
|
+
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
188
|
+
reverse=True)
|
189
|
+
|
190
|
+
# Check if we can stop the beam search.
|
191
|
+
if len(running_child_seqs) == 0:
|
192
|
+
# No running sequences, stop the beam search.
|
193
|
+
stop_beam_search = True
|
194
|
+
elif len(all_finished_seqs) < beam_width:
|
195
|
+
# Not enough finished sequences, continue the beam search.
|
196
|
+
stop_beam_search = False
|
197
|
+
else:
|
198
|
+
# Check the early stopping criteria
|
199
|
+
best_running_seq = running_child_seqs[0][0]
|
200
|
+
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
201
|
+
stop_beam_search = self._check_beam_search_early_stopping(
|
202
|
+
seq_group.sampling_params.early_stopping,
|
203
|
+
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
204
|
+
|
205
|
+
if stop_beam_search:
|
206
|
+
# Stop the beam search and remove all the running sequences from
|
207
|
+
# the sequence group.
|
208
|
+
unselected_child_seqs.extend(running_child_seqs)
|
209
|
+
else:
|
210
|
+
# Continue the beam search and select the top beam_width sequences
|
211
|
+
# to continue the beam search.
|
212
|
+
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
213
|
+
# The remaining running sequences will not be used in the next
|
214
|
+
# iteration. Again, if these sequences are continuations of
|
215
|
+
# parent sequences, we will need to remove the parent sequences
|
216
|
+
# from the sequence group.
|
217
|
+
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
218
|
+
|
219
|
+
# For newly created child sequences, add them to the sequence group
|
220
|
+
# and fork them in block manager if they are not finished.
|
221
|
+
for seq, parent in selected_child_seqs:
|
222
|
+
if seq is not parent:
|
223
|
+
seq_group.add(seq)
|
224
|
+
if not seq.is_finished():
|
225
|
+
self.scheduler.fork_seq(parent, seq)
|
226
|
+
|
227
|
+
# Free the finished and selected parent sequences' memory in block
|
228
|
+
# manager. Keep them in the sequence group as candidate output.
|
229
|
+
for seq, parent in selected_child_seqs:
|
230
|
+
if seq is parent and seq.is_finished():
|
231
|
+
self.scheduler.free_seq(seq)
|
232
|
+
|
233
|
+
# Remove the unselected parent sequences from the sequence group and
|
234
|
+
# free their memory in block manager.
|
235
|
+
for seq, parent in unselected_child_seqs:
|
236
|
+
if seq is parent:
|
237
|
+
# Remove the parent sequence if it is not selected for next
|
238
|
+
# iteration
|
239
|
+
seq_group.remove(seq.seq_id)
|
240
|
+
self.scheduler.free_seq(seq)
|
241
|
+
|
242
|
+
def _check_beam_search_early_stopping(
|
243
|
+
self,
|
244
|
+
early_stopping: Union[bool, str],
|
245
|
+
sampling_params: SamplingParams,
|
246
|
+
best_running_seq: Sequence,
|
247
|
+
current_worst_seq: Sequence,
|
248
|
+
) -> bool:
|
249
|
+
assert sampling_params.use_beam_search
|
250
|
+
length_penalty = sampling_params.length_penalty
|
251
|
+
if early_stopping is True:
|
252
|
+
return True
|
253
|
+
|
254
|
+
current_worst_score = current_worst_seq.get_beam_search_score(
|
255
|
+
length_penalty=length_penalty,
|
256
|
+
eos_token_id=current_worst_seq.eos_token_id)
|
257
|
+
if early_stopping is False:
|
258
|
+
highest_attainable_score = best_running_seq.get_beam_search_score(
|
259
|
+
length_penalty=length_penalty,
|
260
|
+
eos_token_id=best_running_seq.eos_token_id)
|
261
|
+
else:
|
262
|
+
assert early_stopping == "never"
|
263
|
+
if length_penalty > 0.0:
|
264
|
+
# If length_penalty > 0.0, beam search will prefer longer
|
265
|
+
# sequences. The highest attainable score calculation is
|
266
|
+
# based on the longest possible sequence length in this case.
|
267
|
+
max_possible_length = max(
|
268
|
+
best_running_seq.get_prompt_len() +
|
269
|
+
sampling_params.max_tokens,
|
270
|
+
self.scheduler_config.max_model_len)
|
271
|
+
highest_attainable_score = (
|
272
|
+
best_running_seq.get_beam_search_score(
|
273
|
+
length_penalty=length_penalty,
|
274
|
+
eos_token_id=best_running_seq.eos_token_id,
|
275
|
+
seq_len=max_possible_length))
|
276
|
+
else:
|
277
|
+
# Otherwise, beam search will prefer shorter sequences. The
|
278
|
+
# highest attainable score calculation is based on the current
|
279
|
+
# sequence length.
|
280
|
+
highest_attainable_score = (
|
281
|
+
best_running_seq.get_beam_search_score(
|
282
|
+
length_penalty=length_penalty,
|
283
|
+
eos_token_id=best_running_seq.eos_token_id))
|
284
|
+
return current_worst_score >= highest_attainable_score
|
@@ -0,0 +1,101 @@
|
|
1
|
+
from typing import Callable, Optional
|
2
|
+
|
3
|
+
from transformers import PreTrainedTokenizer
|
4
|
+
|
5
|
+
from vllm.sampling_params import SamplingParams
|
6
|
+
from vllm.sequence import Sequence, SequenceStatus
|
7
|
+
|
8
|
+
|
9
|
+
class StopChecker:
|
10
|
+
"""LLMEngine helper class which separates out the logic involving stop
|
11
|
+
checking. This checks things such as: whether the eos token was emitted,
|
12
|
+
whether the max_tokens has been consumed, whether a stop string has been
|
13
|
+
emitted, or if we have exceeded the max model len.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, max_model_len: int,
|
17
|
+
get_tokenizer_for_seq: Callable[[Sequence],
|
18
|
+
PreTrainedTokenizer]):
|
19
|
+
self.max_model_len = max_model_len
|
20
|
+
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
21
|
+
|
22
|
+
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
|
23
|
+
sampling_params: SamplingParams) -> None:
|
24
|
+
"""Stop the finished sequences.
|
25
|
+
|
26
|
+
new_char_count is the number of chars added to the
|
27
|
+
sequence's output text for the newly generated token
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Check if the minimum number of tokens has been generated yet;
|
31
|
+
# skip the stop string/token checks if not
|
32
|
+
if seq.get_output_len() < sampling_params.min_tokens:
|
33
|
+
return
|
34
|
+
|
35
|
+
# Check if the sequence has generated the EOS token.
|
36
|
+
if ((not sampling_params.ignore_eos)
|
37
|
+
and seq.get_last_token_id() == seq.eos_token_id):
|
38
|
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
39
|
+
return
|
40
|
+
|
41
|
+
# Check if a stop token was encountered.
|
42
|
+
# This assumes a single token produced per step.
|
43
|
+
last_token_id = seq.get_last_token_id()
|
44
|
+
if last_token_id in sampling_params.stop_token_ids:
|
45
|
+
if new_char_count and (
|
46
|
+
not sampling_params.include_stop_str_in_output):
|
47
|
+
# Remove last token
|
48
|
+
seq.output_text = seq.output_text[:-new_char_count]
|
49
|
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
50
|
+
seq.stop_reason = last_token_id
|
51
|
+
return
|
52
|
+
|
53
|
+
# Check if any stop strings are matched.
|
54
|
+
stop_str = self._check_stop_strings(seq, new_char_count,
|
55
|
+
sampling_params)
|
56
|
+
if stop_str is not None:
|
57
|
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
58
|
+
seq.stop_reason = stop_str
|
59
|
+
return
|
60
|
+
|
61
|
+
# Check if the sequence has reached max_model_len.
|
62
|
+
if seq.get_len() > self.max_model_len:
|
63
|
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
64
|
+
return
|
65
|
+
|
66
|
+
# Check if the sequence has reached max_tokens.
|
67
|
+
if seq.get_output_len() == sampling_params.max_tokens:
|
68
|
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
69
|
+
return
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
73
|
+
sampling_params: SamplingParams) -> Optional[str]:
|
74
|
+
"""Check if any stop strings are matched and truncate sequence
|
75
|
+
output text accordingly.
|
76
|
+
|
77
|
+
Returns the stop string if matched or else None.
|
78
|
+
"""
|
79
|
+
if not new_char_count:
|
80
|
+
return None
|
81
|
+
|
82
|
+
for stop_str in sampling_params.stop:
|
83
|
+
stop_string_len = len(stop_str)
|
84
|
+
# Avoid searching already-searched text.
|
85
|
+
stop_index = seq.output_text.find(
|
86
|
+
stop_str, -new_char_count - stop_string_len)
|
87
|
+
if stop_index == -1:
|
88
|
+
continue
|
89
|
+
|
90
|
+
if sampling_params.include_stop_str_in_output:
|
91
|
+
# Truncate to end of stop string.
|
92
|
+
stop_index += stop_string_len
|
93
|
+
if stop_index >= len(seq.output_text):
|
94
|
+
# No truncation required.
|
95
|
+
return stop_str
|
96
|
+
|
97
|
+
# Truncate the output text to either the beginning
|
98
|
+
# or end of the stop string.
|
99
|
+
seq.output_text = seq.output_text[:stop_index]
|
100
|
+
return stop_str
|
101
|
+
return None
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from vllm.sequence import SamplerOutput, SequenceGroupOutput
|
4
|
+
|
5
|
+
|
6
|
+
def create_output_by_sequence_group(
|
7
|
+
sampler_outputs: List[SamplerOutput],
|
8
|
+
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
9
|
+
"""Helper method which transforms a 2d list organized by
|
10
|
+
[step][sequence group] into [sequence group][step].
|
11
|
+
"""
|
12
|
+
output_by_sequence_group: List[List[SamplerOutput]] = [
|
13
|
+
[] for _ in range(num_seq_groups)
|
14
|
+
]
|
15
|
+
for step in sampler_outputs:
|
16
|
+
for i, sequence_group_output in enumerate(step):
|
17
|
+
output_by_sequence_group[i].append(sequence_group_output)
|
18
|
+
|
19
|
+
return output_by_sequence_group
|
File without changes
|
@@ -0,0 +1,119 @@
|
|
1
|
+
"""
|
2
|
+
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
3
|
+
and simple performance benchmarks. It is not intended for production use.
|
4
|
+
For production use, we recommend using our OpenAI compatible server.
|
5
|
+
We are also not going to accept PRs modifying this file, please
|
6
|
+
change `vllm/entrypoints/openai/api_server.py` instead.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import argparse
|
10
|
+
import json
|
11
|
+
import ssl
|
12
|
+
from typing import AsyncGenerator
|
13
|
+
|
14
|
+
import uvicorn
|
15
|
+
from fastapi import FastAPI, Request
|
16
|
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
17
|
+
|
18
|
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
19
|
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
20
|
+
from vllm.sampling_params import SamplingParams
|
21
|
+
from vllm.usage.usage_lib import UsageContext
|
22
|
+
from vllm.utils import random_uuid
|
23
|
+
|
24
|
+
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
25
|
+
app = FastAPI()
|
26
|
+
engine = None
|
27
|
+
|
28
|
+
|
29
|
+
@app.get("/health")
|
30
|
+
async def health() -> Response:
|
31
|
+
"""Health check."""
|
32
|
+
return Response(status_code=200)
|
33
|
+
|
34
|
+
|
35
|
+
@app.post("/generate")
|
36
|
+
async def generate(request: Request) -> Response:
|
37
|
+
"""Generate completion for the request.
|
38
|
+
|
39
|
+
The request should be a JSON object with the following fields:
|
40
|
+
- prompt: the prompt to use for the generation.
|
41
|
+
- stream: whether to stream the results or not.
|
42
|
+
- other fields: the sampling parameters (See `SamplingParams` for details).
|
43
|
+
"""
|
44
|
+
request_dict = await request.json()
|
45
|
+
prompt = request_dict.pop("prompt")
|
46
|
+
stream = request_dict.pop("stream", False)
|
47
|
+
sampling_params = SamplingParams(**request_dict)
|
48
|
+
request_id = random_uuid()
|
49
|
+
|
50
|
+
assert engine is not None
|
51
|
+
results_generator = engine.generate(prompt, sampling_params, request_id)
|
52
|
+
|
53
|
+
# Streaming case
|
54
|
+
async def stream_results() -> AsyncGenerator[bytes, None]:
|
55
|
+
async for request_output in results_generator:
|
56
|
+
prompt = request_output.prompt
|
57
|
+
text_outputs = [
|
58
|
+
prompt + output.text for output in request_output.outputs
|
59
|
+
]
|
60
|
+
ret = {"text": text_outputs}
|
61
|
+
yield (json.dumps(ret) + "\0").encode("utf-8")
|
62
|
+
|
63
|
+
if stream:
|
64
|
+
return StreamingResponse(stream_results())
|
65
|
+
|
66
|
+
# Non-streaming case
|
67
|
+
final_output = None
|
68
|
+
async for request_output in results_generator:
|
69
|
+
if await request.is_disconnected():
|
70
|
+
# Abort the request if the client disconnects.
|
71
|
+
await engine.abort(request_id)
|
72
|
+
return Response(status_code=499)
|
73
|
+
final_output = request_output
|
74
|
+
|
75
|
+
assert final_output is not None
|
76
|
+
prompt = final_output.prompt
|
77
|
+
text_outputs = [prompt + output.text for output in final_output.outputs]
|
78
|
+
ret = {"text": text_outputs}
|
79
|
+
return JSONResponse(ret)
|
80
|
+
|
81
|
+
|
82
|
+
if __name__ == "__main__":
|
83
|
+
parser = argparse.ArgumentParser()
|
84
|
+
parser.add_argument("--host", type=str, default=None)
|
85
|
+
parser.add_argument("--port", type=int, default=8000)
|
86
|
+
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
87
|
+
parser.add_argument("--ssl-certfile", type=str, default=None)
|
88
|
+
parser.add_argument("--ssl-ca-certs",
|
89
|
+
type=str,
|
90
|
+
default=None,
|
91
|
+
help="The CA certificates file")
|
92
|
+
parser.add_argument(
|
93
|
+
"--ssl-cert-reqs",
|
94
|
+
type=int,
|
95
|
+
default=int(ssl.CERT_NONE),
|
96
|
+
help="Whether client certificate is required (see stdlib ssl module's)"
|
97
|
+
)
|
98
|
+
parser.add_argument(
|
99
|
+
"--root-path",
|
100
|
+
type=str,
|
101
|
+
default=None,
|
102
|
+
help="FastAPI root_path when app is behind a path based routing proxy")
|
103
|
+
parser.add_argument("--log-level", type=str, default="debug")
|
104
|
+
parser = AsyncEngineArgs.add_cli_args(parser)
|
105
|
+
args = parser.parse_args()
|
106
|
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
107
|
+
engine = AsyncLLMEngine.from_engine_args(
|
108
|
+
engine_args, usage_context=UsageContext.API_SERVER)
|
109
|
+
|
110
|
+
app.root_path = args.root_path
|
111
|
+
uvicorn.run(app,
|
112
|
+
host=args.host,
|
113
|
+
port=args.port,
|
114
|
+
log_level=args.log_level,
|
115
|
+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
116
|
+
ssl_keyfile=args.ssl_keyfile,
|
117
|
+
ssl_certfile=args.ssl_certfile,
|
118
|
+
ssl_ca_certs=args.ssl_ca_certs,
|
119
|
+
ssl_cert_reqs=args.ssl_cert_reqs)
|