vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,284 @@
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ from vllm.config import SchedulerConfig
4
+ from vllm.core.scheduler import Scheduler
5
+ from vllm.engine.output_processor.interfaces import (
6
+ SequenceGroupOutputProcessor)
7
+ from vllm.engine.output_processor.stop_checker import StopChecker
8
+ from vllm.logger import init_logger
9
+ from vllm.sampling_params import SamplingParams
10
+ from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
11
+ SequenceOutput, SequenceStatus)
12
+ from vllm.transformers_utils.detokenizer import Detokenizer
13
+ from vllm.utils import Counter
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
19
+ """SequenceGroupOutputProcessor which handles "output processing" logic,
20
+ which happens after the model returns generated token ids and before
21
+ scheduling of the next batch. Output processing logic includes
22
+ detokenization, and determining if a sequence is finished (e.g. via max len
23
+ or eos token).
24
+
25
+ The SingleStepOutputProcessor is specialized to the case where the model
26
+ emits at most a single token per invocation, which precludes configurations
27
+ such as speculative decoding or multi-step decoding. This enables beam
28
+ search sampling, which requires forking/finishing/freeing sequences in a way
29
+ that is currently difficult to schedule multiple steps ahead of time.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ scheduler_config: SchedulerConfig,
35
+ detokenizer: Detokenizer,
36
+ scheduler: Scheduler,
37
+ seq_counter: Counter,
38
+ stop_checker: StopChecker,
39
+ ):
40
+ self.scheduler_config = scheduler_config
41
+ self.detokenizer = detokenizer
42
+ self.scheduler = scheduler
43
+ self.seq_counter = seq_counter
44
+ self.stop_checker = stop_checker
45
+
46
+ def process_outputs(self, sequence_group: SequenceGroup,
47
+ outputs: List[SequenceGroupOutput]) -> None:
48
+ """Append all new tokens to sequences in the sequence group. Fork any
49
+ surviving beam candidates; free any unsurviving ones.
50
+
51
+ Invokes detokenizer to detokenize new tokens, and also marks sequences
52
+ as finished if they meet stop conditions.
53
+ """
54
+ assert (len(outputs) == 1
55
+ ), f"{type(self)} does not support multiple outputs per step"
56
+ return self._process_sequence_group_outputs(sequence_group, outputs[0])
57
+
58
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
59
+ outputs: List[SequenceGroupOutput]) -> None:
60
+ assert len(outputs) == 1, ("Single step should only has 1 output.")
61
+ output = outputs[0]
62
+ prompt_logprobs = output.prompt_logprobs
63
+ if (prompt_logprobs is not None
64
+ and seq_group.sampling_params.detokenize and self.detokenizer):
65
+ self.detokenizer.decode_prompt_logprobs_inplace(
66
+ seq_group, prompt_logprobs)
67
+ if not seq_group.prompt_logprobs:
68
+ # The first prompt token's logprob is None because it doesn't
69
+ # have tokens that are precedent.
70
+ seq_group.prompt_logprobs = [None]
71
+ seq_group.prompt_logprobs.extend(prompt_logprobs)
72
+
73
+ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
74
+ outputs: SequenceGroupOutput) -> None:
75
+ # Process samples
76
+ samples = outputs.samples
77
+ parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
78
+ existing_finished_seqs = seq_group.get_finished_seqs()
79
+ parent_child_dict: Dict[int, List[SequenceOutput]] = {
80
+ parent_seq.seq_id: []
81
+ for parent_seq in parent_seqs
82
+ }
83
+ for sample in samples:
84
+ parent_child_dict[sample.parent_seq_id].append(sample)
85
+ # List of (child, parent)
86
+ child_seqs: List[Tuple[Sequence, Sequence]] = []
87
+
88
+ # Process the child samples for each parent sequence
89
+ for parent in parent_seqs:
90
+ child_samples: List[SequenceOutput] = parent_child_dict[
91
+ parent.seq_id]
92
+ if len(child_samples) == 0:
93
+ # This parent sequence has no children samples. Remove
94
+ # the parent sequence from the sequence group since it will
95
+ # not be used in the future iterations.
96
+ parent.status = SequenceStatus.FINISHED_ABORTED
97
+ seq_group.remove(parent.seq_id)
98
+ self.scheduler.free_seq(parent)
99
+ continue
100
+ # Fork the parent sequence if there are multiple child samples.
101
+ for child_sample in child_samples[:-1]:
102
+ new_child_seq_id: int = next(self.seq_counter)
103
+ child = parent.fork(new_child_seq_id)
104
+ child.append_token_id(child_sample.output_token,
105
+ child_sample.logprobs)
106
+ child_seqs.append((child, parent))
107
+ # Continue the parent sequence for the last child sample.
108
+ # We reuse the parent sequence here to reduce redundant memory
109
+ # copies, especially when using non-beam search sampling methods.
110
+ last_child_sample = child_samples[-1]
111
+ parent.append_token_id(last_child_sample.output_token,
112
+ last_child_sample.logprobs)
113
+ child_seqs.append((parent, parent))
114
+
115
+ for seq, _ in child_seqs:
116
+ if seq_group.sampling_params.detokenize and self.detokenizer:
117
+ new_char_count = self.detokenizer.decode_sequence_inplace(
118
+ seq, seq_group.sampling_params)
119
+ else:
120
+ new_char_count = 0
121
+ self.stop_checker.maybe_stop_sequence(seq, new_char_count,
122
+ seq_group.sampling_params)
123
+
124
+ # Non-beam search case
125
+ if not seq_group.sampling_params.use_beam_search:
126
+ # For newly created child sequences, add them to the sequence group
127
+ # and fork them in block manager if they are not finished.
128
+ for seq, parent in child_seqs:
129
+ if seq is not parent:
130
+ seq_group.add(seq)
131
+ if not seq.is_finished():
132
+ self.scheduler.fork_seq(parent, seq)
133
+
134
+ # Free the finished and selected parent sequences' memory in block
135
+ # manager. Keep them in the sequence group as candidate output.
136
+ # NOTE: we need to fork the new sequences before freeing the
137
+ # old sequences.
138
+ for seq, parent in child_seqs:
139
+ if seq is parent and seq.is_finished():
140
+ self.scheduler.free_seq(seq)
141
+ return
142
+
143
+ # Beam search case
144
+ # Select the child sequences to keep in the sequence group.
145
+ selected_child_seqs = []
146
+ unselected_child_seqs = []
147
+ beam_width = seq_group.sampling_params.best_of
148
+ length_penalty = seq_group.sampling_params.length_penalty
149
+
150
+ # Select the newly finished sequences with the highest scores
151
+ # to replace existing finished sequences.
152
+ # Tuple of (seq, parent, is_new)
153
+ existing_finished_seqs = [(seq, None, False)
154
+ for seq in existing_finished_seqs]
155
+ new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
156
+ if seq.is_finished()]
157
+ all_finished_seqs = existing_finished_seqs + new_finished_seqs
158
+ # Sort the finished sequences by their scores.
159
+ all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
160
+ length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
161
+ reverse=True)
162
+ for seq, parent, is_new in all_finished_seqs[:beam_width]:
163
+ if is_new:
164
+ # A newly generated child sequence finishes and has a high
165
+ # score, so we will add it into the sequence group.
166
+ selected_child_seqs.append((seq, parent))
167
+ for seq, parent, is_new in all_finished_seqs[beam_width:]:
168
+ if is_new:
169
+ # A newly generated child sequence finishes but has a low
170
+ # score, so we will not add it into the sequence group.
171
+ # Additionally, if this sequence is a continuation of a
172
+ # parent sequence, we will need remove the parent sequence
173
+ # from the sequence group.
174
+ unselected_child_seqs.append((seq, parent))
175
+ else:
176
+ # An existing finished sequence has a low score, so we will
177
+ # remove it from the sequence group.
178
+ seq_group.remove(seq.seq_id)
179
+
180
+ # select the top beam_width sequences from the running
181
+ # sequences for the next iteration to continue the beam
182
+ # search.
183
+ running_child_seqs = [(seq, parent) for seq, parent in child_seqs
184
+ if not seq.is_finished()]
185
+ # Sort the running sequences by their scores.
186
+ running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
187
+ length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
188
+ reverse=True)
189
+
190
+ # Check if we can stop the beam search.
191
+ if len(running_child_seqs) == 0:
192
+ # No running sequences, stop the beam search.
193
+ stop_beam_search = True
194
+ elif len(all_finished_seqs) < beam_width:
195
+ # Not enough finished sequences, continue the beam search.
196
+ stop_beam_search = False
197
+ else:
198
+ # Check the early stopping criteria
199
+ best_running_seq = running_child_seqs[0][0]
200
+ current_worst_seq = all_finished_seqs[beam_width - 1][0]
201
+ stop_beam_search = self._check_beam_search_early_stopping(
202
+ seq_group.sampling_params.early_stopping,
203
+ seq_group.sampling_params, best_running_seq, current_worst_seq)
204
+
205
+ if stop_beam_search:
206
+ # Stop the beam search and remove all the running sequences from
207
+ # the sequence group.
208
+ unselected_child_seqs.extend(running_child_seqs)
209
+ else:
210
+ # Continue the beam search and select the top beam_width sequences
211
+ # to continue the beam search.
212
+ selected_child_seqs.extend(running_child_seqs[:beam_width])
213
+ # The remaining running sequences will not be used in the next
214
+ # iteration. Again, if these sequences are continuations of
215
+ # parent sequences, we will need to remove the parent sequences
216
+ # from the sequence group.
217
+ unselected_child_seqs.extend(running_child_seqs[beam_width:])
218
+
219
+ # For newly created child sequences, add them to the sequence group
220
+ # and fork them in block manager if they are not finished.
221
+ for seq, parent in selected_child_seqs:
222
+ if seq is not parent:
223
+ seq_group.add(seq)
224
+ if not seq.is_finished():
225
+ self.scheduler.fork_seq(parent, seq)
226
+
227
+ # Free the finished and selected parent sequences' memory in block
228
+ # manager. Keep them in the sequence group as candidate output.
229
+ for seq, parent in selected_child_seqs:
230
+ if seq is parent and seq.is_finished():
231
+ self.scheduler.free_seq(seq)
232
+
233
+ # Remove the unselected parent sequences from the sequence group and
234
+ # free their memory in block manager.
235
+ for seq, parent in unselected_child_seqs:
236
+ if seq is parent:
237
+ # Remove the parent sequence if it is not selected for next
238
+ # iteration
239
+ seq_group.remove(seq.seq_id)
240
+ self.scheduler.free_seq(seq)
241
+
242
+ def _check_beam_search_early_stopping(
243
+ self,
244
+ early_stopping: Union[bool, str],
245
+ sampling_params: SamplingParams,
246
+ best_running_seq: Sequence,
247
+ current_worst_seq: Sequence,
248
+ ) -> bool:
249
+ assert sampling_params.use_beam_search
250
+ length_penalty = sampling_params.length_penalty
251
+ if early_stopping is True:
252
+ return True
253
+
254
+ current_worst_score = current_worst_seq.get_beam_search_score(
255
+ length_penalty=length_penalty,
256
+ eos_token_id=current_worst_seq.eos_token_id)
257
+ if early_stopping is False:
258
+ highest_attainable_score = best_running_seq.get_beam_search_score(
259
+ length_penalty=length_penalty,
260
+ eos_token_id=best_running_seq.eos_token_id)
261
+ else:
262
+ assert early_stopping == "never"
263
+ if length_penalty > 0.0:
264
+ # If length_penalty > 0.0, beam search will prefer longer
265
+ # sequences. The highest attainable score calculation is
266
+ # based on the longest possible sequence length in this case.
267
+ max_possible_length = max(
268
+ best_running_seq.get_prompt_len() +
269
+ sampling_params.max_tokens,
270
+ self.scheduler_config.max_model_len)
271
+ highest_attainable_score = (
272
+ best_running_seq.get_beam_search_score(
273
+ length_penalty=length_penalty,
274
+ eos_token_id=best_running_seq.eos_token_id,
275
+ seq_len=max_possible_length))
276
+ else:
277
+ # Otherwise, beam search will prefer shorter sequences. The
278
+ # highest attainable score calculation is based on the current
279
+ # sequence length.
280
+ highest_attainable_score = (
281
+ best_running_seq.get_beam_search_score(
282
+ length_penalty=length_penalty,
283
+ eos_token_id=best_running_seq.eos_token_id))
284
+ return current_worst_score >= highest_attainable_score
@@ -0,0 +1,101 @@
1
+ from typing import Callable, Optional
2
+
3
+ from transformers import PreTrainedTokenizer
4
+
5
+ from vllm.sampling_params import SamplingParams
6
+ from vllm.sequence import Sequence, SequenceStatus
7
+
8
+
9
+ class StopChecker:
10
+ """LLMEngine helper class which separates out the logic involving stop
11
+ checking. This checks things such as: whether the eos token was emitted,
12
+ whether the max_tokens has been consumed, whether a stop string has been
13
+ emitted, or if we have exceeded the max model len.
14
+ """
15
+
16
+ def __init__(self, max_model_len: int,
17
+ get_tokenizer_for_seq: Callable[[Sequence],
18
+ PreTrainedTokenizer]):
19
+ self.max_model_len = max_model_len
20
+ self.get_tokenizer_for_seq = get_tokenizer_for_seq
21
+
22
+ def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
23
+ sampling_params: SamplingParams) -> None:
24
+ """Stop the finished sequences.
25
+
26
+ new_char_count is the number of chars added to the
27
+ sequence's output text for the newly generated token
28
+ """
29
+
30
+ # Check if the minimum number of tokens has been generated yet;
31
+ # skip the stop string/token checks if not
32
+ if seq.get_output_len() < sampling_params.min_tokens:
33
+ return
34
+
35
+ # Check if the sequence has generated the EOS token.
36
+ if ((not sampling_params.ignore_eos)
37
+ and seq.get_last_token_id() == seq.eos_token_id):
38
+ seq.status = SequenceStatus.FINISHED_STOPPED
39
+ return
40
+
41
+ # Check if a stop token was encountered.
42
+ # This assumes a single token produced per step.
43
+ last_token_id = seq.get_last_token_id()
44
+ if last_token_id in sampling_params.stop_token_ids:
45
+ if new_char_count and (
46
+ not sampling_params.include_stop_str_in_output):
47
+ # Remove last token
48
+ seq.output_text = seq.output_text[:-new_char_count]
49
+ seq.status = SequenceStatus.FINISHED_STOPPED
50
+ seq.stop_reason = last_token_id
51
+ return
52
+
53
+ # Check if any stop strings are matched.
54
+ stop_str = self._check_stop_strings(seq, new_char_count,
55
+ sampling_params)
56
+ if stop_str is not None:
57
+ seq.status = SequenceStatus.FINISHED_STOPPED
58
+ seq.stop_reason = stop_str
59
+ return
60
+
61
+ # Check if the sequence has reached max_model_len.
62
+ if seq.get_len() > self.max_model_len:
63
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
64
+ return
65
+
66
+ # Check if the sequence has reached max_tokens.
67
+ if seq.get_output_len() == sampling_params.max_tokens:
68
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
69
+ return
70
+
71
+ @staticmethod
72
+ def _check_stop_strings(seq: Sequence, new_char_count: int,
73
+ sampling_params: SamplingParams) -> Optional[str]:
74
+ """Check if any stop strings are matched and truncate sequence
75
+ output text accordingly.
76
+
77
+ Returns the stop string if matched or else None.
78
+ """
79
+ if not new_char_count:
80
+ return None
81
+
82
+ for stop_str in sampling_params.stop:
83
+ stop_string_len = len(stop_str)
84
+ # Avoid searching already-searched text.
85
+ stop_index = seq.output_text.find(
86
+ stop_str, -new_char_count - stop_string_len)
87
+ if stop_index == -1:
88
+ continue
89
+
90
+ if sampling_params.include_stop_str_in_output:
91
+ # Truncate to end of stop string.
92
+ stop_index += stop_string_len
93
+ if stop_index >= len(seq.output_text):
94
+ # No truncation required.
95
+ return stop_str
96
+
97
+ # Truncate the output text to either the beginning
98
+ # or end of the stop string.
99
+ seq.output_text = seq.output_text[:stop_index]
100
+ return stop_str
101
+ return None
@@ -0,0 +1,19 @@
1
+ from typing import List
2
+
3
+ from vllm.sequence import SamplerOutput, SequenceGroupOutput
4
+
5
+
6
+ def create_output_by_sequence_group(
7
+ sampler_outputs: List[SamplerOutput],
8
+ num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
9
+ """Helper method which transforms a 2d list organized by
10
+ [step][sequence group] into [sequence group][step].
11
+ """
12
+ output_by_sequence_group: List[List[SamplerOutput]] = [
13
+ [] for _ in range(num_seq_groups)
14
+ ]
15
+ for step in sampler_outputs:
16
+ for i, sequence_group_output in enumerate(step):
17
+ output_by_sequence_group[i].append(sequence_group_output)
18
+
19
+ return output_by_sequence_group
File without changes
@@ -0,0 +1,119 @@
1
+ """
2
+ NOTE: This API server is used only for demonstrating usage of AsyncEngine
3
+ and simple performance benchmarks. It is not intended for production use.
4
+ For production use, we recommend using our OpenAI compatible server.
5
+ We are also not going to accept PRs modifying this file, please
6
+ change `vllm/entrypoints/openai/api_server.py` instead.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import ssl
12
+ from typing import AsyncGenerator
13
+
14
+ import uvicorn
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
17
+
18
+ from vllm.engine.arg_utils import AsyncEngineArgs
19
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
20
+ from vllm.sampling_params import SamplingParams
21
+ from vllm.usage.usage_lib import UsageContext
22
+ from vllm.utils import random_uuid
23
+
24
+ TIMEOUT_KEEP_ALIVE = 5 # seconds.
25
+ app = FastAPI()
26
+ engine = None
27
+
28
+
29
+ @app.get("/health")
30
+ async def health() -> Response:
31
+ """Health check."""
32
+ return Response(status_code=200)
33
+
34
+
35
+ @app.post("/generate")
36
+ async def generate(request: Request) -> Response:
37
+ """Generate completion for the request.
38
+
39
+ The request should be a JSON object with the following fields:
40
+ - prompt: the prompt to use for the generation.
41
+ - stream: whether to stream the results or not.
42
+ - other fields: the sampling parameters (See `SamplingParams` for details).
43
+ """
44
+ request_dict = await request.json()
45
+ prompt = request_dict.pop("prompt")
46
+ stream = request_dict.pop("stream", False)
47
+ sampling_params = SamplingParams(**request_dict)
48
+ request_id = random_uuid()
49
+
50
+ assert engine is not None
51
+ results_generator = engine.generate(prompt, sampling_params, request_id)
52
+
53
+ # Streaming case
54
+ async def stream_results() -> AsyncGenerator[bytes, None]:
55
+ async for request_output in results_generator:
56
+ prompt = request_output.prompt
57
+ text_outputs = [
58
+ prompt + output.text for output in request_output.outputs
59
+ ]
60
+ ret = {"text": text_outputs}
61
+ yield (json.dumps(ret) + "\0").encode("utf-8")
62
+
63
+ if stream:
64
+ return StreamingResponse(stream_results())
65
+
66
+ # Non-streaming case
67
+ final_output = None
68
+ async for request_output in results_generator:
69
+ if await request.is_disconnected():
70
+ # Abort the request if the client disconnects.
71
+ await engine.abort(request_id)
72
+ return Response(status_code=499)
73
+ final_output = request_output
74
+
75
+ assert final_output is not None
76
+ prompt = final_output.prompt
77
+ text_outputs = [prompt + output.text for output in final_output.outputs]
78
+ ret = {"text": text_outputs}
79
+ return JSONResponse(ret)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--host", type=str, default=None)
85
+ parser.add_argument("--port", type=int, default=8000)
86
+ parser.add_argument("--ssl-keyfile", type=str, default=None)
87
+ parser.add_argument("--ssl-certfile", type=str, default=None)
88
+ parser.add_argument("--ssl-ca-certs",
89
+ type=str,
90
+ default=None,
91
+ help="The CA certificates file")
92
+ parser.add_argument(
93
+ "--ssl-cert-reqs",
94
+ type=int,
95
+ default=int(ssl.CERT_NONE),
96
+ help="Whether client certificate is required (see stdlib ssl module's)"
97
+ )
98
+ parser.add_argument(
99
+ "--root-path",
100
+ type=str,
101
+ default=None,
102
+ help="FastAPI root_path when app is behind a path based routing proxy")
103
+ parser.add_argument("--log-level", type=str, default="debug")
104
+ parser = AsyncEngineArgs.add_cli_args(parser)
105
+ args = parser.parse_args()
106
+ engine_args = AsyncEngineArgs.from_cli_args(args)
107
+ engine = AsyncLLMEngine.from_engine_args(
108
+ engine_args, usage_context=UsageContext.API_SERVER)
109
+
110
+ app.root_path = args.root_path
111
+ uvicorn.run(app,
112
+ host=args.host,
113
+ port=args.port,
114
+ log_level=args.log_level,
115
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
116
+ ssl_keyfile=args.ssl_keyfile,
117
+ ssl_certfile=args.ssl_certfile,
118
+ ssl_ca_certs=args.ssl_ca_certs,
119
+ ssl_cert_reqs=args.ssl_cert_reqs)