tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
"""Logging configuration for vLLM."""
|
|
4
|
+
import datetime
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
from collections.abc import Hashable
|
|
9
|
+
from functools import lru_cache, partial
|
|
10
|
+
from logging import Logger
|
|
11
|
+
from logging.config import dictConfig
|
|
12
|
+
from os import path
|
|
13
|
+
from types import MethodType
|
|
14
|
+
from typing import Any, cast
|
|
15
|
+
|
|
16
|
+
import tpu_inference.mock.vllm_envs as envs
|
|
17
|
+
|
|
18
|
+
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
|
19
|
+
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
|
20
|
+
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
|
|
21
|
+
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
|
|
22
|
+
|
|
23
|
+
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
|
|
24
|
+
"[%(filename)s:%(lineno)d] %(message)s")
|
|
25
|
+
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
26
|
+
|
|
27
|
+
DEFAULT_LOGGING_CONFIG = {
|
|
28
|
+
"formatters": {
|
|
29
|
+
"vllm": {
|
|
30
|
+
"class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
|
|
31
|
+
"datefmt": _DATE_FORMAT,
|
|
32
|
+
"format": _FORMAT,
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
"handlers": {
|
|
36
|
+
"vllm": {
|
|
37
|
+
"class": "logging.StreamHandler",
|
|
38
|
+
"formatter": "vllm",
|
|
39
|
+
"level": VLLM_LOGGING_LEVEL,
|
|
40
|
+
"stream": "ext://sys.stdout",
|
|
41
|
+
},
|
|
42
|
+
},
|
|
43
|
+
"loggers": {
|
|
44
|
+
"vllm": {
|
|
45
|
+
"handlers": ["vllm"],
|
|
46
|
+
"level": "DEBUG",
|
|
47
|
+
"propagate": False,
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
"version": 1,
|
|
51
|
+
"disable_existing_loggers": False
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@lru_cache
|
|
56
|
+
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
57
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
58
|
+
logger.debug(msg, *args, stacklevel=2)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@lru_cache
|
|
62
|
+
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
63
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
64
|
+
logger.info(msg, *args, stacklevel=2)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@lru_cache
|
|
68
|
+
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
69
|
+
# Set the stacklevel to 2 to print the original caller's line info
|
|
70
|
+
logger.warning(msg, *args, stacklevel=2)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _VllmLogger(Logger):
|
|
74
|
+
"""
|
|
75
|
+
Note:
|
|
76
|
+
This class is just to provide type information.
|
|
77
|
+
We actually patch the methods directly on the [`logging.Logger`][]
|
|
78
|
+
instance to avoid conflicting with other libraries such as
|
|
79
|
+
`intel_extension_for_pytorch.utils._logger`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def debug_once(self, msg: str, *args: Hashable) -> None:
|
|
83
|
+
"""
|
|
84
|
+
As [`debug`][logging.Logger.debug], but subsequent calls with
|
|
85
|
+
the same message are silently dropped.
|
|
86
|
+
"""
|
|
87
|
+
_print_debug_once(self, msg, *args)
|
|
88
|
+
|
|
89
|
+
def info_once(self, msg: str, *args: Hashable) -> None:
|
|
90
|
+
"""
|
|
91
|
+
As [`info`][logging.Logger.info], but subsequent calls with
|
|
92
|
+
the same message are silently dropped.
|
|
93
|
+
"""
|
|
94
|
+
_print_info_once(self, msg, *args)
|
|
95
|
+
|
|
96
|
+
def warning_once(self, msg: str, *args: Hashable) -> None:
|
|
97
|
+
"""
|
|
98
|
+
As [`warning`][logging.Logger.warning], but subsequent calls with
|
|
99
|
+
the same message are silently dropped.
|
|
100
|
+
"""
|
|
101
|
+
_print_warning_once(self, msg, *args)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Pre-defined methods mapping to avoid repeated dictionary creation
|
|
105
|
+
_METHODS_TO_PATCH = {
|
|
106
|
+
"debug_once": _print_debug_once,
|
|
107
|
+
"info_once": _print_info_once,
|
|
108
|
+
"warning_once": _print_warning_once,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _configure_vllm_root_logger() -> None:
|
|
113
|
+
logging_config = dict[str, Any]()
|
|
114
|
+
|
|
115
|
+
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
|
|
118
|
+
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
|
|
119
|
+
"implies VLLM_CONFIGURE_LOGGING. Please enable "
|
|
120
|
+
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
|
|
121
|
+
|
|
122
|
+
if VLLM_CONFIGURE_LOGGING:
|
|
123
|
+
logging_config = DEFAULT_LOGGING_CONFIG
|
|
124
|
+
|
|
125
|
+
if VLLM_LOGGING_CONFIG_PATH:
|
|
126
|
+
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
|
|
127
|
+
raise RuntimeError(
|
|
128
|
+
"Could not load logging config. File does not exist: %s",
|
|
129
|
+
VLLM_LOGGING_CONFIG_PATH)
|
|
130
|
+
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
|
131
|
+
custom_config = json.loads(file.read())
|
|
132
|
+
|
|
133
|
+
if not isinstance(custom_config, dict):
|
|
134
|
+
raise ValueError("Invalid logging config. Expected dict, got %s.",
|
|
135
|
+
type(custom_config).__name__)
|
|
136
|
+
logging_config = custom_config
|
|
137
|
+
|
|
138
|
+
for formatter in logging_config.get("formatters", {}).values():
|
|
139
|
+
# This provides backwards compatibility after #10134.
|
|
140
|
+
if formatter.get(
|
|
141
|
+
"class"
|
|
142
|
+
) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
|
|
143
|
+
formatter[
|
|
144
|
+
"class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
|
|
145
|
+
|
|
146
|
+
if logging_config:
|
|
147
|
+
dictConfig(logging_config)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# The root logger is initialized when the module is imported.
|
|
151
|
+
# This is thread-safe as the module is only imported once,
|
|
152
|
+
# guaranteed by the Python GIL.
|
|
153
|
+
_configure_vllm_root_logger()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def init_logger(name: str) -> _VllmLogger:
|
|
157
|
+
"""The main purpose of this function is to ensure that loggers are
|
|
158
|
+
retrieved in such a way that we can be sure the root vllm logger has
|
|
159
|
+
already been configured."""
|
|
160
|
+
|
|
161
|
+
logger = logging.getLogger(name)
|
|
162
|
+
|
|
163
|
+
for method_name, method in _METHODS_TO_PATCH.items():
|
|
164
|
+
setattr(logger, method_name, MethodType(method, logger))
|
|
165
|
+
|
|
166
|
+
return cast(_VllmLogger, logger)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
logger = init_logger(__name__)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
|
173
|
+
if event in ['call', 'return']:
|
|
174
|
+
# Extract the filename, line number, function name, and the code object
|
|
175
|
+
filename = frame.f_code.co_filename
|
|
176
|
+
lineno = frame.f_lineno
|
|
177
|
+
func_name = frame.f_code.co_name
|
|
178
|
+
if not filename.startswith(root_dir):
|
|
179
|
+
# only log the functions in the vllm root_dir
|
|
180
|
+
return
|
|
181
|
+
# Log every function call or return
|
|
182
|
+
try:
|
|
183
|
+
last_frame = frame.f_back
|
|
184
|
+
if last_frame is not None:
|
|
185
|
+
last_filename = last_frame.f_code.co_filename
|
|
186
|
+
last_lineno = last_frame.f_lineno
|
|
187
|
+
last_func_name = last_frame.f_code.co_name
|
|
188
|
+
else:
|
|
189
|
+
# initial frame
|
|
190
|
+
last_filename = ""
|
|
191
|
+
last_lineno = 0
|
|
192
|
+
last_func_name = ""
|
|
193
|
+
with open(log_path, 'a') as f:
|
|
194
|
+
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
195
|
+
if event == 'call':
|
|
196
|
+
f.write(f"{ts} Call to"
|
|
197
|
+
f" {func_name} in {filename}:{lineno}"
|
|
198
|
+
f" from {last_func_name} in {last_filename}:"
|
|
199
|
+
f"{last_lineno}\n")
|
|
200
|
+
else:
|
|
201
|
+
f.write(f"{ts} Return from"
|
|
202
|
+
f" {func_name} in {filename}:{lineno}"
|
|
203
|
+
f" to {last_func_name} in {last_filename}:"
|
|
204
|
+
f"{last_lineno}\n")
|
|
205
|
+
except NameError:
|
|
206
|
+
# modules are deleted during shutdown
|
|
207
|
+
pass
|
|
208
|
+
return partial(_trace_calls, log_path, root_dir)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def enable_trace_function_call(log_file_path: str, root_dir: str):
|
|
212
|
+
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class NewLineFormatter(logging.Formatter):
|
|
5
|
+
"""Adds logging prefix to newlines to align multi-line messages."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, fmt, datefmt=None, style="%"):
|
|
8
|
+
logging.Formatter.__init__(self, fmt, datefmt, style)
|
|
9
|
+
|
|
10
|
+
def format(self, record):
|
|
11
|
+
msg = logging.Formatter.format(self, record)
|
|
12
|
+
if record.message != "":
|
|
13
|
+
parts = msg.split(record.message)
|
|
14
|
+
msg = msg.replace("\n", "\r\n" + parts[0])
|
|
15
|
+
return msg
|
|
@@ -8,9 +8,6 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
|
8
8
|
from torchax.ops.mappings import j2t_dtype
|
|
9
9
|
from transformers import PretrainedConfig
|
|
10
10
|
from vllm.config import VllmConfig
|
|
11
|
-
from vllm.model_executor.model_loader import get_model_loader
|
|
12
|
-
from vllm.model_executor.model_loader.runai_streamer_loader import \
|
|
13
|
-
RunaiModelStreamerLoader
|
|
14
11
|
from vllm.utils.func_utils import supports_kw
|
|
15
12
|
|
|
16
13
|
from tpu_inference import envs
|
|
@@ -39,17 +36,19 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
39
36
|
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
40
37
|
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
41
38
|
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
42
|
-
from tpu_inference.models.jax.
|
|
39
|
+
from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
|
|
40
|
+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
43
41
|
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
44
42
|
Qwen2_5_VLForConditionalGeneration
|
|
45
43
|
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
46
44
|
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
47
45
|
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
48
46
|
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
49
|
-
_MODEL_REGISTRY["
|
|
47
|
+
_MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
|
|
50
48
|
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
51
49
|
_MODEL_REGISTRY[
|
|
52
50
|
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
+
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
53
52
|
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
54
53
|
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
55
54
|
|
|
@@ -58,10 +57,8 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
58
57
|
if arch in _MODEL_REGISTRY:
|
|
59
58
|
return _MODEL_REGISTRY[arch]
|
|
60
59
|
raise UnsupportedArchitectureError(
|
|
61
|
-
f"Model architectures {architectures} not "
|
|
62
|
-
"
|
|
63
|
-
f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
|
|
64
|
-
)
|
|
60
|
+
f"Model architectures {architectures} are not supported for now. "
|
|
61
|
+
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
def _get_nnx_model(
|
|
@@ -180,23 +177,7 @@ def _get_nnx_model(
|
|
|
180
177
|
# the model creation again, otherwise the model forward will have
|
|
181
178
|
# non-trivial overhead in PjitFunction.
|
|
182
179
|
with mesh:
|
|
183
|
-
|
|
184
|
-
if isinstance(loader, RunaiModelStreamerLoader):
|
|
185
|
-
model_weights = vllm_config.model_config.model
|
|
186
|
-
if hasattr(vllm_config.model_config, "model_weights"):
|
|
187
|
-
model_weights = vllm_config.model_config.model_weights
|
|
188
|
-
weights_iterator = loader._get_weights_iterator(
|
|
189
|
-
model_weights, vllm_config.model_config.revision)
|
|
190
|
-
# We set the weights iterator at runtime, to prevent having to change
|
|
191
|
-
# every model's load_weights signature. This also prevents us from hitting
|
|
192
|
-
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
|
|
193
|
-
# flax_nnx model whose load_weights function does not accept the
|
|
194
|
-
# weights_iterator keyword argument.
|
|
195
|
-
vllm_config.model_config.model_weights_iterator = weights_iterator
|
|
196
|
-
model.load_weights(rng)
|
|
197
|
-
del vllm_config.model_config.model_weights_iterator
|
|
198
|
-
else:
|
|
199
|
-
model.load_weights(rng)
|
|
180
|
+
model.load_weights(rng)
|
|
200
181
|
jit_model = create_jit_model(
|
|
201
182
|
model,
|
|
202
183
|
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
@@ -236,9 +217,7 @@ def get_flax_model(
|
|
|
236
217
|
hidden_states_sharding, # aux hidden states
|
|
237
218
|
),
|
|
238
219
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
239
|
-
static_argnums=
|
|
240
|
-
7, 10, 11
|
|
241
|
-
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
220
|
+
static_argnums=6, #6 is layer_name_to_kvcache_index
|
|
242
221
|
)
|
|
243
222
|
def run_model(graphdef, state, *args):
|
|
244
223
|
model = nnx.merge(graphdef, state)
|
|
@@ -347,8 +326,8 @@ def get_model(
|
|
|
347
326
|
# Convert the error message to a string to check its contents
|
|
348
327
|
error_msg = str(e)
|
|
349
328
|
|
|
350
|
-
logger.warning(error_msg
|
|
351
|
-
|
|
329
|
+
logger.warning(f"Flax model failed with: '{error_msg}'. "
|
|
330
|
+
"Falling back to vLLM implementation.")
|
|
352
331
|
# Fall back to the vLLM model and updating the dtype accordingly
|
|
353
332
|
vllm_config.model_config.dtype = j2t_dtype(
|
|
354
333
|
vllm_config.model_config.dtype.dtype)
|
|
@@ -442,17 +421,6 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
442
421
|
"This is a JAX model and does not implement the PyTorch forward method."
|
|
443
422
|
)
|
|
444
423
|
|
|
445
|
-
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
446
|
-
def unimplemented_get_input_embeddings(
|
|
447
|
-
self,
|
|
448
|
-
input_ids: "torch.Tensor",
|
|
449
|
-
positions: "torch.Tensor",
|
|
450
|
-
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
451
|
-
) -> "torch.Tensor":
|
|
452
|
-
raise NotImplementedError(
|
|
453
|
-
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
|
|
454
|
-
)
|
|
455
|
-
|
|
456
424
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
457
425
|
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
458
426
|
def wrapper_init(self, *args, **kwargs):
|
|
@@ -466,7 +434,6 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
466
434
|
{
|
|
467
435
|
"__init__": wrapper_init,
|
|
468
436
|
"forward": unimplemented_forward,
|
|
469
|
-
"get_input_embeddings": unimplemented_get_input_embeddings,
|
|
470
437
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
471
438
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
472
439
|
})
|
|
@@ -368,8 +368,7 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config.
|
|
372
|
-
self.mesh, mappings)
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
373
372
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
373
|
model=self,
|
|
375
374
|
metadata_map=metadata_map,
|
|
@@ -194,12 +194,13 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
202
|
+
hidden_size = model_config.get_hidden_size()
|
|
203
|
+
|
|
203
204
|
head_dim_original = model_config.get_head_size()
|
|
204
205
|
|
|
205
206
|
metadata_map.reshape_map.update({
|
|
@@ -304,8 +305,6 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
304
305
|
"fc": "model.fc.kernel",
|
|
305
306
|
"lm_head": "lm_head.kernel",
|
|
306
307
|
"d2t": "draft_id_to_target_id",
|
|
307
|
-
"embed_tokens":
|
|
308
|
-
"model.embed_tokens.embedding", # Some checkpoints need this
|
|
309
308
|
}
|
|
310
309
|
|
|
311
310
|
# Define keys to keep in original dtype (e.g., float32 for stability)
|
|
@@ -313,9 +312,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
313
312
|
r".*d2t.*",
|
|
314
313
|
]
|
|
315
314
|
|
|
316
|
-
metadata_map = get_default_maps(
|
|
317
|
-
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
-
mappings)
|
|
315
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
319
316
|
|
|
320
317
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
321
318
|
|
|
@@ -327,7 +324,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
327
324
|
is_draft_model=True,
|
|
328
325
|
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
329
326
|
|
|
330
|
-
# If the embedding is not initialized, initialize it with a
|
|
327
|
+
# If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
331
328
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
332
329
|
jax.ShapeDtypeStruct):
|
|
333
330
|
self.model.embed_tokens.embedding.value = jnp.zeros(
|