tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,212 +0,0 @@
|
|
|
1
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
-
"""Logging configuration for vLLM."""
|
|
4
|
-
import datetime
|
|
5
|
-
import json
|
|
6
|
-
import logging
|
|
7
|
-
import sys
|
|
8
|
-
from collections.abc import Hashable
|
|
9
|
-
from functools import lru_cache, partial
|
|
10
|
-
from logging import Logger
|
|
11
|
-
from logging.config import dictConfig
|
|
12
|
-
from os import path
|
|
13
|
-
from types import MethodType
|
|
14
|
-
from typing import Any, cast
|
|
15
|
-
|
|
16
|
-
import tpu_inference.mock.vllm_envs as envs
|
|
17
|
-
|
|
18
|
-
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
|
19
|
-
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
|
20
|
-
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
|
|
21
|
-
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
|
|
22
|
-
|
|
23
|
-
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
|
|
24
|
-
"[%(filename)s:%(lineno)d] %(message)s")
|
|
25
|
-
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
26
|
-
|
|
27
|
-
DEFAULT_LOGGING_CONFIG = {
|
|
28
|
-
"formatters": {
|
|
29
|
-
"vllm": {
|
|
30
|
-
"class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
|
|
31
|
-
"datefmt": _DATE_FORMAT,
|
|
32
|
-
"format": _FORMAT,
|
|
33
|
-
},
|
|
34
|
-
},
|
|
35
|
-
"handlers": {
|
|
36
|
-
"vllm": {
|
|
37
|
-
"class": "logging.StreamHandler",
|
|
38
|
-
"formatter": "vllm",
|
|
39
|
-
"level": VLLM_LOGGING_LEVEL,
|
|
40
|
-
"stream": "ext://sys.stdout",
|
|
41
|
-
},
|
|
42
|
-
},
|
|
43
|
-
"loggers": {
|
|
44
|
-
"vllm": {
|
|
45
|
-
"handlers": ["vllm"],
|
|
46
|
-
"level": "DEBUG",
|
|
47
|
-
"propagate": False,
|
|
48
|
-
},
|
|
49
|
-
},
|
|
50
|
-
"version": 1,
|
|
51
|
-
"disable_existing_loggers": False
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@lru_cache
|
|
56
|
-
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
57
|
-
# Set the stacklevel to 2 to print the original caller's line info
|
|
58
|
-
logger.debug(msg, *args, stacklevel=2)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@lru_cache
|
|
62
|
-
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
63
|
-
# Set the stacklevel to 2 to print the original caller's line info
|
|
64
|
-
logger.info(msg, *args, stacklevel=2)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@lru_cache
|
|
68
|
-
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
|
69
|
-
# Set the stacklevel to 2 to print the original caller's line info
|
|
70
|
-
logger.warning(msg, *args, stacklevel=2)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class _VllmLogger(Logger):
|
|
74
|
-
"""
|
|
75
|
-
Note:
|
|
76
|
-
This class is just to provide type information.
|
|
77
|
-
We actually patch the methods directly on the [`logging.Logger`][]
|
|
78
|
-
instance to avoid conflicting with other libraries such as
|
|
79
|
-
`intel_extension_for_pytorch.utils._logger`.
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
def debug_once(self, msg: str, *args: Hashable) -> None:
|
|
83
|
-
"""
|
|
84
|
-
As [`debug`][logging.Logger.debug], but subsequent calls with
|
|
85
|
-
the same message are silently dropped.
|
|
86
|
-
"""
|
|
87
|
-
_print_debug_once(self, msg, *args)
|
|
88
|
-
|
|
89
|
-
def info_once(self, msg: str, *args: Hashable) -> None:
|
|
90
|
-
"""
|
|
91
|
-
As [`info`][logging.Logger.info], but subsequent calls with
|
|
92
|
-
the same message are silently dropped.
|
|
93
|
-
"""
|
|
94
|
-
_print_info_once(self, msg, *args)
|
|
95
|
-
|
|
96
|
-
def warning_once(self, msg: str, *args: Hashable) -> None:
|
|
97
|
-
"""
|
|
98
|
-
As [`warning`][logging.Logger.warning], but subsequent calls with
|
|
99
|
-
the same message are silently dropped.
|
|
100
|
-
"""
|
|
101
|
-
_print_warning_once(self, msg, *args)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
# Pre-defined methods mapping to avoid repeated dictionary creation
|
|
105
|
-
_METHODS_TO_PATCH = {
|
|
106
|
-
"debug_once": _print_debug_once,
|
|
107
|
-
"info_once": _print_info_once,
|
|
108
|
-
"warning_once": _print_warning_once,
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def _configure_vllm_root_logger() -> None:
|
|
113
|
-
logging_config = dict[str, Any]()
|
|
114
|
-
|
|
115
|
-
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
|
116
|
-
raise RuntimeError(
|
|
117
|
-
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
|
|
118
|
-
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
|
|
119
|
-
"implies VLLM_CONFIGURE_LOGGING. Please enable "
|
|
120
|
-
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
|
|
121
|
-
|
|
122
|
-
if VLLM_CONFIGURE_LOGGING:
|
|
123
|
-
logging_config = DEFAULT_LOGGING_CONFIG
|
|
124
|
-
|
|
125
|
-
if VLLM_LOGGING_CONFIG_PATH:
|
|
126
|
-
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
|
|
127
|
-
raise RuntimeError(
|
|
128
|
-
"Could not load logging config. File does not exist: %s",
|
|
129
|
-
VLLM_LOGGING_CONFIG_PATH)
|
|
130
|
-
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
|
131
|
-
custom_config = json.loads(file.read())
|
|
132
|
-
|
|
133
|
-
if not isinstance(custom_config, dict):
|
|
134
|
-
raise ValueError("Invalid logging config. Expected dict, got %s.",
|
|
135
|
-
type(custom_config).__name__)
|
|
136
|
-
logging_config = custom_config
|
|
137
|
-
|
|
138
|
-
for formatter in logging_config.get("formatters", {}).values():
|
|
139
|
-
# This provides backwards compatibility after #10134.
|
|
140
|
-
if formatter.get(
|
|
141
|
-
"class"
|
|
142
|
-
) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
|
|
143
|
-
formatter[
|
|
144
|
-
"class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
|
|
145
|
-
|
|
146
|
-
if logging_config:
|
|
147
|
-
dictConfig(logging_config)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# The root logger is initialized when the module is imported.
|
|
151
|
-
# This is thread-safe as the module is only imported once,
|
|
152
|
-
# guaranteed by the Python GIL.
|
|
153
|
-
_configure_vllm_root_logger()
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def init_logger(name: str) -> _VllmLogger:
|
|
157
|
-
"""The main purpose of this function is to ensure that loggers are
|
|
158
|
-
retrieved in such a way that we can be sure the root vllm logger has
|
|
159
|
-
already been configured."""
|
|
160
|
-
|
|
161
|
-
logger = logging.getLogger(name)
|
|
162
|
-
|
|
163
|
-
for method_name, method in _METHODS_TO_PATCH.items():
|
|
164
|
-
setattr(logger, method_name, MethodType(method, logger))
|
|
165
|
-
|
|
166
|
-
return cast(_VllmLogger, logger)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
logger = init_logger(__name__)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
|
173
|
-
if event in ['call', 'return']:
|
|
174
|
-
# Extract the filename, line number, function name, and the code object
|
|
175
|
-
filename = frame.f_code.co_filename
|
|
176
|
-
lineno = frame.f_lineno
|
|
177
|
-
func_name = frame.f_code.co_name
|
|
178
|
-
if not filename.startswith(root_dir):
|
|
179
|
-
# only log the functions in the vllm root_dir
|
|
180
|
-
return
|
|
181
|
-
# Log every function call or return
|
|
182
|
-
try:
|
|
183
|
-
last_frame = frame.f_back
|
|
184
|
-
if last_frame is not None:
|
|
185
|
-
last_filename = last_frame.f_code.co_filename
|
|
186
|
-
last_lineno = last_frame.f_lineno
|
|
187
|
-
last_func_name = last_frame.f_code.co_name
|
|
188
|
-
else:
|
|
189
|
-
# initial frame
|
|
190
|
-
last_filename = ""
|
|
191
|
-
last_lineno = 0
|
|
192
|
-
last_func_name = ""
|
|
193
|
-
with open(log_path, 'a') as f:
|
|
194
|
-
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
195
|
-
if event == 'call':
|
|
196
|
-
f.write(f"{ts} Call to"
|
|
197
|
-
f" {func_name} in {filename}:{lineno}"
|
|
198
|
-
f" from {last_func_name} in {last_filename}:"
|
|
199
|
-
f"{last_lineno}\n")
|
|
200
|
-
else:
|
|
201
|
-
f.write(f"{ts} Return from"
|
|
202
|
-
f" {func_name} in {filename}:{lineno}"
|
|
203
|
-
f" to {last_func_name} in {last_filename}:"
|
|
204
|
-
f"{last_lineno}\n")
|
|
205
|
-
except NameError:
|
|
206
|
-
# modules are deleted during shutdown
|
|
207
|
-
pass
|
|
208
|
-
return partial(_trace_calls, log_path, root_dir)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
def enable_trace_function_call(log_file_path: str, root_dir: str):
|
|
212
|
-
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class NewLineFormatter(logging.Formatter):
|
|
5
|
-
"""Adds logging prefix to newlines to align multi-line messages."""
|
|
6
|
-
|
|
7
|
-
def __init__(self, fmt, datefmt=None, style="%"):
|
|
8
|
-
logging.Formatter.__init__(self, fmt, datefmt, style)
|
|
9
|
-
|
|
10
|
-
def format(self, record):
|
|
11
|
-
msg = logging.Formatter.format(self, record)
|
|
12
|
-
if record.message != "":
|
|
13
|
-
parts = msg.split(record.message)
|
|
14
|
-
msg = msg.replace("\n", "\r\n" + parts[0])
|
|
15
|
-
return msg
|
tpu_inference/models/jax/phi3.py
DELETED
|
@@ -1,376 +0,0 @@
|
|
|
1
|
-
from typing import List, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
from flax import nnx
|
|
6
|
-
from jax.sharding import Mesh
|
|
7
|
-
from transformers import Phi3Config, modeling_flax_utils
|
|
8
|
-
from vllm.config import VllmConfig
|
|
9
|
-
|
|
10
|
-
from tpu_inference import utils
|
|
11
|
-
from tpu_inference.layers.common.attention_interface import attention
|
|
12
|
-
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
-
from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
|
|
14
|
-
from tpu_inference.logger import init_logger
|
|
15
|
-
from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
|
|
16
|
-
load_hf_weights)
|
|
17
|
-
|
|
18
|
-
logger = init_logger(__name__)
|
|
19
|
-
|
|
20
|
-
init_fn = nnx.initializers.uniform()
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Phi3MLP(nnx.Module):
|
|
24
|
-
|
|
25
|
-
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs):
|
|
26
|
-
hidden_size = config.hidden_size
|
|
27
|
-
intermediate_size = config.intermediate_size
|
|
28
|
-
act = config.hidden_act
|
|
29
|
-
|
|
30
|
-
self.gate_up_proj = nnx.Linear(
|
|
31
|
-
hidden_size,
|
|
32
|
-
2 * intermediate_size,
|
|
33
|
-
use_bias=False,
|
|
34
|
-
param_dtype=dtype,
|
|
35
|
-
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
36
|
-
rngs=rng,
|
|
37
|
-
)
|
|
38
|
-
self.down_proj = nnx.Linear(
|
|
39
|
-
intermediate_size,
|
|
40
|
-
hidden_size,
|
|
41
|
-
use_bias=False,
|
|
42
|
-
param_dtype=dtype,
|
|
43
|
-
kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
44
|
-
rngs=rng,
|
|
45
|
-
)
|
|
46
|
-
self.act_fn = modeling_flax_utils.ACT2FN[act]
|
|
47
|
-
|
|
48
|
-
def __call__(self, x: jax.Array) -> jax.Array:
|
|
49
|
-
gate_up = self.gate_up_proj(x)
|
|
50
|
-
gate, up = jnp.split(gate_up, 2, axis=-1)
|
|
51
|
-
fuse = up * self.act_fn(gate)
|
|
52
|
-
result = self.down_proj(fuse)
|
|
53
|
-
return result
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class Phi3Attention(nnx.Module):
|
|
57
|
-
|
|
58
|
-
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
59
|
-
mesh: Mesh, kv_cache_dtype: str):
|
|
60
|
-
self.hidden_size = config.hidden_size
|
|
61
|
-
self.num_heads = config.num_attention_heads
|
|
62
|
-
self.num_kv_heads = config.num_key_value_heads
|
|
63
|
-
self.rope_theta = config.rope_theta
|
|
64
|
-
self.rope_scaling = config.rope_scaling
|
|
65
|
-
self.original_max_position_embeddings = config.original_max_position_embeddings
|
|
66
|
-
self.max_position_embeddings = config.max_position_embeddings
|
|
67
|
-
|
|
68
|
-
self.head_dim_original = getattr(config, "head_dim",
|
|
69
|
-
self.hidden_size // self.num_heads)
|
|
70
|
-
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
71
|
-
|
|
72
|
-
sharding_size = mesh.shape["model"]
|
|
73
|
-
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
74
|
-
sharding_size)
|
|
75
|
-
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
76
|
-
sharding_size)
|
|
77
|
-
|
|
78
|
-
self.mesh = mesh
|
|
79
|
-
|
|
80
|
-
self.qkv_proj = nnx.Einsum(
|
|
81
|
-
"TD,DNH->TNH",
|
|
82
|
-
(self.hidden_size, self.num_heads + self.num_kv_heads * 2,
|
|
83
|
-
self.head_dim),
|
|
84
|
-
param_dtype=dtype,
|
|
85
|
-
kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
|
|
86
|
-
rngs=rng,
|
|
87
|
-
)
|
|
88
|
-
self.o_proj = nnx.Einsum(
|
|
89
|
-
"TNH,NHD->TD",
|
|
90
|
-
(self.num_heads, self.head_dim, self.hidden_size),
|
|
91
|
-
param_dtype=dtype,
|
|
92
|
-
kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
|
|
93
|
-
rngs=rng,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self._q_scale = 1.0
|
|
97
|
-
self._k_scale = 1.0
|
|
98
|
-
self._v_scale = 1.0
|
|
99
|
-
self.kv_cache_quantized_dtype = None
|
|
100
|
-
if kv_cache_dtype != "auto":
|
|
101
|
-
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
102
|
-
kv_cache_dtype)
|
|
103
|
-
|
|
104
|
-
def __call__(
|
|
105
|
-
self,
|
|
106
|
-
kv_cache: Optional[jax.Array],
|
|
107
|
-
x: jax.Array,
|
|
108
|
-
attention_metadata: AttentionMetadata,
|
|
109
|
-
) -> Tuple[jax.Array, jax.Array]:
|
|
110
|
-
md = attention_metadata
|
|
111
|
-
# qkv: (T, N + K * 2, H)
|
|
112
|
-
qkv = self.qkv_proj(x)
|
|
113
|
-
q, k, v = jnp.split(
|
|
114
|
-
qkv, [self.num_heads, self.num_heads + self.num_kv_heads], axis=1)
|
|
115
|
-
if self.rope_scaling:
|
|
116
|
-
q = apply_longrope(q, md.input_positions, self.head_dim_original,
|
|
117
|
-
self.rope_scaling,
|
|
118
|
-
self.original_max_position_embeddings,
|
|
119
|
-
self.max_position_embeddings, self.rope_theta)
|
|
120
|
-
k = apply_longrope(k, md.input_positions, self.head_dim_original,
|
|
121
|
-
self.rope_scaling,
|
|
122
|
-
self.original_max_position_embeddings,
|
|
123
|
-
self.max_position_embeddings, self.rope_theta)
|
|
124
|
-
else:
|
|
125
|
-
q = apply_rope(q, md.input_positions, self.head_dim_original,
|
|
126
|
-
self.rope_theta, self.rope_scaling)
|
|
127
|
-
k = apply_rope(k, md.input_positions, self.head_dim_original,
|
|
128
|
-
self.rope_theta, self.rope_scaling)
|
|
129
|
-
# o: (T, N, H)
|
|
130
|
-
q_scale = k_scale = v_scale = None
|
|
131
|
-
if self.kv_cache_quantized_dtype:
|
|
132
|
-
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
133
|
-
# q_scale = self._q_scale
|
|
134
|
-
k_scale = self._k_scale
|
|
135
|
-
v_scale = self._v_scale
|
|
136
|
-
k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
|
|
137
|
-
k_scale, v_scale)
|
|
138
|
-
new_kv_cache, outputs = attention(
|
|
139
|
-
kv_cache,
|
|
140
|
-
q,
|
|
141
|
-
k,
|
|
142
|
-
v,
|
|
143
|
-
attention_metadata,
|
|
144
|
-
self.mesh,
|
|
145
|
-
self.head_dim_original,
|
|
146
|
-
q_scale=q_scale,
|
|
147
|
-
k_scale=k_scale,
|
|
148
|
-
v_scale=v_scale,
|
|
149
|
-
)
|
|
150
|
-
# (T, D)
|
|
151
|
-
o = self.o_proj(outputs)
|
|
152
|
-
return new_kv_cache, o
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class Phi3DecoderLayer(nnx.Module):
|
|
156
|
-
|
|
157
|
-
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
158
|
-
mesh: Mesh, kv_cache_dtype: str):
|
|
159
|
-
rms_norm_eps = config.rms_norm_eps
|
|
160
|
-
hidden_size = config.hidden_size
|
|
161
|
-
|
|
162
|
-
self.input_layernorm = nnx.RMSNorm(
|
|
163
|
-
hidden_size,
|
|
164
|
-
epsilon=rms_norm_eps,
|
|
165
|
-
param_dtype=dtype,
|
|
166
|
-
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
167
|
-
rngs=rng,
|
|
168
|
-
)
|
|
169
|
-
self.self_attn = Phi3Attention(config=config,
|
|
170
|
-
dtype=dtype,
|
|
171
|
-
rng=rng,
|
|
172
|
-
mesh=mesh,
|
|
173
|
-
kv_cache_dtype=kv_cache_dtype)
|
|
174
|
-
self.post_attention_layernorm = nnx.RMSNorm(
|
|
175
|
-
hidden_size,
|
|
176
|
-
epsilon=rms_norm_eps,
|
|
177
|
-
param_dtype=dtype,
|
|
178
|
-
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
179
|
-
rngs=rng,
|
|
180
|
-
)
|
|
181
|
-
self.mlp = Phi3MLP(
|
|
182
|
-
config=config,
|
|
183
|
-
dtype=dtype,
|
|
184
|
-
rng=rng,
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
def __call__(
|
|
188
|
-
self,
|
|
189
|
-
kv_cache: jax.Array,
|
|
190
|
-
x: jax.Array,
|
|
191
|
-
attention_metadata: AttentionMetadata,
|
|
192
|
-
) -> Tuple[jax.Array, jax.Array]:
|
|
193
|
-
hidden_states = self.input_layernorm(x)
|
|
194
|
-
kv_cache, attn_output = self.self_attn(
|
|
195
|
-
kv_cache,
|
|
196
|
-
hidden_states,
|
|
197
|
-
attention_metadata,
|
|
198
|
-
)
|
|
199
|
-
attn_output += x
|
|
200
|
-
|
|
201
|
-
residual = attn_output
|
|
202
|
-
attn_output = self.post_attention_layernorm(attn_output)
|
|
203
|
-
outputs = self.mlp(attn_output)
|
|
204
|
-
outputs = residual + outputs
|
|
205
|
-
return kv_cache, outputs
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class Phi3Model(nnx.Module):
|
|
209
|
-
|
|
210
|
-
def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
|
|
211
|
-
mesh: Mesh) -> None:
|
|
212
|
-
model_config = vllm_config.model_config
|
|
213
|
-
hf_config = model_config.hf_config
|
|
214
|
-
vocab_size = model_config.get_vocab_size()
|
|
215
|
-
dtype = model_config.dtype
|
|
216
|
-
rms_norm_eps = hf_config.rms_norm_eps
|
|
217
|
-
hidden_size = hf_config.hidden_size
|
|
218
|
-
|
|
219
|
-
self.embed = nnx.Embed(
|
|
220
|
-
num_embeddings=vocab_size,
|
|
221
|
-
features=hidden_size,
|
|
222
|
-
param_dtype=dtype,
|
|
223
|
-
embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
224
|
-
rngs=rng,
|
|
225
|
-
)
|
|
226
|
-
self.layers = [
|
|
227
|
-
Phi3DecoderLayer(
|
|
228
|
-
config=hf_config,
|
|
229
|
-
dtype=dtype,
|
|
230
|
-
rng=rng,
|
|
231
|
-
mesh=mesh,
|
|
232
|
-
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
233
|
-
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
|
|
234
|
-
for _ in range(hf_config.num_hidden_layers)
|
|
235
|
-
]
|
|
236
|
-
self.norm = nnx.RMSNorm(
|
|
237
|
-
hidden_size,
|
|
238
|
-
epsilon=rms_norm_eps,
|
|
239
|
-
param_dtype=dtype,
|
|
240
|
-
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
241
|
-
rngs=rng,
|
|
242
|
-
)
|
|
243
|
-
if model_config.hf_config.tie_word_embeddings:
|
|
244
|
-
self.lm_head = self.embed.embedding
|
|
245
|
-
else:
|
|
246
|
-
self.lm_head = nnx.Param(
|
|
247
|
-
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
248
|
-
sharding=(None, "model"),
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
def __call__(
|
|
252
|
-
self,
|
|
253
|
-
kv_caches: List[jax.Array],
|
|
254
|
-
input_ids: jax.Array,
|
|
255
|
-
attention_metadata: AttentionMetadata,
|
|
256
|
-
) -> Tuple[List[jax.Array], jax.Array]:
|
|
257
|
-
x = self.embed(input_ids)
|
|
258
|
-
for i, layer in enumerate(self.layers):
|
|
259
|
-
kv_cache = kv_caches[i]
|
|
260
|
-
kv_cache, x = layer(
|
|
261
|
-
kv_cache,
|
|
262
|
-
x,
|
|
263
|
-
attention_metadata,
|
|
264
|
-
)
|
|
265
|
-
kv_caches[i] = kv_cache
|
|
266
|
-
x = self.norm(x)
|
|
267
|
-
return kv_caches, x
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
class Phi3ForCausalLM(nnx.Module):
|
|
271
|
-
|
|
272
|
-
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
|
|
273
|
-
mesh: Mesh) -> None:
|
|
274
|
-
self.vllm_config = vllm_config
|
|
275
|
-
self.rng = nnx.Rngs(rng_key)
|
|
276
|
-
self.mesh = mesh
|
|
277
|
-
|
|
278
|
-
self.model = Phi3Model(
|
|
279
|
-
vllm_config=vllm_config,
|
|
280
|
-
rng=self.rng,
|
|
281
|
-
mesh=mesh,
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
def __call__(
|
|
285
|
-
self,
|
|
286
|
-
kv_caches: List[jax.Array],
|
|
287
|
-
input_ids: jax.Array,
|
|
288
|
-
attention_metadata: AttentionMetadata,
|
|
289
|
-
*args,
|
|
290
|
-
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
|
|
291
|
-
kv_caches, x = self.model(
|
|
292
|
-
kv_caches,
|
|
293
|
-
input_ids,
|
|
294
|
-
attention_metadata,
|
|
295
|
-
)
|
|
296
|
-
return kv_caches, x, []
|
|
297
|
-
|
|
298
|
-
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
299
|
-
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
300
|
-
logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
|
|
301
|
-
else:
|
|
302
|
-
logits = jnp.dot(hidden_states, self.model.lm_head.value)
|
|
303
|
-
return logits
|
|
304
|
-
|
|
305
|
-
def get_metadata_map(self) -> MetadataMap:
|
|
306
|
-
sharding_size = self.mesh.shape["model"]
|
|
307
|
-
|
|
308
|
-
model_config = self.vllm_config.model_config
|
|
309
|
-
hf_config = model_config.hf_config
|
|
310
|
-
|
|
311
|
-
num_heads = hf_config.num_attention_heads
|
|
312
|
-
num_kv_heads = hf_config.num_key_value_heads
|
|
313
|
-
qkv_heads = num_heads + num_kv_heads * 2
|
|
314
|
-
hidden_size = model_config.get_hidden_size()
|
|
315
|
-
|
|
316
|
-
# Pad head_dim for kernel performance.
|
|
317
|
-
head_dim_original = model_config.get_head_size()
|
|
318
|
-
|
|
319
|
-
# Key: path to a HF layer weight
|
|
320
|
-
# Value: path to a nnx layer weight
|
|
321
|
-
name_map = {
|
|
322
|
-
"model.embed_tokens": "model.embed.embedding",
|
|
323
|
-
"model.layers.*.input_layernorm":
|
|
324
|
-
"model.layers.*.input_layernorm.scale",
|
|
325
|
-
"model.layers.*.mlp.down_proj":
|
|
326
|
-
"model.layers.*.mlp.down_proj.kernel",
|
|
327
|
-
"model.layers.*.mlp.gate_up_proj":
|
|
328
|
-
"model.layers.*.mlp.gate_up_proj.kernel",
|
|
329
|
-
"model.layers.*.post_attention_layernorm":
|
|
330
|
-
"model.layers.*.post_attention_layernorm.scale",
|
|
331
|
-
"model.layers.*.self_attn.qkv_proj":
|
|
332
|
-
"model.layers.*.self_attn.qkv_proj.kernel",
|
|
333
|
-
"model.layers.*.self_attn.o_proj":
|
|
334
|
-
"model.layers.*.self_attn.o_proj.kernel",
|
|
335
|
-
"model.norm": "model.norm.scale",
|
|
336
|
-
}
|
|
337
|
-
if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
338
|
-
name_map.update({
|
|
339
|
-
"lm_head": "model.lm_head",
|
|
340
|
-
})
|
|
341
|
-
|
|
342
|
-
reshape_keys: dict[str, tuple[int, ...]] = {
|
|
343
|
-
"qkv_proj": (qkv_heads, head_dim_original, hidden_size),
|
|
344
|
-
"o_proj": (hidden_size, num_heads, head_dim_original),
|
|
345
|
-
}
|
|
346
|
-
transpose_keys: dict[str, tuple[int, ...]] = {
|
|
347
|
-
"lm_head": (1, 0),
|
|
348
|
-
"gate_up_proj": (1, 0),
|
|
349
|
-
"down_proj": (1, 0),
|
|
350
|
-
"qkv_proj": (2, 0, 1),
|
|
351
|
-
"o_proj": (1, 2, 0),
|
|
352
|
-
}
|
|
353
|
-
|
|
354
|
-
# key: (padding_dim, padding_size)
|
|
355
|
-
pad_keys: dict[str, tuple[int, ...]] = {
|
|
356
|
-
"qkv_proj": (1, sharding_size // num_heads),
|
|
357
|
-
"o_proj": (0, sharding_size // num_heads),
|
|
358
|
-
}
|
|
359
|
-
|
|
360
|
-
return MetadataMap(name_map=name_map,
|
|
361
|
-
reshape_map=reshape_keys,
|
|
362
|
-
bias_reshape_map={},
|
|
363
|
-
transpose_map=transpose_keys,
|
|
364
|
-
pad_map=pad_keys,
|
|
365
|
-
bias_pad_map={})
|
|
366
|
-
|
|
367
|
-
def load_weights(self, rng_key: jax.Array):
|
|
368
|
-
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
369
|
-
# we have to pass dynamic arrays here for __call__'s usage.
|
|
370
|
-
self.rng = nnx.Rngs(rng_key)
|
|
371
|
-
|
|
372
|
-
metadata_map = self.get_metadata_map()
|
|
373
|
-
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
|
-
model=self,
|
|
375
|
-
metadata_map=metadata_map,
|
|
376
|
-
mesh=self.mesh)
|
{tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|