tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
"""Logging configuration for vLLM."""
|
|
4
|
+
import datetime
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
from collections.abc import Hashable
|
|
9
|
+
from functools import lru_cache, partial
|
|
10
|
+
from logging import Logger
|
|
11
|
+
from logging.config import dictConfig
|
|
12
|
+
from os import path
|
|
13
|
+
from types import MethodType
|
|
14
|
+
from typing import Any, cast
|
|
15
|
+
|
|
16
|
+
import tpu_inference.mock.vllm_envs as envs
|
|
17
|
+
|
|
18
|
+
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
|
19
|
+
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
|
20
|
+
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
|
|
21
|
+
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
|
|
22
|
+
|
|
23
|
+
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
|
|
24
|
+
"[%(filename)s:%(lineno)d] %(message)s")
|
|
25
|
+
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
26
|
+
|
|
27
|
+
DEFAULT_LOGGING_CONFIG = {
|
|
28
|
+
"formatters": {
|
|
29
|
+
"vllm": {
|
|
30
|
+
"class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
|
|
31
|
+
"datefmt": _DATE_FORMAT,
|
|
32
|
+
"format": _FORMAT,
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
"handlers": {
|
|
36
|
+
"vllm": {
|
|
37
|
+
"class": "logging.StreamHandler",
|
|
38
|
+
"formatter": "vllm",
|
|
39
|
+
"level": VLLM_LOGGING_LEVEL,
|
|
40
|
+
"stream": "ext://sys.stdout",
|
|
41
|
+
},
|
|
42
|
+
},
|
|
43
|
+
"loggers": {
|
|
44
|
+
"vllm": {
|
|
45
|
+
"handlers": ["vllm"],
|
|
46
|
+
"level": "DEBUG",
|
|
47
|
+
"propagate": False,
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
"version": 1,
|
|
51
|
+
"disable_existing_loggers": False
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@lru_cache
|
|
56
|
+
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
57
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
58
|
+
logger.debug(msg, *args, stacklevel=2)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@lru_cache
|
|
62
|
+
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
63
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
64
|
+
logger.info(msg, *args, stacklevel=2)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@lru_cache
|
|
68
|
+
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
69
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
70
|
+
logger.warning(msg, *args, stacklevel=2)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _VllmLogger(Logger):
|
|
74
|
+
"""
|
|
75
|
+
Note:
|
|
76
|
+
This class is just to provide type information.
|
|
77
|
+
We actually patch the methods directly on the [`logging.Logger`][]
|
|
78
|
+
instance to avoid conflicting with other libraries such as
|
|
79
|
+
`intel_extension_for_pytorch.utils._logger`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def debug_once(self, msg: str, *args: Hashable) -> None:
|
|
83
|
+
"""
|
|
84
|
+
As [`debug`][logging.Logger.debug], but subsequent calls with
|
|
85
|
+
the same message are silently dropped.
|
|
86
|
+
"""
|
|
87
|
+
_print_debug_once(self, msg, *args)
|
|
88
|
+
|
|
89
|
+
def info_once(self, msg: str, *args: Hashable) -> None:
|
|
90
|
+
"""
|
|
91
|
+
As [`info`][logging.Logger.info], but subsequent calls with
|
|
92
|
+
the same message are silently dropped.
|
|
93
|
+
"""
|
|
94
|
+
_print_info_once(self, msg, *args)
|
|
95
|
+
|
|
96
|
+
def warning_once(self, msg: str, *args: Hashable) -> None:
|
|
97
|
+
"""
|
|
98
|
+
As [`warning`][logging.Logger.warning], but subsequent calls with
|
|
99
|
+
the same message are silently dropped.
|
|
100
|
+
"""
|
|
101
|
+
_print_warning_once(self, msg, *args)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Pre-defined methods mapping to avoid repeated dictionary creation
|
|
105
|
+
_METHODS_TO_PATCH = {
|
|
106
|
+
"debug_once": _print_debug_once,
|
|
107
|
+
"info_once": _print_info_once,
|
|
108
|
+
"warning_once": _print_warning_once,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _configure_vllm_root_logger() -> None:
|
|
113
|
+
logging_config = dict[str, Any]()
|
|
114
|
+
|
|
115
|
+
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
|
|
118
|
+
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
|
|
119
|
+
"implies VLLM_CONFIGURE_LOGGING. Please enable "
|
|
120
|
+
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
|
|
121
|
+
|
|
122
|
+
if VLLM_CONFIGURE_LOGGING:
|
|
123
|
+
logging_config = DEFAULT_LOGGING_CONFIG
|
|
124
|
+
|
|
125
|
+
if VLLM_LOGGING_CONFIG_PATH:
|
|
126
|
+
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
|
|
127
|
+
raise RuntimeError(
|
|
128
|
+
"Could not load logging config. File does not exist: %s",
|
|
129
|
+
VLLM_LOGGING_CONFIG_PATH)
|
|
130
|
+
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
|
131
|
+
custom_config = json.loads(file.read())
|
|
132
|
+
|
|
133
|
+
if not isinstance(custom_config, dict):
|
|
134
|
+
raise ValueError("Invalid logging config. Expected dict, got %s.",
|
|
135
|
+
type(custom_config).__name__)
|
|
136
|
+
logging_config = custom_config
|
|
137
|
+
|
|
138
|
+
for formatter in logging_config.get("formatters", {}).values():
|
|
139
|
+
# This provides backwards compatibility after #10134.
|
|
140
|
+
if formatter.get(
|
|
141
|
+
"class"
|
|
142
|
+
) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
|
|
143
|
+
formatter[
|
|
144
|
+
"class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
|
|
145
|
+
|
|
146
|
+
if logging_config:
|
|
147
|
+
dictConfig(logging_config)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# The root logger is initialized when the module is imported.
|
|
151
|
+
# This is thread-safe as the module is only imported once,
|
|
152
|
+
# guaranteed by the Python GIL.
|
|
153
|
+
_configure_vllm_root_logger()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def init_logger(name: str) -> _VllmLogger:
|
|
157
|
+
"""The main purpose of this function is to ensure that loggers are
|
|
158
|
+
retrieved in such a way that we can be sure the root vllm logger has
|
|
159
|
+
already been configured."""
|
|
160
|
+
|
|
161
|
+
logger = logging.getLogger(name)
|
|
162
|
+
|
|
163
|
+
for method_name, method in _METHODS_TO_PATCH.items():
|
|
164
|
+
setattr(logger, method_name, MethodType(method, logger))
|
|
165
|
+
|
|
166
|
+
return cast(_VllmLogger, logger)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
logger = init_logger(__name__)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
|
173
|
+
if event in ['call', 'return']:
|
|
174
|
+
# Extract the filename, line number, function name, and the code object
|
|
175
|
+
filename = frame.f_code.co_filename
|
|
176
|
+
lineno = frame.f_lineno
|
|
177
|
+
func_name = frame.f_code.co_name
|
|
178
|
+
if not filename.startswith(root_dir):
|
|
179
|
+
# only log the functions in the vllm root_dir
|
|
180
|
+
return
|
|
181
|
+
# Log every function call or return
|
|
182
|
+
try:
|
|
183
|
+
last_frame = frame.f_back
|
|
184
|
+
if last_frame is not None:
|
|
185
|
+
last_filename = last_frame.f_code.co_filename
|
|
186
|
+
last_lineno = last_frame.f_lineno
|
|
187
|
+
last_func_name = last_frame.f_code.co_name
|
|
188
|
+
else:
|
|
189
|
+
# initial frame
|
|
190
|
+
last_filename = ""
|
|
191
|
+
last_lineno = 0
|
|
192
|
+
last_func_name = ""
|
|
193
|
+
with open(log_path, 'a') as f:
|
|
194
|
+
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
195
|
+
if event == 'call':
|
|
196
|
+
f.write(f"{ts} Call to"
|
|
197
|
+
f" {func_name} in {filename}:{lineno}"
|
|
198
|
+
f" from {last_func_name} in {last_filename}:"
|
|
199
|
+
f"{last_lineno}\n")
|
|
200
|
+
else:
|
|
201
|
+
f.write(f"{ts} Return from"
|
|
202
|
+
f" {func_name} in {filename}:{lineno}"
|
|
203
|
+
f" to {last_func_name} in {last_filename}:"
|
|
204
|
+
f"{last_lineno}\n")
|
|
205
|
+
except NameError:
|
|
206
|
+
# modules are deleted during shutdown
|
|
207
|
+
pass
|
|
208
|
+
return partial(_trace_calls, log_path, root_dir)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def enable_trace_function_call(log_file_path: str, root_dir: str):
|
|
212
|
+
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class NewLineFormatter(logging.Formatter):
|
|
5
|
+
"""Adds logging prefix to newlines to align multi-line messages."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, fmt, datefmt=None, style="%"):
|
|
8
|
+
logging.Formatter.__init__(self, fmt, datefmt, style)
|
|
9
|
+
|
|
10
|
+
def format(self, record):
|
|
11
|
+
msg = logging.Formatter.format(self, record)
|
|
12
|
+
if record.message != "":
|
|
13
|
+
parts = msg.split(record.message)
|
|
14
|
+
msg = msg.replace("\n", "\r\n" + parts[0])
|
|
15
|
+
return msg
|
|
@@ -8,13 +8,10 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
|
8
8
|
from torchax.ops.mappings import j2t_dtype
|
|
9
9
|
from transformers import PretrainedConfig
|
|
10
10
|
from vllm.config import VllmConfig
|
|
11
|
-
from vllm.model_executor.model_loader import get_model_loader
|
|
12
|
-
from vllm.model_executor.model_loader.runai_streamer_loader import \
|
|
13
|
-
RunaiModelStreamerLoader
|
|
14
11
|
from vllm.utils.func_utils import supports_kw
|
|
15
12
|
|
|
16
13
|
from tpu_inference import envs
|
|
17
|
-
from tpu_inference.layers.
|
|
14
|
+
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
18
15
|
from tpu_inference.logger import init_logger
|
|
19
16
|
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
20
17
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
@@ -39,17 +36,19 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
39
36
|
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
40
37
|
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
41
38
|
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
42
|
-
from tpu_inference.models.jax.
|
|
39
|
+
from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
|
|
40
|
+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
43
41
|
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
44
42
|
Qwen2_5_VLForConditionalGeneration
|
|
45
43
|
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
46
44
|
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
47
45
|
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
48
46
|
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
49
|
-
_MODEL_REGISTRY["
|
|
47
|
+
_MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
|
|
50
48
|
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
51
49
|
_MODEL_REGISTRY[
|
|
52
50
|
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
+
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
53
52
|
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
54
53
|
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
55
54
|
|
|
@@ -58,10 +57,8 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
58
57
|
if arch in _MODEL_REGISTRY:
|
|
59
58
|
return _MODEL_REGISTRY[arch]
|
|
60
59
|
raise UnsupportedArchitectureError(
|
|
61
|
-
f"Model architectures {architectures} not "
|
|
62
|
-
"
|
|
63
|
-
f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
|
|
64
|
-
)
|
|
60
|
+
f"Model architectures {architectures} are not supported for now. "
|
|
61
|
+
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
def _get_nnx_model(
|
|
@@ -180,23 +177,7 @@ def _get_nnx_model(
|
|
|
180
177
|
# the model creation again, otherwise the model forward will have
|
|
181
178
|
# non-trivial overhead in PjitFunction.
|
|
182
179
|
with mesh:
|
|
183
|
-
|
|
184
|
-
if isinstance(loader, RunaiModelStreamerLoader):
|
|
185
|
-
model_weights = vllm_config.model_config.model
|
|
186
|
-
if hasattr(vllm_config.model_config, "model_weights"):
|
|
187
|
-
model_weights = vllm_config.model_config.model_weights
|
|
188
|
-
weights_iterator = loader._get_weights_iterator(
|
|
189
|
-
model_weights, vllm_config.model_config.revision)
|
|
190
|
-
# We set the weights iterator at runtime, to prevent having to change
|
|
191
|
-
# every model's load_weights signature. This also prevents us from hitting
|
|
192
|
-
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
|
|
193
|
-
# flax_nnx model whose load_weights function does not accept the
|
|
194
|
-
# weights_iterator keyword argument.
|
|
195
|
-
vllm_config.model_config.model_weights_iterator = weights_iterator
|
|
196
|
-
model.load_weights(rng)
|
|
197
|
-
del vllm_config.model_config.model_weights_iterator
|
|
198
|
-
else:
|
|
199
|
-
model.load_weights(rng)
|
|
180
|
+
model.load_weights(rng)
|
|
200
181
|
jit_model = create_jit_model(
|
|
201
182
|
model,
|
|
202
183
|
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
@@ -236,9 +217,7 @@ def get_flax_model(
|
|
|
236
217
|
hidden_states_sharding, # aux hidden states
|
|
237
218
|
),
|
|
238
219
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
239
|
-
static_argnums=
|
|
240
|
-
7, 10, 11
|
|
241
|
-
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
220
|
+
static_argnums=6, #6 is layer_name_to_kvcache_index
|
|
242
221
|
)
|
|
243
222
|
def run_model(graphdef, state, *args):
|
|
244
223
|
model = nnx.merge(graphdef, state)
|
|
@@ -263,11 +242,10 @@ def get_flax_model(
|
|
|
263
242
|
model = nnx.merge(graphdef, state)
|
|
264
243
|
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
|
|
265
244
|
|
|
266
|
-
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
267
245
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
268
246
|
@functools.partial(
|
|
269
247
|
jax.jit,
|
|
270
|
-
out_shardings=(
|
|
248
|
+
out_shardings=(logits_sharding),
|
|
271
249
|
)
|
|
272
250
|
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
|
|
273
251
|
model = nnx.merge(graphdef, state)
|
|
@@ -347,8 +325,8 @@ def get_model(
|
|
|
347
325
|
# Convert the error message to a string to check its contents
|
|
348
326
|
error_msg = str(e)
|
|
349
327
|
|
|
350
|
-
logger.warning(error_msg
|
|
351
|
-
|
|
328
|
+
logger.warning(f"Flax model failed with: '{error_msg}'. "
|
|
329
|
+
"Falling back to vLLM implementation.")
|
|
352
330
|
# Fall back to the vLLM model and updating the dtype accordingly
|
|
353
331
|
vllm_config.model_config.dtype = j2t_dtype(
|
|
354
332
|
vllm_config.model_config.dtype.dtype)
|
|
@@ -442,17 +420,6 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
442
420
|
"This is a JAX model and does not implement the PyTorch forward method."
|
|
443
421
|
)
|
|
444
422
|
|
|
445
|
-
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
446
|
-
def unimplemented_get_input_embeddings(
|
|
447
|
-
self,
|
|
448
|
-
input_ids: "torch.Tensor",
|
|
449
|
-
positions: "torch.Tensor",
|
|
450
|
-
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
451
|
-
) -> "torch.Tensor":
|
|
452
|
-
raise NotImplementedError(
|
|
453
|
-
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
|
|
454
|
-
)
|
|
455
|
-
|
|
456
423
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
457
424
|
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
458
425
|
def wrapper_init(self, *args, **kwargs):
|
|
@@ -466,7 +433,6 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
466
433
|
{
|
|
467
434
|
"__init__": wrapper_init,
|
|
468
435
|
"forward": unimplemented_forward,
|
|
469
|
-
"get_input_embeddings": unimplemented_get_input_embeddings,
|
|
470
436
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
471
437
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
472
438
|
})
|
|
@@ -8,10 +8,10 @@ from transformers import LlamaConfig, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
-
from tpu_inference.layers.common.attention_interface import attention
|
|
12
11
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
-
from tpu_inference.layers.
|
|
12
|
+
from tpu_inference.layers.jax.attention_interface import attention
|
|
14
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
|
+
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
15
15
|
from tpu_inference.logger import init_logger
|
|
16
16
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
17
17
|
load_hf_weights)
|
|
@@ -368,8 +368,7 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config.
|
|
372
|
-
self.mesh, mappings)
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
373
372
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
373
|
model=self,
|
|
375
374
|
metadata_map=metadata_map,
|
|
@@ -194,12 +194,13 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
202
|
+
hidden_size = model_config.get_hidden_size()
|
|
203
|
+
|
|
203
204
|
head_dim_original = model_config.get_head_size()
|
|
204
205
|
|
|
205
206
|
metadata_map.reshape_map.update({
|
|
@@ -304,8 +305,6 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
304
305
|
"fc": "model.fc.kernel",
|
|
305
306
|
"lm_head": "lm_head.kernel",
|
|
306
307
|
"d2t": "draft_id_to_target_id",
|
|
307
|
-
"embed_tokens":
|
|
308
|
-
"model.embed_tokens.embedding", # Some checkpoints need this
|
|
309
308
|
}
|
|
310
309
|
|
|
311
310
|
# Define keys to keep in original dtype (e.g., float32 for stability)
|
|
@@ -313,9 +312,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
313
312
|
r".*d2t.*",
|
|
314
313
|
]
|
|
315
314
|
|
|
316
|
-
metadata_map = get_default_maps(
|
|
317
|
-
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
-
mappings)
|
|
315
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
319
316
|
|
|
320
317
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
321
318
|
|
|
@@ -327,7 +324,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
327
324
|
is_draft_model=True,
|
|
328
325
|
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
329
326
|
|
|
330
|
-
# If the embedding is not initialized, initialize it with a
|
|
327
|
+
# If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
331
328
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
332
329
|
jax.ShapeDtypeStruct):
|
|
333
330
|
self.model.embed_tokens.embedding.value = jnp.zeros(
|