vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,737 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from functools import partial
|
4
|
+
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
5
|
+
Optional, Set, Tuple, Type, Union)
|
6
|
+
|
7
|
+
from transformers import PreTrainedTokenizer
|
8
|
+
|
9
|
+
import vllm.envs as envs
|
10
|
+
from vllm.config import DecodingConfig, ModelConfig
|
11
|
+
from vllm.core.scheduler import SchedulerOutputs
|
12
|
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
13
|
+
from vllm.engine.llm_engine import LLMEngine
|
14
|
+
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
15
|
+
from vllm.logger import init_logger
|
16
|
+
from vllm.lora.request import LoRARequest
|
17
|
+
from vllm.outputs import RequestOutput
|
18
|
+
from vllm.sampling_params import SamplingParams
|
19
|
+
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
20
|
+
from vllm.usage.usage_lib import UsageContext
|
21
|
+
|
22
|
+
logger = init_logger(__name__)
|
23
|
+
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
24
|
+
|
25
|
+
|
26
|
+
class AsyncEngineDeadError(RuntimeError):
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
def _raise_exception_on_finish(
|
31
|
+
task: asyncio.Task, error_callback: Callable[[Exception],
|
32
|
+
None]) -> None:
|
33
|
+
msg = ("Task finished unexpectedly. This should never happen! "
|
34
|
+
"Please open an issue on Github.")
|
35
|
+
|
36
|
+
exception = None
|
37
|
+
try:
|
38
|
+
task.result()
|
39
|
+
# NOTE: This will be thrown if task exits normally (which it should not)
|
40
|
+
raise AsyncEngineDeadError(msg)
|
41
|
+
except Exception as e:
|
42
|
+
exception = e
|
43
|
+
logger.error("Engine background task failed", exc_info=e)
|
44
|
+
error_callback(exception)
|
45
|
+
raise AsyncEngineDeadError(
|
46
|
+
msg + " See stack trace above for the actual cause.") from e
|
47
|
+
|
48
|
+
|
49
|
+
class AsyncStream:
|
50
|
+
"""A stream of RequestOutputs for a request that can be
|
51
|
+
iterated over asynchronously."""
|
52
|
+
|
53
|
+
def __init__(self, request_id: str) -> None:
|
54
|
+
self.request_id = request_id
|
55
|
+
self._queue: asyncio.Queue = asyncio.Queue()
|
56
|
+
self._finished = False
|
57
|
+
|
58
|
+
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
59
|
+
if self._finished:
|
60
|
+
return
|
61
|
+
self._queue.put_nowait(item)
|
62
|
+
|
63
|
+
def finish(self) -> None:
|
64
|
+
self._queue.put_nowait(StopAsyncIteration())
|
65
|
+
self._finished = True
|
66
|
+
|
67
|
+
@property
|
68
|
+
def finished(self) -> bool:
|
69
|
+
return self._finished
|
70
|
+
|
71
|
+
def __aiter__(self):
|
72
|
+
return self
|
73
|
+
|
74
|
+
async def __anext__(self) -> RequestOutput:
|
75
|
+
result = await self._queue.get()
|
76
|
+
if isinstance(result, Exception):
|
77
|
+
raise result
|
78
|
+
return result
|
79
|
+
|
80
|
+
|
81
|
+
class RequestTracker:
|
82
|
+
"""Synchronous abstraction for tracking requests."""
|
83
|
+
|
84
|
+
def __init__(self) -> None:
|
85
|
+
self._request_streams: Dict[str, AsyncStream] = {}
|
86
|
+
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
87
|
+
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
88
|
+
dict]] = asyncio.Queue()
|
89
|
+
self.new_requests_event = asyncio.Event()
|
90
|
+
|
91
|
+
def __contains__(self, item):
|
92
|
+
return item in self._request_streams
|
93
|
+
|
94
|
+
def __len__(self) -> int:
|
95
|
+
return len(self._request_streams)
|
96
|
+
|
97
|
+
def propagate_exception(self,
|
98
|
+
exc: Exception,
|
99
|
+
request_id: Optional[str] = None) -> None:
|
100
|
+
"""Propagate an exception to request streams
|
101
|
+
(all if request_id is None)."""
|
102
|
+
if request_id is not None:
|
103
|
+
self._request_streams[request_id].put(exc)
|
104
|
+
self.abort_request(request_id)
|
105
|
+
else:
|
106
|
+
for rid, stream in self._request_streams.items():
|
107
|
+
stream.put(exc)
|
108
|
+
self.abort_request(rid)
|
109
|
+
|
110
|
+
def process_request_output(self,
|
111
|
+
request_output: RequestOutput,
|
112
|
+
*,
|
113
|
+
verbose: bool = False) -> None:
|
114
|
+
"""Process a request output from the engine."""
|
115
|
+
request_id = request_output.request_id
|
116
|
+
|
117
|
+
self._request_streams[request_id].put(request_output)
|
118
|
+
if request_output.finished:
|
119
|
+
if verbose:
|
120
|
+
logger.info("Finished request %s.", request_id)
|
121
|
+
self.abort_request(request_id)
|
122
|
+
|
123
|
+
def process_exception(self,
|
124
|
+
request_id: str,
|
125
|
+
exception: Exception,
|
126
|
+
*,
|
127
|
+
verbose: bool = False) -> None:
|
128
|
+
"""Propagate an exception from the engine."""
|
129
|
+
self._request_streams[request_id].put(exception)
|
130
|
+
if verbose:
|
131
|
+
logger.info("Finished request %s.", request_id)
|
132
|
+
self.abort_request(request_id)
|
133
|
+
|
134
|
+
def add_request(self, request_id: str,
|
135
|
+
**engine_add_request_kwargs) -> AsyncStream:
|
136
|
+
"""Add a request to be sent to the engine on the next background
|
137
|
+
loop iteration."""
|
138
|
+
if request_id in self._request_streams:
|
139
|
+
raise KeyError(f"Request {request_id} already exists.")
|
140
|
+
|
141
|
+
stream = AsyncStream(request_id)
|
142
|
+
self._new_requests.put_nowait((stream, {
|
143
|
+
"request_id": request_id,
|
144
|
+
**engine_add_request_kwargs
|
145
|
+
}))
|
146
|
+
|
147
|
+
self.new_requests_event.set()
|
148
|
+
|
149
|
+
return stream
|
150
|
+
|
151
|
+
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
152
|
+
"""Abort a request during next background loop iteration."""
|
153
|
+
if verbose:
|
154
|
+
logger.info("Aborted request %s.", request_id)
|
155
|
+
|
156
|
+
self._finished_requests.put_nowait(request_id)
|
157
|
+
|
158
|
+
if request_id not in self._request_streams or self._request_streams[
|
159
|
+
request_id].finished:
|
160
|
+
# The request has already finished or been aborted.
|
161
|
+
return
|
162
|
+
|
163
|
+
self._request_streams[request_id].finish()
|
164
|
+
|
165
|
+
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
166
|
+
"""Get the new requests and finished requests to be
|
167
|
+
sent to the engine."""
|
168
|
+
new_requests: List[Dict] = []
|
169
|
+
finished_requests: Set[str] = set()
|
170
|
+
|
171
|
+
while not self._finished_requests.empty():
|
172
|
+
request_id = self._finished_requests.get_nowait()
|
173
|
+
finished_requests.add(request_id)
|
174
|
+
self._request_streams.pop(request_id, None)
|
175
|
+
|
176
|
+
while not self._new_requests.empty():
|
177
|
+
stream, new_request = self._new_requests.get_nowait()
|
178
|
+
if stream.request_id in finished_requests:
|
179
|
+
# The request has already been aborted.
|
180
|
+
stream.finish()
|
181
|
+
continue
|
182
|
+
self._request_streams[stream.request_id] = stream
|
183
|
+
new_requests.append(new_request)
|
184
|
+
|
185
|
+
return new_requests, finished_requests
|
186
|
+
|
187
|
+
async def wait_for_new_requests(self):
|
188
|
+
if not self.has_new_requests():
|
189
|
+
await self.new_requests_event.wait()
|
190
|
+
self.new_requests_event.clear()
|
191
|
+
|
192
|
+
def has_new_requests(self):
|
193
|
+
return not self._new_requests.empty()
|
194
|
+
|
195
|
+
|
196
|
+
class _AsyncLLMEngine(LLMEngine):
|
197
|
+
"""Extension of LLMEngine to add async methods."""
|
198
|
+
|
199
|
+
async def step_async(self) -> List[RequestOutput]:
|
200
|
+
"""Performs one decoding iteration and returns newly generated results.
|
201
|
+
The workers are ran asynchronously if possible.
|
202
|
+
|
203
|
+
This function performs one decoding iteration of the engine. It first
|
204
|
+
schedules the sequences to be executed in the next iteration and the
|
205
|
+
token blocks to be swapped in/out/copy. Then, it executes the model
|
206
|
+
and updates the scheduler with the model outputs. Finally, it decodes
|
207
|
+
the sequences and returns the newly generated results.
|
208
|
+
"""
|
209
|
+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
210
|
+
|
211
|
+
if not scheduler_outputs.is_empty():
|
212
|
+
# Execute the model.
|
213
|
+
execute_model_req = ExecuteModelRequest(
|
214
|
+
seq_group_metadata_list=seq_group_metadata_list,
|
215
|
+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
216
|
+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
217
|
+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
218
|
+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
219
|
+
running_queue_size=scheduler_outputs.running_queue_size,
|
220
|
+
)
|
221
|
+
output = await self.model_executor.execute_model_async(
|
222
|
+
execute_model_req)
|
223
|
+
else:
|
224
|
+
output = []
|
225
|
+
|
226
|
+
request_outputs = self._process_model_outputs(
|
227
|
+
output, scheduler_outputs.scheduled_seq_groups,
|
228
|
+
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
229
|
+
|
230
|
+
# Log stats.
|
231
|
+
self.do_log_stats(scheduler_outputs, output)
|
232
|
+
|
233
|
+
return request_outputs
|
234
|
+
|
235
|
+
async def encode_request_async(
|
236
|
+
self,
|
237
|
+
request_id: str, # pylint: disable=unused-argument
|
238
|
+
prompt: Optional[str],
|
239
|
+
prompt_token_ids: Optional[List[int]] = None,
|
240
|
+
lora_request: Optional[LoRARequest] = None,
|
241
|
+
):
|
242
|
+
if prompt_token_ids is None:
|
243
|
+
assert prompt is not None
|
244
|
+
prompt_token_ids = await self.tokenizer.encode_async(
|
245
|
+
request_id=request_id,
|
246
|
+
prompt=prompt,
|
247
|
+
lora_request=lora_request)
|
248
|
+
return prompt_token_ids
|
249
|
+
|
250
|
+
async def add_request_async(
|
251
|
+
self,
|
252
|
+
request_id: str,
|
253
|
+
prompt: Optional[str],
|
254
|
+
sampling_params: SamplingParams,
|
255
|
+
prompt_token_ids: Optional[List[int]] = None,
|
256
|
+
arrival_time: Optional[float] = None,
|
257
|
+
lora_request: Optional[LoRARequest] = None,
|
258
|
+
multi_modal_data: Optional[MultiModalData] = None,
|
259
|
+
) -> None:
|
260
|
+
if lora_request is not None and not self.lora_config:
|
261
|
+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
262
|
+
"not enabled!")
|
263
|
+
if arrival_time is None:
|
264
|
+
arrival_time = time.time()
|
265
|
+
prompt_token_ids = await self.encode_request_async(
|
266
|
+
request_id=request_id,
|
267
|
+
prompt=prompt,
|
268
|
+
prompt_token_ids=prompt_token_ids,
|
269
|
+
lora_request=lora_request)
|
270
|
+
|
271
|
+
return self.add_request(request_id,
|
272
|
+
prompt=prompt,
|
273
|
+
prompt_token_ids=prompt_token_ids,
|
274
|
+
sampling_params=sampling_params,
|
275
|
+
arrival_time=arrival_time,
|
276
|
+
lora_request=lora_request,
|
277
|
+
multi_modal_data=multi_modal_data)
|
278
|
+
|
279
|
+
async def check_health_async(self) -> None:
|
280
|
+
self.model_executor.check_health()
|
281
|
+
|
282
|
+
|
283
|
+
class AsyncLLMEngine:
|
284
|
+
"""An asynchronous wrapper for LLMEngine.
|
285
|
+
|
286
|
+
This class is used to wrap the LLMEngine class to make it asynchronous. It
|
287
|
+
uses asyncio to create a background loop that keeps processing incoming
|
288
|
+
requests. The LLMEngine is kicked by the generate method when there
|
289
|
+
are requests in the waiting queue. The generate method yields the outputs
|
290
|
+
from the LLMEngine to the caller.
|
291
|
+
|
292
|
+
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
worker_use_ray: Whether to use Ray for model workers. Required for
|
296
|
+
distributed execution. Should be the same as
|
297
|
+
`parallel_config.worker_use_ray`.
|
298
|
+
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
299
|
+
async frontend will be executed in a separate process as the
|
300
|
+
model workers.
|
301
|
+
log_requests: Whether to log the requests.
|
302
|
+
max_log_len: Maximum number of prompt characters or prompt ID numbers
|
303
|
+
being printed in log.
|
304
|
+
start_engine_loop: If True, the background task to run the engine
|
305
|
+
will be automatically started in the generate call.
|
306
|
+
*args: Arguments for LLMEngine.
|
307
|
+
*kwargs: Arguments for LLMEngine.
|
308
|
+
"""
|
309
|
+
|
310
|
+
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
311
|
+
|
312
|
+
def __init__(self,
|
313
|
+
worker_use_ray: bool,
|
314
|
+
engine_use_ray: bool,
|
315
|
+
*args,
|
316
|
+
log_requests: bool = True,
|
317
|
+
max_log_len: Optional[int] = None,
|
318
|
+
start_engine_loop: bool = True,
|
319
|
+
**kwargs) -> None:
|
320
|
+
self.worker_use_ray = worker_use_ray
|
321
|
+
self.engine_use_ray = engine_use_ray
|
322
|
+
self.log_requests = log_requests
|
323
|
+
self.max_log_len = max_log_len
|
324
|
+
self.engine = self._init_engine(*args, **kwargs)
|
325
|
+
|
326
|
+
self.background_loop: Optional[asyncio.Future] = None
|
327
|
+
# We need to keep a reference to unshielded
|
328
|
+
# task as well to prevent it from being garbage
|
329
|
+
# collected
|
330
|
+
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
|
331
|
+
self.start_engine_loop = start_engine_loop
|
332
|
+
self._errored_with: Optional[BaseException] = None
|
333
|
+
|
334
|
+
# Lazy initialized fields
|
335
|
+
self._request_tracker: RequestTracker
|
336
|
+
|
337
|
+
@classmethod
|
338
|
+
def from_engine_args(
|
339
|
+
cls,
|
340
|
+
engine_args: AsyncEngineArgs,
|
341
|
+
start_engine_loop: bool = True,
|
342
|
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
343
|
+
) -> "AsyncLLMEngine":
|
344
|
+
"""Creates an async LLM engine from the engine arguments."""
|
345
|
+
# Create the engine configs.
|
346
|
+
engine_config = engine_args.create_engine_config()
|
347
|
+
|
348
|
+
if engine_config.device_config.device_type == "neuron":
|
349
|
+
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
350
|
+
executor_class = NeuronExecutorAsync
|
351
|
+
elif engine_config.device_config.device_type == "cpu":
|
352
|
+
assert not engine_config.parallel_config.worker_use_ray, (
|
353
|
+
"Ray is not supported with the CPU backend.")
|
354
|
+
from vllm.executor.cpu_executor import CPUExecutorAsync
|
355
|
+
executor_class = CPUExecutorAsync
|
356
|
+
elif engine_config.parallel_config.worker_use_ray:
|
357
|
+
initialize_ray_cluster(engine_config.parallel_config)
|
358
|
+
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
359
|
+
executor_class = RayGPUExecutorAsync
|
360
|
+
else:
|
361
|
+
assert engine_config.parallel_config.world_size == 1, (
|
362
|
+
"Ray is required if parallel_config.world_size > 1.")
|
363
|
+
from vllm.executor.gpu_executor import GPUExecutorAsync
|
364
|
+
executor_class = GPUExecutorAsync
|
365
|
+
# Create the async LLM engine.
|
366
|
+
engine = cls(
|
367
|
+
engine_config.parallel_config.worker_use_ray,
|
368
|
+
engine_args.engine_use_ray,
|
369
|
+
**engine_config.to_dict(),
|
370
|
+
executor_class=executor_class,
|
371
|
+
log_requests=not engine_args.disable_log_requests,
|
372
|
+
log_stats=not engine_args.disable_log_stats,
|
373
|
+
max_log_len=engine_args.max_log_len,
|
374
|
+
start_engine_loop=start_engine_loop,
|
375
|
+
usage_context=usage_context,
|
376
|
+
)
|
377
|
+
return engine
|
378
|
+
|
379
|
+
@property
|
380
|
+
def is_running(self) -> bool:
|
381
|
+
return (self.background_loop is not None
|
382
|
+
and self._background_loop_unshielded is not None
|
383
|
+
and not self._background_loop_unshielded.done())
|
384
|
+
|
385
|
+
@property
|
386
|
+
def is_stopped(self) -> bool:
|
387
|
+
return self.errored or (self.background_loop is not None and
|
388
|
+
self._background_loop_unshielded is not None
|
389
|
+
and self._background_loop_unshielded.done())
|
390
|
+
|
391
|
+
@property
|
392
|
+
def errored(self) -> bool:
|
393
|
+
return self._errored_with is not None
|
394
|
+
|
395
|
+
def set_errored(self, exc: Exception) -> None:
|
396
|
+
self._errored_with = exc
|
397
|
+
|
398
|
+
def _error_callback(self, exc: Exception) -> None:
|
399
|
+
self.set_errored(exc)
|
400
|
+
self._request_tracker.propagate_exception(exc)
|
401
|
+
|
402
|
+
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
403
|
+
if self.engine_use_ray:
|
404
|
+
return await self.engine.get_tokenizer.remote() # type: ignore
|
405
|
+
else:
|
406
|
+
return self.engine.get_tokenizer()
|
407
|
+
|
408
|
+
def start_background_loop(self) -> None:
|
409
|
+
"""Start the background loop."""
|
410
|
+
if self.errored:
|
411
|
+
raise AsyncEngineDeadError(
|
412
|
+
"Background loop has errored already.") from self._errored_with
|
413
|
+
if self.is_running:
|
414
|
+
raise RuntimeError("Background loop is already running.")
|
415
|
+
# Initialize the RequestTracker here so it uses the right event loop.
|
416
|
+
self._request_tracker = RequestTracker()
|
417
|
+
|
418
|
+
self._background_loop_unshielded = asyncio.get_event_loop(
|
419
|
+
).create_task(self.run_engine_loop())
|
420
|
+
self._background_loop_unshielded.add_done_callback(
|
421
|
+
partial(_raise_exception_on_finish,
|
422
|
+
error_callback=self._error_callback))
|
423
|
+
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
424
|
+
|
425
|
+
def _init_engine(self, *args,
|
426
|
+
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
427
|
+
if not self.engine_use_ray:
|
428
|
+
engine_class = self._engine_class
|
429
|
+
elif self.worker_use_ray:
|
430
|
+
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
431
|
+
else:
|
432
|
+
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
433
|
+
# order of the arguments.
|
434
|
+
cache_config = kwargs["cache_config"]
|
435
|
+
parallel_config = kwargs["parallel_config"]
|
436
|
+
if parallel_config.tensor_parallel_size == 1:
|
437
|
+
num_gpus = cache_config.gpu_memory_utilization
|
438
|
+
else:
|
439
|
+
num_gpus = 1
|
440
|
+
engine_class = ray.remote(num_gpus=num_gpus)(
|
441
|
+
self._engine_class).remote
|
442
|
+
return engine_class(*args, **kwargs)
|
443
|
+
|
444
|
+
async def engine_step(self) -> bool:
|
445
|
+
"""Kick the engine to process the waiting requests.
|
446
|
+
|
447
|
+
Returns True if there are in-progress requests."""
|
448
|
+
|
449
|
+
new_requests, finished_requests = (
|
450
|
+
self._request_tracker.get_new_and_finished_requests())
|
451
|
+
|
452
|
+
for new_request in new_requests:
|
453
|
+
# Add the request into the vLLM engine's waiting queue.
|
454
|
+
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
455
|
+
try:
|
456
|
+
if self.engine_use_ray:
|
457
|
+
await self.engine.add_request.remote( # type: ignore
|
458
|
+
**new_request)
|
459
|
+
else:
|
460
|
+
await self.engine.add_request_async(**new_request)
|
461
|
+
except ValueError as e:
|
462
|
+
# TODO: use a vLLM specific error for failed validation
|
463
|
+
self._request_tracker.process_exception(
|
464
|
+
new_request["request_id"],
|
465
|
+
e,
|
466
|
+
verbose=self.log_requests,
|
467
|
+
)
|
468
|
+
|
469
|
+
if finished_requests:
|
470
|
+
await self._engine_abort(finished_requests)
|
471
|
+
|
472
|
+
if self.engine_use_ray:
|
473
|
+
request_outputs = await self.engine.step.remote() # type: ignore
|
474
|
+
else:
|
475
|
+
request_outputs = await self.engine.step_async()
|
476
|
+
|
477
|
+
# Put the outputs into the corresponding streams.
|
478
|
+
for request_output in request_outputs:
|
479
|
+
self._request_tracker.process_request_output(
|
480
|
+
request_output, verbose=self.log_requests)
|
481
|
+
|
482
|
+
return len(request_outputs) > 0
|
483
|
+
|
484
|
+
async def _engine_abort(self, request_ids: Iterable[str]):
|
485
|
+
if self.engine_use_ray:
|
486
|
+
await self.engine.abort_request.remote(request_ids) # type: ignore
|
487
|
+
else:
|
488
|
+
self.engine.abort_request(request_ids)
|
489
|
+
|
490
|
+
async def run_engine_loop(self):
|
491
|
+
has_requests_in_progress = False
|
492
|
+
while True:
|
493
|
+
if not has_requests_in_progress:
|
494
|
+
logger.debug("Waiting for new requests...")
|
495
|
+
await self._request_tracker.wait_for_new_requests()
|
496
|
+
logger.debug("Got new requests!")
|
497
|
+
|
498
|
+
# Abort if iteration takes too long due to unrecoverable errors
|
499
|
+
# (eg. NCCL timeouts).
|
500
|
+
try:
|
501
|
+
has_requests_in_progress = await asyncio.wait_for(
|
502
|
+
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
503
|
+
except asyncio.TimeoutError as exc:
|
504
|
+
logger.error(
|
505
|
+
"Engine iteration timed out. This should never happen!")
|
506
|
+
self.set_errored(exc)
|
507
|
+
raise
|
508
|
+
await asyncio.sleep(0)
|
509
|
+
|
510
|
+
async def add_request(
|
511
|
+
self,
|
512
|
+
request_id: str,
|
513
|
+
prompt: Optional[str],
|
514
|
+
sampling_params: SamplingParams,
|
515
|
+
prompt_token_ids: Optional[List[int]] = None,
|
516
|
+
arrival_time: Optional[float] = None,
|
517
|
+
lora_request: Optional[LoRARequest] = None,
|
518
|
+
multi_modal_data: Optional[MultiModalData] = None,
|
519
|
+
) -> AsyncStream:
|
520
|
+
if self.log_requests:
|
521
|
+
shortened_prompt = prompt
|
522
|
+
shortened_token_ids = prompt_token_ids
|
523
|
+
if self.max_log_len is not None:
|
524
|
+
if shortened_prompt is not None:
|
525
|
+
shortened_prompt = shortened_prompt[:self.max_log_len]
|
526
|
+
if shortened_token_ids is not None:
|
527
|
+
shortened_token_ids = shortened_token_ids[:self.
|
528
|
+
max_log_len]
|
529
|
+
logger.info(
|
530
|
+
"Received request %s: prompt: %r, "
|
531
|
+
"sampling_params: %s, prompt_token_ids: %s, "
|
532
|
+
"lora_request: %s.", request_id, shortened_prompt,
|
533
|
+
sampling_params, shortened_token_ids, lora_request)
|
534
|
+
|
535
|
+
if not self.is_running:
|
536
|
+
if self.start_engine_loop:
|
537
|
+
self.start_background_loop()
|
538
|
+
else:
|
539
|
+
raise AsyncEngineDeadError(
|
540
|
+
"Background loop is not running. If it was running, "
|
541
|
+
"inspect the output to find the stacktrace of the "
|
542
|
+
"error that caused the background loop to stop "
|
543
|
+
"(AsyncEngineDeadError).")
|
544
|
+
|
545
|
+
if arrival_time is None:
|
546
|
+
arrival_time = time.time()
|
547
|
+
|
548
|
+
if self.engine_use_ray:
|
549
|
+
prompt_token_ids = await (
|
550
|
+
self.engine.encode_request_async.remote( # type: ignore
|
551
|
+
request_id=request_id,
|
552
|
+
prompt=prompt,
|
553
|
+
prompt_token_ids=prompt_token_ids,
|
554
|
+
lora_request=lora_request))
|
555
|
+
else:
|
556
|
+
prompt_token_ids = await self.engine.encode_request_async(
|
557
|
+
request_id=request_id,
|
558
|
+
prompt=prompt,
|
559
|
+
prompt_token_ids=prompt_token_ids,
|
560
|
+
lora_request=lora_request)
|
561
|
+
|
562
|
+
stream = self._request_tracker.add_request(
|
563
|
+
request_id,
|
564
|
+
prompt=prompt,
|
565
|
+
sampling_params=sampling_params,
|
566
|
+
prompt_token_ids=prompt_token_ids,
|
567
|
+
arrival_time=arrival_time,
|
568
|
+
lora_request=lora_request,
|
569
|
+
multi_modal_data=multi_modal_data,
|
570
|
+
)
|
571
|
+
|
572
|
+
return stream
|
573
|
+
|
574
|
+
async def generate(
|
575
|
+
self,
|
576
|
+
prompt: Optional[str],
|
577
|
+
sampling_params: SamplingParams,
|
578
|
+
request_id: str,
|
579
|
+
prompt_token_ids: Optional[List[int]] = None,
|
580
|
+
lora_request: Optional[LoRARequest] = None,
|
581
|
+
multi_modal_data: Optional[MultiModalData] = None
|
582
|
+
) -> AsyncIterator[RequestOutput]:
|
583
|
+
"""Generate outputs for a request.
|
584
|
+
|
585
|
+
Generate outputs for a request. This method is a coroutine. It adds the
|
586
|
+
request into the waiting queue of the LLMEngine and streams the outputs
|
587
|
+
from the LLMEngine to the caller.
|
588
|
+
|
589
|
+
Args:
|
590
|
+
prompt: The prompt string. Can be None if prompt_token_ids is
|
591
|
+
provided.
|
592
|
+
sampling_params: The sampling parameters of the request.
|
593
|
+
request_id: The unique id of the request.
|
594
|
+
prompt_token_ids: The token IDs of the prompt. If None, we
|
595
|
+
use the tokenizer to convert the prompts to token IDs.
|
596
|
+
lora_request: LoRA request to use for generation, if any.
|
597
|
+
multi_modal_data: Multi modal data per request.
|
598
|
+
|
599
|
+
Yields:
|
600
|
+
The output `RequestOutput` objects from the LLMEngine for the
|
601
|
+
request.
|
602
|
+
|
603
|
+
Details:
|
604
|
+
- If the engine is not running, start the background loop,
|
605
|
+
which iteratively invokes
|
606
|
+
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
607
|
+
to process the waiting requests.
|
608
|
+
- Add the request to the engine's `RequestTracker`.
|
609
|
+
On the next background loop, this request will be sent to
|
610
|
+
the underlying engine.
|
611
|
+
Also, a corresponding `AsyncStream` will be created.
|
612
|
+
- Wait for the request outputs from `AsyncStream` and yield them.
|
613
|
+
|
614
|
+
Example:
|
615
|
+
>>> # Please refer to entrypoints/api_server.py for
|
616
|
+
>>> # the complete example.
|
617
|
+
>>>
|
618
|
+
>>> # initialize the engine and the example input
|
619
|
+
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
620
|
+
>>> example_input = {
|
621
|
+
>>> "prompt": "What is LLM?",
|
622
|
+
>>> "stream": False, # assume the non-streaming case
|
623
|
+
>>> "temperature": 0.0,
|
624
|
+
>>> "request_id": 0,
|
625
|
+
>>> }
|
626
|
+
>>>
|
627
|
+
>>> # start the generation
|
628
|
+
>>> results_generator = engine.generate(
|
629
|
+
>>> example_input["prompt"],
|
630
|
+
>>> SamplingParams(temperature=example_input["temperature"]),
|
631
|
+
>>> example_input["request_id"])
|
632
|
+
>>>
|
633
|
+
>>> # get the results
|
634
|
+
>>> final_output = None
|
635
|
+
>>> async for request_output in results_generator:
|
636
|
+
>>> if await request.is_disconnected():
|
637
|
+
>>> # Abort the request if the client disconnects.
|
638
|
+
>>> await engine.abort(request_id)
|
639
|
+
>>> # Return or raise an error
|
640
|
+
>>> ...
|
641
|
+
>>> final_output = request_output
|
642
|
+
>>>
|
643
|
+
>>> # Process and return the final output
|
644
|
+
>>> ...
|
645
|
+
"""
|
646
|
+
# Preprocess the request.
|
647
|
+
arrival_time = time.time()
|
648
|
+
|
649
|
+
try:
|
650
|
+
stream = await self.add_request(
|
651
|
+
request_id,
|
652
|
+
prompt,
|
653
|
+
sampling_params,
|
654
|
+
prompt_token_ids=prompt_token_ids,
|
655
|
+
arrival_time=arrival_time,
|
656
|
+
lora_request=lora_request,
|
657
|
+
multi_modal_data=multi_modal_data,
|
658
|
+
)
|
659
|
+
|
660
|
+
async for request_output in stream:
|
661
|
+
yield request_output
|
662
|
+
except (Exception, asyncio.CancelledError) as e:
|
663
|
+
# If there is an exception or coroutine is cancelled, abort the
|
664
|
+
# request.
|
665
|
+
self._abort(request_id)
|
666
|
+
raise e
|
667
|
+
|
668
|
+
async def abort(self, request_id: str) -> None:
|
669
|
+
"""Abort a request.
|
670
|
+
|
671
|
+
Abort a submitted request. If the request is finished or not found,
|
672
|
+
this method will be a no-op.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
request_id: The unique id of the request.
|
676
|
+
"""
|
677
|
+
if not self.is_running:
|
678
|
+
raise AsyncEngineDeadError(
|
679
|
+
"Background loop is not running. If it was running, "
|
680
|
+
"inspect the output to find the stacktrace of the "
|
681
|
+
"error that caused the background loop to stop "
|
682
|
+
"(AsyncEngineDeadError).")
|
683
|
+
|
684
|
+
return self._abort(request_id)
|
685
|
+
|
686
|
+
def _abort(self, request_id: str) -> None:
|
687
|
+
"""Abort a request.
|
688
|
+
|
689
|
+
Abort a submitted request. If the request is finished or not found,
|
690
|
+
this method will be a no-op.
|
691
|
+
|
692
|
+
Args:
|
693
|
+
request_id: The unique id of the request.
|
694
|
+
"""
|
695
|
+
self._request_tracker.abort_request(request_id,
|
696
|
+
verbose=self.log_requests)
|
697
|
+
|
698
|
+
async def get_model_config(self) -> ModelConfig:
|
699
|
+
"""Get the model configuration of the vLLM engine."""
|
700
|
+
if self.engine_use_ray:
|
701
|
+
return await self.engine.get_model_config.remote() # type: ignore
|
702
|
+
else:
|
703
|
+
return self.engine.get_model_config()
|
704
|
+
|
705
|
+
async def get_decoding_config(self) -> DecodingConfig:
|
706
|
+
"""Get the decoding configuration of the vLLM engine."""
|
707
|
+
if self.engine_use_ray:
|
708
|
+
return await self.engine.get_decoding_config.remote( # type: ignore
|
709
|
+
)
|
710
|
+
else:
|
711
|
+
return self.engine.get_decoding_config()
|
712
|
+
|
713
|
+
async def do_log_stats(
|
714
|
+
self,
|
715
|
+
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
716
|
+
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
717
|
+
if self.engine_use_ray:
|
718
|
+
await self.engine.do_log_stats.remote( # type: ignore
|
719
|
+
scheduler_outputs, model_output)
|
720
|
+
else:
|
721
|
+
self.engine.do_log_stats()
|
722
|
+
|
723
|
+
async def check_health(self) -> None:
|
724
|
+
"""Raises an error if engine is unhealthy."""
|
725
|
+
t = time.perf_counter()
|
726
|
+
logger.debug("Starting health check...")
|
727
|
+
if self.is_stopped:
|
728
|
+
raise AsyncEngineDeadError("Background loop is stopped.")
|
729
|
+
|
730
|
+
if self.engine_use_ray:
|
731
|
+
try:
|
732
|
+
await self.engine.check_health.remote() # type: ignore
|
733
|
+
except ray.exceptions.RayActorError as e:
|
734
|
+
raise RuntimeError("Engine is dead.") from e
|
735
|
+
else:
|
736
|
+
await self.engine.check_health_async()
|
737
|
+
logger.debug("Health check took %fs", time.perf_counter() - t)
|