tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
"""Logging configuration for vLLM."""
|
|
4
|
+
import datetime
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
from collections.abc import Hashable
|
|
9
|
+
from functools import lru_cache, partial
|
|
10
|
+
from logging import Logger
|
|
11
|
+
from logging.config import dictConfig
|
|
12
|
+
from os import path
|
|
13
|
+
from types import MethodType
|
|
14
|
+
from typing import Any, cast
|
|
15
|
+
|
|
16
|
+
import tpu_inference.mock.vllm_envs as envs
|
|
17
|
+
|
|
18
|
+
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
|
19
|
+
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
|
20
|
+
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
|
|
21
|
+
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
|
|
22
|
+
|
|
23
|
+
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
|
|
24
|
+
"[%(filename)s:%(lineno)d] %(message)s")
|
|
25
|
+
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
26
|
+
|
|
27
|
+
DEFAULT_LOGGING_CONFIG = {
|
|
28
|
+
"formatters": {
|
|
29
|
+
"vllm": {
|
|
30
|
+
"class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
|
|
31
|
+
"datefmt": _DATE_FORMAT,
|
|
32
|
+
"format": _FORMAT,
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
"handlers": {
|
|
36
|
+
"vllm": {
|
|
37
|
+
"class": "logging.StreamHandler",
|
|
38
|
+
"formatter": "vllm",
|
|
39
|
+
"level": VLLM_LOGGING_LEVEL,
|
|
40
|
+
"stream": "ext://sys.stdout",
|
|
41
|
+
},
|
|
42
|
+
},
|
|
43
|
+
"loggers": {
|
|
44
|
+
"vllm": {
|
|
45
|
+
"handlers": ["vllm"],
|
|
46
|
+
"level": "DEBUG",
|
|
47
|
+
"propagate": False,
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
"version": 1,
|
|
51
|
+
"disable_existing_loggers": False
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@lru_cache
|
|
56
|
+
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
57
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
58
|
+
logger.debug(msg, *args, stacklevel=2)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@lru_cache
|
|
62
|
+
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
63
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
64
|
+
logger.info(msg, *args, stacklevel=2)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@lru_cache
|
|
68
|
+
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
69
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
70
|
+
logger.warning(msg, *args, stacklevel=2)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _VllmLogger(Logger):
|
|
74
|
+
"""
|
|
75
|
+
Note:
|
|
76
|
+
This class is just to provide type information.
|
|
77
|
+
We actually patch the methods directly on the [`logging.Logger`][]
|
|
78
|
+
instance to avoid conflicting with other libraries such as
|
|
79
|
+
`intel_extension_for_pytorch.utils._logger`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def debug_once(self, msg: str, *args: Hashable) -> None:
|
|
83
|
+
"""
|
|
84
|
+
As [`debug`][logging.Logger.debug], but subsequent calls with
|
|
85
|
+
the same message are silently dropped.
|
|
86
|
+
"""
|
|
87
|
+
_print_debug_once(self, msg, *args)
|
|
88
|
+
|
|
89
|
+
def info_once(self, msg: str, *args: Hashable) -> None:
|
|
90
|
+
"""
|
|
91
|
+
As [`info`][logging.Logger.info], but subsequent calls with
|
|
92
|
+
the same message are silently dropped.
|
|
93
|
+
"""
|
|
94
|
+
_print_info_once(self, msg, *args)
|
|
95
|
+
|
|
96
|
+
def warning_once(self, msg: str, *args: Hashable) -> None:
|
|
97
|
+
"""
|
|
98
|
+
As [`warning`][logging.Logger.warning], but subsequent calls with
|
|
99
|
+
the same message are silently dropped.
|
|
100
|
+
"""
|
|
101
|
+
_print_warning_once(self, msg, *args)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Pre-defined methods mapping to avoid repeated dictionary creation
|
|
105
|
+
_METHODS_TO_PATCH = {
|
|
106
|
+
"debug_once": _print_debug_once,
|
|
107
|
+
"info_once": _print_info_once,
|
|
108
|
+
"warning_once": _print_warning_once,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _configure_vllm_root_logger() -> None:
|
|
113
|
+
logging_config = dict[str, Any]()
|
|
114
|
+
|
|
115
|
+
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
|
|
118
|
+
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
|
|
119
|
+
"implies VLLM_CONFIGURE_LOGGING. Please enable "
|
|
120
|
+
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
|
|
121
|
+
|
|
122
|
+
if VLLM_CONFIGURE_LOGGING:
|
|
123
|
+
logging_config = DEFAULT_LOGGING_CONFIG
|
|
124
|
+
|
|
125
|
+
if VLLM_LOGGING_CONFIG_PATH:
|
|
126
|
+
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
|
|
127
|
+
raise RuntimeError(
|
|
128
|
+
"Could not load logging config. File does not exist: %s",
|
|
129
|
+
VLLM_LOGGING_CONFIG_PATH)
|
|
130
|
+
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
|
131
|
+
custom_config = json.loads(file.read())
|
|
132
|
+
|
|
133
|
+
if not isinstance(custom_config, dict):
|
|
134
|
+
raise ValueError("Invalid logging config. Expected dict, got %s.",
|
|
135
|
+
type(custom_config).__name__)
|
|
136
|
+
logging_config = custom_config
|
|
137
|
+
|
|
138
|
+
for formatter in logging_config.get("formatters", {}).values():
|
|
139
|
+
# This provides backwards compatibility after #10134.
|
|
140
|
+
if formatter.get(
|
|
141
|
+
"class"
|
|
142
|
+
) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
|
|
143
|
+
formatter[
|
|
144
|
+
"class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
|
|
145
|
+
|
|
146
|
+
if logging_config:
|
|
147
|
+
dictConfig(logging_config)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# The root logger is initialized when the module is imported.
|
|
151
|
+
# This is thread-safe as the module is only imported once,
|
|
152
|
+
# guaranteed by the Python GIL.
|
|
153
|
+
_configure_vllm_root_logger()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def init_logger(name: str) -> _VllmLogger:
|
|
157
|
+
"""The main purpose of this function is to ensure that loggers are
|
|
158
|
+
retrieved in such a way that we can be sure the root vllm logger has
|
|
159
|
+
already been configured."""
|
|
160
|
+
|
|
161
|
+
logger = logging.getLogger(name)
|
|
162
|
+
|
|
163
|
+
for method_name, method in _METHODS_TO_PATCH.items():
|
|
164
|
+
setattr(logger, method_name, MethodType(method, logger))
|
|
165
|
+
|
|
166
|
+
return cast(_VllmLogger, logger)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
logger = init_logger(__name__)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
|
173
|
+
if event in ['call', 'return']:
|
|
174
|
+
# Extract the filename, line number, function name, and the code object
|
|
175
|
+
filename = frame.f_code.co_filename
|
|
176
|
+
lineno = frame.f_lineno
|
|
177
|
+
func_name = frame.f_code.co_name
|
|
178
|
+
if not filename.startswith(root_dir):
|
|
179
|
+
# only log the functions in the vllm root_dir
|
|
180
|
+
return
|
|
181
|
+
# Log every function call or return
|
|
182
|
+
try:
|
|
183
|
+
last_frame = frame.f_back
|
|
184
|
+
if last_frame is not None:
|
|
185
|
+
last_filename = last_frame.f_code.co_filename
|
|
186
|
+
last_lineno = last_frame.f_lineno
|
|
187
|
+
last_func_name = last_frame.f_code.co_name
|
|
188
|
+
else:
|
|
189
|
+
# initial frame
|
|
190
|
+
last_filename = ""
|
|
191
|
+
last_lineno = 0
|
|
192
|
+
last_func_name = ""
|
|
193
|
+
with open(log_path, 'a') as f:
|
|
194
|
+
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
195
|
+
if event == 'call':
|
|
196
|
+
f.write(f"{ts} Call to"
|
|
197
|
+
f" {func_name} in {filename}:{lineno}"
|
|
198
|
+
f" from {last_func_name} in {last_filename}:"
|
|
199
|
+
f"{last_lineno}\n")
|
|
200
|
+
else:
|
|
201
|
+
f.write(f"{ts} Return from"
|
|
202
|
+
f" {func_name} in {filename}:{lineno}"
|
|
203
|
+
f" to {last_func_name} in {last_filename}:"
|
|
204
|
+
f"{last_lineno}\n")
|
|
205
|
+
except NameError:
|
|
206
|
+
# modules are deleted during shutdown
|
|
207
|
+
pass
|
|
208
|
+
return partial(_trace_calls, log_path, root_dir)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def enable_trace_function_call(log_file_path: str, root_dir: str):
|
|
212
|
+
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class NewLineFormatter(logging.Formatter):
|
|
5
|
+
"""Adds logging prefix to newlines to align multi-line messages."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, fmt, datefmt=None, style="%"):
|
|
8
|
+
logging.Formatter.__init__(self, fmt, datefmt, style)
|
|
9
|
+
|
|
10
|
+
def format(self, record):
|
|
11
|
+
msg = logging.Formatter.format(self, record)
|
|
12
|
+
if record.message != "":
|
|
13
|
+
parts = msg.split(record.message)
|
|
14
|
+
msg = msg.replace("\n", "\r\n" + parts[0])
|
|
15
|
+
return msg
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import torch
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torchax.ops.mappings import j2t_dtype
|
|
9
|
+
from transformers import PretrainedConfig
|
|
10
|
+
from vllm.config import VllmConfig
|
|
11
|
+
from vllm.utils.func_utils import supports_kw
|
|
12
|
+
|
|
13
|
+
from tpu_inference import envs
|
|
14
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
15
|
+
from tpu_inference.logger import init_logger
|
|
16
|
+
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
17
|
+
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
18
|
+
load_random_weights_into_qwix_abstract_model)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
_MODEL_REGISTRY = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UnsupportedArchitectureError(ValueError):
|
|
26
|
+
"""Raised when a model architecture is not supported in the registry."""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
31
|
+
# NOTE: Use inline imports here, otherwise the normal imports
|
|
32
|
+
# would cause JAX init failure when using multi hosts with Ray.
|
|
33
|
+
|
|
34
|
+
from tpu_inference.models.jax.deepseek_v3 import DeepSeekV3
|
|
35
|
+
from tpu_inference.models.jax.gpt_oss import GptOss
|
|
36
|
+
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
37
|
+
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
38
|
+
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
39
|
+
from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
|
|
40
|
+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
41
|
+
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
42
|
+
Qwen2_5_VLForConditionalGeneration
|
|
43
|
+
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
44
|
+
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
45
|
+
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
46
|
+
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
47
|
+
_MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
|
|
48
|
+
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
49
|
+
_MODEL_REGISTRY[
|
|
50
|
+
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
+
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
52
|
+
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
53
|
+
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
54
|
+
|
|
55
|
+
architectures = getattr(config, "architectures", [])
|
|
56
|
+
for arch in architectures:
|
|
57
|
+
if arch in _MODEL_REGISTRY:
|
|
58
|
+
return _MODEL_REGISTRY[arch]
|
|
59
|
+
raise UnsupportedArchitectureError(
|
|
60
|
+
f"Model architectures {architectures} are not supported for now. "
|
|
61
|
+
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_nnx_model(
|
|
65
|
+
model_class: Any,
|
|
66
|
+
vllm_config: VllmConfig,
|
|
67
|
+
rng: jax.Array,
|
|
68
|
+
mesh: Mesh,
|
|
69
|
+
) -> nnx.Module:
|
|
70
|
+
|
|
71
|
+
def create_abstract_model() -> nnx.Module:
|
|
72
|
+
"""
|
|
73
|
+
Helper class to create an abstract model for `nnx.eval_shape`.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
An abstract model function.
|
|
77
|
+
"""
|
|
78
|
+
return model_class(vllm_config, rng, mesh)
|
|
79
|
+
|
|
80
|
+
@nnx.jit(donate_argnums=(0, ),
|
|
81
|
+
static_argnames=('use_qwix_on_abstract_model', ))
|
|
82
|
+
def create_jit_model(
|
|
83
|
+
model: nnx.Module,
|
|
84
|
+
use_qwix_on_abstract_model: bool = False) -> nnx.Module:
|
|
85
|
+
"""
|
|
86
|
+
Create a jit model.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model: The model to jit.
|
|
90
|
+
use_qwix_on_abstract_model: Whether to apply Qwix on the abstract model.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The jitted model.
|
|
94
|
+
"""
|
|
95
|
+
state = nnx.state(model)
|
|
96
|
+
nnx.update(model, state)
|
|
97
|
+
if not use_qwix_on_abstract_model:
|
|
98
|
+
# NOTE: if Qwix is not configured, this will be a no-op
|
|
99
|
+
model = apply_qwix_quantization(vllm_config,
|
|
100
|
+
model,
|
|
101
|
+
rng,
|
|
102
|
+
mesh,
|
|
103
|
+
apply_to_abstract_model=False)
|
|
104
|
+
return model
|
|
105
|
+
|
|
106
|
+
if vllm_config.load_config.load_format == "dummy":
|
|
107
|
+
# Create a sharded model with random inited weights.
|
|
108
|
+
# TODO: currently Qwen2ForCausalLM is using legacy model implementation
|
|
109
|
+
# will merge the random init logic when all model are migrated to new model implementation
|
|
110
|
+
|
|
111
|
+
# Handle the case where we want to load in random weights to a Qwix-quantized model. Here, we
|
|
112
|
+
# need to run an abstract pass for Qwix first and then load in the random weights.
|
|
113
|
+
if apply_qwix_on_abstract_model(vllm_config):
|
|
114
|
+
abstract_model_fn = apply_qwix_quantization(
|
|
115
|
+
vllm_config,
|
|
116
|
+
create_abstract_model,
|
|
117
|
+
rng,
|
|
118
|
+
mesh,
|
|
119
|
+
apply_to_abstract_model=True)
|
|
120
|
+
|
|
121
|
+
model = nnx.eval_shape(abstract_model_fn)
|
|
122
|
+
quantization_config = vllm_config.model_config.hf_config.quantization_config if hasattr(
|
|
123
|
+
vllm_config.model_config.hf_config,
|
|
124
|
+
"quantization_config") else {}
|
|
125
|
+
load_random_weights_into_qwix_abstract_model(
|
|
126
|
+
rng, model, mesh, quantization_config)
|
|
127
|
+
with mesh:
|
|
128
|
+
jit_model = create_jit_model(model,
|
|
129
|
+
use_qwix_on_abstract_model=True)
|
|
130
|
+
return jit_model
|
|
131
|
+
|
|
132
|
+
@nnx.jit
|
|
133
|
+
def create_sharded_model():
|
|
134
|
+
model = model_class(vllm_config, rng, mesh)
|
|
135
|
+
state = nnx.state(model)
|
|
136
|
+
pspecs = nnx.get_partition_spec(state)
|
|
137
|
+
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
|
|
138
|
+
nnx.update(model, sharded_state)
|
|
139
|
+
# NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
|
|
140
|
+
return model
|
|
141
|
+
|
|
142
|
+
with mesh:
|
|
143
|
+
jit_model = create_sharded_model()
|
|
144
|
+
# In this case, we are applying Qwix quantization to the true, concrete model
|
|
145
|
+
jit_model = apply_qwix_quantization(vllm_config,
|
|
146
|
+
jit_model,
|
|
147
|
+
rng,
|
|
148
|
+
mesh,
|
|
149
|
+
apply_to_abstract_model=False)
|
|
150
|
+
if hasattr(jit_model, 'initialize_cache'):
|
|
151
|
+
jit_model.initialize_cache()
|
|
152
|
+
else:
|
|
153
|
+
# We first create an abstract model without allocating any weights,
|
|
154
|
+
# then fill in its weigths during load_weights from HF.
|
|
155
|
+
# This shows 2 advantages than the normal way:
|
|
156
|
+
# 1. The model weights will only be allocated once. Otherwise the normal way
|
|
157
|
+
# will random-init the model weights first, then load the real weights.
|
|
158
|
+
# The two pass weights allocation causes model loading slow.
|
|
159
|
+
# 2. The model loading won't be OOM. Otherwise the normal way will hold
|
|
160
|
+
# a full model weights after random-init, then duplicate a layer during
|
|
161
|
+
# the load_weights. This would be easy to OOM if the layer is super large.
|
|
162
|
+
abstract_model_fn = create_abstract_model
|
|
163
|
+
# NOTE: only one of the abstract (this) or or concrete Qwix quantization paths should
|
|
164
|
+
# be taken
|
|
165
|
+
if should_apply_qwix_on_abstract_model := apply_qwix_on_abstract_model(
|
|
166
|
+
vllm_config):
|
|
167
|
+
# NOTE: if Qwix is not configured, this will return `create_abstract_model` and
|
|
168
|
+
# thus be a no-op
|
|
169
|
+
abstract_model_fn = apply_qwix_quantization(
|
|
170
|
+
vllm_config,
|
|
171
|
+
create_abstract_model,
|
|
172
|
+
rng,
|
|
173
|
+
mesh,
|
|
174
|
+
apply_to_abstract_model=True)
|
|
175
|
+
model = nnx.eval_shape(abstract_model_fn)
|
|
176
|
+
# Although the created model can already work, we still need to jit
|
|
177
|
+
# the model creation again, otherwise the model forward will have
|
|
178
|
+
# non-trivial overhead in PjitFunction.
|
|
179
|
+
with mesh:
|
|
180
|
+
model.load_weights(rng)
|
|
181
|
+
jit_model = create_jit_model(
|
|
182
|
+
model,
|
|
183
|
+
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
184
|
+
return jit_model
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# TODO(pooyam): We need to refactor this. This is returning a bunch of functions that do not work with all models and this is not very easy to see from the code.
|
|
188
|
+
def get_flax_model(
|
|
189
|
+
vllm_config: VllmConfig,
|
|
190
|
+
rng: jax.Array,
|
|
191
|
+
mesh: Mesh,
|
|
192
|
+
is_draft_model: bool = False,
|
|
193
|
+
) -> nnx.Module:
|
|
194
|
+
if is_draft_model:
|
|
195
|
+
model_class = _get_model_architecture(
|
|
196
|
+
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
197
|
+
else:
|
|
198
|
+
model_class = _get_model_architecture(
|
|
199
|
+
vllm_config.model_config.hf_config)
|
|
200
|
+
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
|
|
201
|
+
kv_cache_sharding = NamedSharding(
|
|
202
|
+
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
|
|
203
|
+
hidden_states_sharding = NamedSharding(mesh,
|
|
204
|
+
PartitionSpec(
|
|
205
|
+
ShardingAxisName.ATTN_DATA,
|
|
206
|
+
None)) # (T, D)
|
|
207
|
+
|
|
208
|
+
# For performance consideration, refer to:
|
|
209
|
+
# https://flax.readthedocs.io/en/latest/guides/performance.html
|
|
210
|
+
graphdef, state = nnx.split(jit_model)
|
|
211
|
+
|
|
212
|
+
@functools.partial(
|
|
213
|
+
jax.jit,
|
|
214
|
+
out_shardings=(
|
|
215
|
+
kv_cache_sharding,
|
|
216
|
+
hidden_states_sharding,
|
|
217
|
+
hidden_states_sharding, # aux hidden states
|
|
218
|
+
),
|
|
219
|
+
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
+
static_argnums=6, #6 is layer_name_to_kvcache_index
|
|
221
|
+
)
|
|
222
|
+
def run_model(graphdef, state, *args):
|
|
223
|
+
model = nnx.merge(graphdef, state)
|
|
224
|
+
return model(*args)
|
|
225
|
+
|
|
226
|
+
logits_sharding = NamedSharding(
|
|
227
|
+
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
|
|
228
|
+
|
|
229
|
+
@functools.partial(
|
|
230
|
+
jax.jit,
|
|
231
|
+
out_shardings=(logits_sharding),
|
|
232
|
+
)
|
|
233
|
+
def run_compute_logits(graphdef, state, *args):
|
|
234
|
+
model = nnx.merge(graphdef, state)
|
|
235
|
+
hidden_state, *_ = args
|
|
236
|
+
return model.compute_logits(hidden_state)
|
|
237
|
+
|
|
238
|
+
# Multi-modal support only
|
|
239
|
+
# This function calculates the image token's embeddings by VIT
|
|
240
|
+
def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
|
|
241
|
+
**kwargs):
|
|
242
|
+
model = nnx.merge(graphdef, state)
|
|
243
|
+
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
|
|
244
|
+
|
|
245
|
+
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
246
|
+
@functools.partial(
|
|
247
|
+
jax.jit,
|
|
248
|
+
out_shardings=(logits_sharding),
|
|
249
|
+
)
|
|
250
|
+
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
|
|
251
|
+
model = nnx.merge(graphdef, state)
|
|
252
|
+
return model.get_input_embeddings(*args, **kwargs)
|
|
253
|
+
|
|
254
|
+
# For models that want to work with EAGLE-3 speculative decoding
|
|
255
|
+
@functools.partial(
|
|
256
|
+
jax.jit,
|
|
257
|
+
out_shardings=(logits_sharding),
|
|
258
|
+
)
|
|
259
|
+
def combine_hidden_states(graphdef, state, hidden_states):
|
|
260
|
+
model = nnx.merge(graphdef, state)
|
|
261
|
+
return model.combine_hidden_states(hidden_states)
|
|
262
|
+
|
|
263
|
+
model = nnx.merge(graphdef, state)
|
|
264
|
+
precompile_vision_encoder_fn = getattr(model, "precompile_vision_encoder",
|
|
265
|
+
None)
|
|
266
|
+
model_fn = functools.partial(run_model, graphdef)
|
|
267
|
+
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
268
|
+
get_multimodal_embeddings_fn = functools.partial(
|
|
269
|
+
run_get_multimodal_embeddings, graphdef)
|
|
270
|
+
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
271
|
+
graphdef)
|
|
272
|
+
lora_manager, model = None, None
|
|
273
|
+
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
274
|
+
graphdef)
|
|
275
|
+
|
|
276
|
+
get_mrope_input_positions_fn = None if not hasattr(
|
|
277
|
+
jit_model,
|
|
278
|
+
"get_mrope_input_positions") else jit_model.get_mrope_input_positions
|
|
279
|
+
|
|
280
|
+
multimodal_fns = {
|
|
281
|
+
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
282
|
+
"get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
|
|
283
|
+
"get_input_embeddings_fn": get_input_embeddings_fn,
|
|
284
|
+
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_vllm_model(
|
|
291
|
+
vllm_config: VllmConfig,
|
|
292
|
+
rng: jax.Array,
|
|
293
|
+
mesh: Mesh,
|
|
294
|
+
):
|
|
295
|
+
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
|
|
296
|
+
|
|
297
|
+
model = VllmModelWrapper(
|
|
298
|
+
vllm_config=vllm_config,
|
|
299
|
+
rng=rng,
|
|
300
|
+
mesh=mesh,
|
|
301
|
+
)
|
|
302
|
+
params, lora_manager = model.load_weights()
|
|
303
|
+
|
|
304
|
+
jit_model = model.jit_step_func()
|
|
305
|
+
compute_logits_fn = model.jit_compute_logits_func()
|
|
306
|
+
# the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
|
|
307
|
+
combine_hidden_states_fn = None
|
|
308
|
+
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def get_model(
|
|
312
|
+
vllm_config: VllmConfig,
|
|
313
|
+
rng: jax.Array,
|
|
314
|
+
mesh: Mesh,
|
|
315
|
+
is_draft_model: bool = False,
|
|
316
|
+
) -> Any:
|
|
317
|
+
impl = envs.MODEL_IMPL_TYPE
|
|
318
|
+
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
|
|
319
|
+
|
|
320
|
+
if impl == "flax_nnx":
|
|
321
|
+
try:
|
|
322
|
+
# Try to load the flax model first
|
|
323
|
+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
324
|
+
except UnsupportedArchitectureError as e:
|
|
325
|
+
# Convert the error message to a string to check its contents
|
|
326
|
+
error_msg = str(e)
|
|
327
|
+
|
|
328
|
+
logger.warning(f"Flax model failed with: '{error_msg}'. "
|
|
329
|
+
"Falling back to vLLM implementation.")
|
|
330
|
+
# Fall back to the vLLM model and updating the dtype accordingly
|
|
331
|
+
vllm_config.model_config.dtype = j2t_dtype(
|
|
332
|
+
vllm_config.model_config.dtype.dtype)
|
|
333
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
334
|
+
elif impl == "vllm":
|
|
335
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
336
|
+
else:
|
|
337
|
+
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _validate_model_interface(model: Any) -> None:
|
|
341
|
+
"""Validates that the model class has the required methods and signatures.
|
|
342
|
+
|
|
343
|
+
A valid model must have:
|
|
344
|
+
- An __init__ method that accepts a 'vllm_config' keyword argument.
|
|
345
|
+
- A __call__ method that accepts 'kv_caches', 'input_ids', and
|
|
346
|
+
'attention_metadata' keyword arguments.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
model: The model class to validate.
|
|
350
|
+
|
|
351
|
+
Raises:
|
|
352
|
+
TypeError: If the model does not meet the interface requirements.
|
|
353
|
+
"""
|
|
354
|
+
# Check for __init__ with vllm_config
|
|
355
|
+
model_init = getattr(model, "__init__", None)
|
|
356
|
+
if not callable(model_init):
|
|
357
|
+
raise TypeError(
|
|
358
|
+
f"Model {model.__name__} must have an __init__ method.")
|
|
359
|
+
|
|
360
|
+
if not supports_kw(model_init, "vllm_config"):
|
|
361
|
+
raise TypeError(
|
|
362
|
+
f"Model {model.__name__} __init__ method must accept a "
|
|
363
|
+
"'vllm_config' keyword argument.")
|
|
364
|
+
|
|
365
|
+
# Check for __call__ with required arguments
|
|
366
|
+
model_call = getattr(model, "__call__", None)
|
|
367
|
+
# A class object is always callable (it produces an instance).
|
|
368
|
+
# We need to check if the class _explicitly_ defines a __call__ method for its
|
|
369
|
+
# instance, which is different from `type.__call__`.
|
|
370
|
+
has_defined_call = False
|
|
371
|
+
if isinstance(model, type):
|
|
372
|
+
if any("__call__" in C.__dict__ for C in model.__mro__):
|
|
373
|
+
has_defined_call = True
|
|
374
|
+
elif callable(model_call):
|
|
375
|
+
# For an instance, a simple callable check is sufficient.
|
|
376
|
+
has_defined_call = True
|
|
377
|
+
|
|
378
|
+
if not has_defined_call:
|
|
379
|
+
raise TypeError(f"Model {model.__name__} must have a __call__ method.")
|
|
380
|
+
|
|
381
|
+
required_call_args = ("kv_caches", "input_ids", "attention_metadata")
|
|
382
|
+
missing_args = tuple(arg for arg in required_call_args
|
|
383
|
+
if not supports_kw(model_call, arg))
|
|
384
|
+
|
|
385
|
+
if missing_args:
|
|
386
|
+
raise TypeError(
|
|
387
|
+
f"Model {model.__name__} __call__ method is missing required "
|
|
388
|
+
f"keyword arguments: {missing_args}")
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def register_model(arch: str, model: Any) -> None:
|
|
392
|
+
"""
|
|
393
|
+
Registers a model class for a given architecture name.
|
|
394
|
+
|
|
395
|
+
This function registers the model with both the tpu_inference registry
|
|
396
|
+
and the vLLM registry. For vLLM, it creates a compatible wrapper
|
|
397
|
+
around the JAX model.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
arch: The name of the architecture (e.g., "LlamaForCausalLM").
|
|
401
|
+
model: The JAX model class to register (e.g., a flax.nnx.Module).
|
|
402
|
+
"""
|
|
403
|
+
_validate_model_interface(model)
|
|
404
|
+
|
|
405
|
+
# Register with tpu_inference registry for the JAX backend
|
|
406
|
+
_MODEL_REGISTRY[arch] = model
|
|
407
|
+
|
|
408
|
+
# Create a vLLM-compatible wrapper for the JAX model class.
|
|
409
|
+
# This wrapper inherits from the JAX model and torch.nn.Module
|
|
410
|
+
# to pass vLLM's type checks. It is not meant to be instantiated
|
|
411
|
+
# or executed by vLLM's PyTorch backend.
|
|
412
|
+
def unimplemented_forward(
|
|
413
|
+
self,
|
|
414
|
+
input_ids: "torch.Tensor",
|
|
415
|
+
positions: "torch.Tensor",
|
|
416
|
+
intermediate_tensors: Optional[Any] = None,
|
|
417
|
+
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
418
|
+
) -> None:
|
|
419
|
+
raise NotImplementedError(
|
|
420
|
+
"This is a JAX model and does not implement the PyTorch forward method."
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
424
|
+
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
425
|
+
def wrapper_init(self, *args, **kwargs):
|
|
426
|
+
torch.nn.Module.__init__(self)
|
|
427
|
+
|
|
428
|
+
# Dynamically create the wrapper class that is a subclass of both the
|
|
429
|
+
# JAX model and torch.nn.Module.
|
|
430
|
+
VllmCompatibleModel = type(
|
|
431
|
+
f"VllmCompatible{model.__name__}",
|
|
432
|
+
(model, torch.nn.Module),
|
|
433
|
+
{
|
|
434
|
+
"__init__": wrapper_init,
|
|
435
|
+
"forward": unimplemented_forward,
|
|
436
|
+
# Prevent vLLM from trying to load weights into this dummy class.
|
|
437
|
+
"load_weights": lambda self, *args, **kwargs: None,
|
|
438
|
+
})
|
|
439
|
+
|
|
440
|
+
# Register the wrapped model with vLLM's registry.
|
|
441
|
+
from vllm.model_executor.models.registry import ModelRegistry
|
|
442
|
+
ModelRegistry.register_model(arch, VllmCompatibleModel)
|
|
443
|
+
logger.info(
|
|
444
|
+
f"Registered JAX model {arch} with tpu_inference and vLLM registries.")
|
|
File without changes
|