tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import torch
|
|
8
|
+
from flax import nnx
|
|
9
|
+
from flax.typing import PRNGKey
|
|
10
|
+
from jax.sharding import Mesh, NamedSharding
|
|
11
|
+
from jax.sharding import PartitionSpec as P
|
|
12
|
+
from vllm.config import VllmConfig
|
|
13
|
+
|
|
14
|
+
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
|
|
15
|
+
AttentionMetadata, GptOssAttention)
|
|
16
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
17
|
+
from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
|
|
18
|
+
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
|
|
19
|
+
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
20
|
+
from tpu_inference.logger import init_logger
|
|
21
|
+
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
|
|
22
|
+
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
|
|
23
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
24
|
+
get_param, model_weights_generator, print_param_info)
|
|
25
|
+
|
|
26
|
+
logger = init_logger(__name__)
|
|
27
|
+
|
|
28
|
+
# A map from JAX dtype to the corresponding PyTorch integer dtype for raw memory viewing.
|
|
29
|
+
DTYPE_VIEW_MAP = {
|
|
30
|
+
jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
|
|
31
|
+
jnp.dtype(jnp.bfloat16): torch.uint16,
|
|
32
|
+
jnp.dtype(jnp.float32): torch.uint32,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class GptOss(nnx.Module):
|
|
38
|
+
"""
|
|
39
|
+
JAX implementation of the GPT-OSS model architecture.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self,
|
|
43
|
+
vllm_config: VllmConfig,
|
|
44
|
+
rng: jax.Array,
|
|
45
|
+
mesh: Mesh,
|
|
46
|
+
force_random_weights: bool = False):
|
|
47
|
+
assert mesh is not None
|
|
48
|
+
|
|
49
|
+
self.vllm_config = vllm_config
|
|
50
|
+
self.hf_config = vllm_config.model_config.hf_config
|
|
51
|
+
self.rng = nnx.Rngs(rng)
|
|
52
|
+
|
|
53
|
+
num_layers: int = self.hf_config.num_hidden_layers
|
|
54
|
+
num_experts: int = self.hf_config.num_local_experts
|
|
55
|
+
vocab_size: int = self.hf_config.vocab_size
|
|
56
|
+
num_attention_heads: int = self.hf_config.num_attention_heads
|
|
57
|
+
num_key_value_heads: int = self.hf_config.num_key_value_heads
|
|
58
|
+
head_dim: int = self.hf_config.head_dim
|
|
59
|
+
hidden_size: int = self.hf_config.hidden_size
|
|
60
|
+
ffw_intermediate_size: int = self.hf_config.intermediate_size
|
|
61
|
+
num_experts_per_token: int = self.hf_config.num_experts_per_tok
|
|
62
|
+
rms_norm_eps: float = self.hf_config.rms_norm_eps
|
|
63
|
+
swiglu_limit: float = self.hf_config.swiglu_limit
|
|
64
|
+
|
|
65
|
+
rope_theta: float = self.hf_config.rope_theta
|
|
66
|
+
rope_scaling_factor: float = self.hf_config.rope_scaling["factor"]
|
|
67
|
+
rope_ntk_alpha: float = self.hf_config.rope_scaling["beta_slow"]
|
|
68
|
+
rope_ntk_beta: float = self.hf_config.rope_scaling["beta_fast"]
|
|
69
|
+
initial_context_length: int = self.hf_config.rope_scaling[
|
|
70
|
+
"original_max_position_embeddings"]
|
|
71
|
+
|
|
72
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
73
|
+
|
|
74
|
+
self.sliding_window = self.hf_config.sliding_window
|
|
75
|
+
|
|
76
|
+
self.random_init = force_random_weights or self.vllm_config.additional_config.get(
|
|
77
|
+
"random_weights", False)
|
|
78
|
+
self.mesh = mesh
|
|
79
|
+
|
|
80
|
+
self.embedder = Embedder(
|
|
81
|
+
vocab_size=vocab_size,
|
|
82
|
+
hidden_size=hidden_size,
|
|
83
|
+
dtype=dtype,
|
|
84
|
+
rngs=self.rng,
|
|
85
|
+
vd_sharding=P(('data', 'model'), None),
|
|
86
|
+
random_init=self.random_init,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
self.layers = []
|
|
90
|
+
for i in range(num_layers):
|
|
91
|
+
attn = GptOssAttention(
|
|
92
|
+
hidden_size=hidden_size,
|
|
93
|
+
num_attention_heads=num_attention_heads,
|
|
94
|
+
num_key_value_heads=num_key_value_heads,
|
|
95
|
+
head_dim=head_dim,
|
|
96
|
+
dtype=dtype,
|
|
97
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
98
|
+
rope_theta=rope_theta,
|
|
99
|
+
initial_context_length=initial_context_length,
|
|
100
|
+
rope_scaling_factor=rope_scaling_factor,
|
|
101
|
+
rope_ntk_alpha=rope_ntk_alpha,
|
|
102
|
+
rope_ntk_beta=rope_ntk_beta,
|
|
103
|
+
rngs=self.rng,
|
|
104
|
+
random_init=self.random_init,
|
|
105
|
+
query_tnh=P(None, 'model', None),
|
|
106
|
+
keyvalue_skh=P(None, 'model', None),
|
|
107
|
+
attn_o_tnh=P(None, 'model', None),
|
|
108
|
+
dnh_sharding=P(None, 'model', None),
|
|
109
|
+
dkh_sharding=P(None, 'model', None),
|
|
110
|
+
nhd_sharding=P('model', None, None),
|
|
111
|
+
mesh=self.mesh,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# MoE MLP block
|
|
115
|
+
router = GptOssRouter(
|
|
116
|
+
hidden_size=hidden_size,
|
|
117
|
+
num_experts=num_experts,
|
|
118
|
+
num_experts_per_tok=num_experts_per_token,
|
|
119
|
+
rngs=self.rng,
|
|
120
|
+
dtype=dtype,
|
|
121
|
+
router_act='softmax',
|
|
122
|
+
random_init=self.random_init,
|
|
123
|
+
activation_ffw_td=P('data', None),
|
|
124
|
+
ed_sharding=P('model', None),
|
|
125
|
+
e_sharding=P('model'),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
moe_mlp = GptOssMoE(
|
|
129
|
+
dtype=dtype,
|
|
130
|
+
num_local_experts=num_experts,
|
|
131
|
+
hidden_size=hidden_size,
|
|
132
|
+
intermediate_size_moe=ffw_intermediate_size,
|
|
133
|
+
rngs=self.rng,
|
|
134
|
+
random_init=self.random_init,
|
|
135
|
+
router=router,
|
|
136
|
+
swiglu_limit=swiglu_limit,
|
|
137
|
+
# Sharding configuration
|
|
138
|
+
activation_ffw_td=P('data', None),
|
|
139
|
+
edf_sharding=P('model', None, None),
|
|
140
|
+
efd_sharding=P('model', None, None),
|
|
141
|
+
ed_sharding=P('model', None),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
block = TransformerBlock(
|
|
145
|
+
pre_attention_norm=RMSNorm(
|
|
146
|
+
dims=hidden_size,
|
|
147
|
+
random_init=self.random_init,
|
|
148
|
+
epsilon=rms_norm_eps,
|
|
149
|
+
dtype=dtype,
|
|
150
|
+
rngs=self.rng,
|
|
151
|
+
activation_ffw_td=P('data', None),
|
|
152
|
+
),
|
|
153
|
+
pre_mlp_norm=RMSNorm(
|
|
154
|
+
dims=hidden_size,
|
|
155
|
+
random_init=self.random_init,
|
|
156
|
+
epsilon=rms_norm_eps,
|
|
157
|
+
dtype=dtype,
|
|
158
|
+
rngs=self.rng,
|
|
159
|
+
activation_ffw_td=P('data', None),
|
|
160
|
+
),
|
|
161
|
+
attn=attn,
|
|
162
|
+
custom_module=moe_mlp,
|
|
163
|
+
)
|
|
164
|
+
self.layers.append(block)
|
|
165
|
+
# Note: ALL RMSNorm does not upcast input to float32, while the pytorch does
|
|
166
|
+
self.final_norm = RMSNorm(
|
|
167
|
+
dims=hidden_size,
|
|
168
|
+
rngs=self.rng,
|
|
169
|
+
random_init=self.random_init,
|
|
170
|
+
epsilon=rms_norm_eps,
|
|
171
|
+
dtype=dtype,
|
|
172
|
+
activation_ffw_td=P('data', None),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.lm_head = LMhead(
|
|
176
|
+
vocab_size=vocab_size,
|
|
177
|
+
hidden_size=hidden_size,
|
|
178
|
+
dtype=dtype,
|
|
179
|
+
rngs=self.rng,
|
|
180
|
+
vd_sharding=P(('data', 'model'), None),
|
|
181
|
+
dv_sharding=P(None, ('data', 'model')),
|
|
182
|
+
random_init=self.random_init,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# For compatibility with flax.
|
|
186
|
+
def apply(self, variables, *args, **kwargs):
|
|
187
|
+
return self.__call__(*args, **kwargs)
|
|
188
|
+
|
|
189
|
+
def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
|
|
190
|
+
"""Loads and transforms all weights from a checkpoint"""
|
|
191
|
+
self.rng = nnx.Rngs(rng)
|
|
192
|
+
|
|
193
|
+
# Determine quantization method from HF config (config.json)
|
|
194
|
+
quant_method = (self.hf_config.quantization_config["quant_method"]
|
|
195
|
+
if hasattr(self.hf_config, "quantization_config") else
|
|
196
|
+
None)
|
|
197
|
+
|
|
198
|
+
# Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
|
|
199
|
+
transforms = {
|
|
200
|
+
"transpose_reshape": lambda w, shape: w.T.reshape(shape),
|
|
201
|
+
"reshape": lambda b, shape: b.reshape(shape),
|
|
202
|
+
"transpose": lambda w, _: w.T,
|
|
203
|
+
"swap_last2": lambda w, _: w.swapaxes(-1, -2),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
|
|
207
|
+
swap_mlp_transform = transforms[
|
|
208
|
+
"swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
|
|
209
|
+
|
|
210
|
+
mappings = {
|
|
211
|
+
# Embeddings, Norms, and LM Head
|
|
212
|
+
"model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
|
|
213
|
+
None, None),
|
|
214
|
+
"lm_head.weight": ("lm_head.input_embedding_table_DV",
|
|
215
|
+
transforms["transpose"], None),
|
|
216
|
+
"model.norm.weight": ("final_norm.scale", None, None),
|
|
217
|
+
"model.layers.*.input_layernorm.weight":
|
|
218
|
+
("layers.*.pre_attention_norm.scale", None, None),
|
|
219
|
+
"model.layers.*.post_attention_layernorm.weight":
|
|
220
|
+
("layers.*.pre_mlp_norm.scale", None, None),
|
|
221
|
+
|
|
222
|
+
# Attention Weights
|
|
223
|
+
"model.layers.*.self_attn.q_proj.weight":
|
|
224
|
+
("layers.*.attn.kernel_q_DNH", transforms["transpose_reshape"],
|
|
225
|
+
(self.hf_config.hidden_size, self.hf_config.num_attention_heads,
|
|
226
|
+
self.hf_config.head_dim)),
|
|
227
|
+
"model.layers.*.self_attn.k_proj.weight":
|
|
228
|
+
("layers.*.attn.kernel_k_DKH", transforms["transpose_reshape"],
|
|
229
|
+
(self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
|
|
230
|
+
self.hf_config.head_dim)),
|
|
231
|
+
"model.layers.*.self_attn.v_proj.weight":
|
|
232
|
+
("layers.*.attn.kernel_v_DKH", transforms["transpose_reshape"],
|
|
233
|
+
(self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
|
|
234
|
+
self.hf_config.head_dim)),
|
|
235
|
+
"model.layers.*.self_attn.o_proj.weight":
|
|
236
|
+
("layers.*.attn.kernel_o_proj_NHD",
|
|
237
|
+
transforms["transpose_reshape"],
|
|
238
|
+
(self.hf_config.num_attention_heads, self.hf_config.head_dim,
|
|
239
|
+
self.hf_config.hidden_size)),
|
|
240
|
+
|
|
241
|
+
# Attention Biases
|
|
242
|
+
"model.layers.*.self_attn.q_proj.bias":
|
|
243
|
+
("layers.*.attn.bias_q_NH", transforms["reshape"],
|
|
244
|
+
(self.hf_config.num_attention_heads, self.hf_config.head_dim)),
|
|
245
|
+
"model.layers.*.self_attn.k_proj.bias":
|
|
246
|
+
("layers.*.attn.bias_k_KH", transforms["reshape"],
|
|
247
|
+
(self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
|
|
248
|
+
"model.layers.*.self_attn.v_proj.bias":
|
|
249
|
+
("layers.*.attn.bias_v_KH", transforms["reshape"],
|
|
250
|
+
(self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
|
|
251
|
+
"model.layers.*.self_attn.o_proj.bias": ("layers.*.attn.bias_o_D",
|
|
252
|
+
None, None),
|
|
253
|
+
|
|
254
|
+
# Sinks
|
|
255
|
+
"model.layers.*.self_attn.sinks": ("layers.*.attn.sinks_N", None,
|
|
256
|
+
None),
|
|
257
|
+
|
|
258
|
+
# MoE Weights
|
|
259
|
+
"model.layers.*.mlp.router.weight":
|
|
260
|
+
("layers.*.custom_module.router.kernel_DE",
|
|
261
|
+
transforms["transpose"], None),
|
|
262
|
+
"model.layers.*.mlp.router.bias":
|
|
263
|
+
("layers.*.custom_module.router.bias_E", None, None),
|
|
264
|
+
"model.layers.*.mlp.experts.gate_up_proj":
|
|
265
|
+
("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform,
|
|
266
|
+
None),
|
|
267
|
+
"model.layers.*.mlp.experts.gate_up_proj_bias":
|
|
268
|
+
("layers.*.custom_module.mlp1_bias_EF2", None, None),
|
|
269
|
+
"model.layers.*.mlp.experts.down_proj":
|
|
270
|
+
("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform,
|
|
271
|
+
None),
|
|
272
|
+
"model.layers.*.mlp.experts.down_proj_bias":
|
|
273
|
+
("layers.*.custom_module.mlp2_bias_ED", None, None),
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
model_params = nnx.state(self)
|
|
277
|
+
is_verbose = self.vllm_config.additional_config.get(
|
|
278
|
+
"is_verbose", False)
|
|
279
|
+
|
|
280
|
+
names_and_weights_generator = model_weights_generator(
|
|
281
|
+
model_name_or_path=self.vllm_config.model_config.model,
|
|
282
|
+
framework="pt",
|
|
283
|
+
download_dir=self.vllm_config.load_config.download_dir)
|
|
284
|
+
|
|
285
|
+
# Build a pool of weights with MXFP4 experts combined if neededs
|
|
286
|
+
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
|
|
287
|
+
names_and_weights_generator,
|
|
288
|
+
mappings) if quant_method == MXFP4_QUANT_METHOD else {
|
|
289
|
+
loaded_name: loaded_weight
|
|
290
|
+
for loaded_name, loaded_weight in names_and_weights_generator
|
|
291
|
+
})
|
|
292
|
+
|
|
293
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
294
|
+
for loaded_name, loaded_weight in pool.items():
|
|
295
|
+
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
|
|
296
|
+
if hf_pattern not in mappings:
|
|
297
|
+
logger.warning(
|
|
298
|
+
f"No mapping found for checkpoint tensor: {loaded_name}. Skipping."
|
|
299
|
+
)
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
jax_path_template, transform_fn, target_shape = mappings[
|
|
303
|
+
hf_pattern]
|
|
304
|
+
|
|
305
|
+
layer_num_match = re.search(r"layers\.(\d+)", loaded_name)
|
|
306
|
+
jax_path = jax_path_template
|
|
307
|
+
if layer_num_match:
|
|
308
|
+
jax_path = jax_path_template.replace(
|
|
309
|
+
"*", layer_num_match.group(1))
|
|
310
|
+
|
|
311
|
+
model_weight = get_param(model_params, jax_path)
|
|
312
|
+
|
|
313
|
+
prepared_weight = loaded_weight
|
|
314
|
+
if isinstance(loaded_weight, tuple):
|
|
315
|
+
# Loaded weight is an MXFP4 tuple
|
|
316
|
+
blocks_u8, scales_u8 = loaded_weight
|
|
317
|
+
# Quantized param (QArray): set qvalue/scale directly and skip regular path
|
|
318
|
+
if hasattr(model_weight, "array"): # QArray check
|
|
319
|
+
codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
|
|
320
|
+
blocks_u8, scales_u8)
|
|
321
|
+
self._load_mxfp4(
|
|
322
|
+
model_weight=model_weight,
|
|
323
|
+
codes_fp32_t=codes_fp32_t,
|
|
324
|
+
scales_fp32_t=scales_fp32_t,
|
|
325
|
+
transform_fn=transform_fn,
|
|
326
|
+
)
|
|
327
|
+
if is_verbose:
|
|
328
|
+
print_param_info(model_weight, loaded_name)
|
|
329
|
+
continue
|
|
330
|
+
# Not a QArray: dequantize MXFP4 to BF16 full weights
|
|
331
|
+
prepared_weight = dequant_mxfp4_to_bf16(
|
|
332
|
+
blocks_u8, scales_u8)
|
|
333
|
+
|
|
334
|
+
# Single regular-tensor load call (BF16 or dequantized MXFP4)
|
|
335
|
+
cast_type = model_weight.value.dtype
|
|
336
|
+
self._load_regular_param(
|
|
337
|
+
model_weight=model_weight,
|
|
338
|
+
loaded_weight=prepared_weight,
|
|
339
|
+
cast_type=cast_type,
|
|
340
|
+
transform_fn=transform_fn,
|
|
341
|
+
target_shape=target_shape,
|
|
342
|
+
jax_path_template=jax_path_template,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
if is_verbose:
|
|
346
|
+
print_param_info(model_weight, loaded_name)
|
|
347
|
+
|
|
348
|
+
nnx.update(self, model_params)
|
|
349
|
+
|
|
350
|
+
def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
|
|
351
|
+
"""Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
|
|
352
|
+
|
|
353
|
+
Combines *_blocks and *_scales pairs and stores uint8 tensors together.
|
|
354
|
+
Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
|
|
355
|
+
"""
|
|
356
|
+
pool: dict[str, torch.Tensor | tuple] = {}
|
|
357
|
+
pending_experts: dict[str, dict[str, torch.Tensor]] = {}
|
|
358
|
+
for loaded_name, loaded_weight in names_and_weights_generator:
|
|
359
|
+
if loaded_name.endswith("_blocks") or loaded_name.endswith(
|
|
360
|
+
"_scales"):
|
|
361
|
+
base = loaded_name[:-7]
|
|
362
|
+
entry = pending_experts.setdefault(base, {})
|
|
363
|
+
if loaded_name.endswith("_blocks"):
|
|
364
|
+
entry["blocks"] = loaded_weight
|
|
365
|
+
else:
|
|
366
|
+
entry["scales"] = loaded_weight
|
|
367
|
+
|
|
368
|
+
# If we have both parts, place raw pair into the main pool
|
|
369
|
+
if "blocks" in entry and "scales" in entry:
|
|
370
|
+
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
|
|
371
|
+
if hf_pattern not in mappings:
|
|
372
|
+
raise ValueError(
|
|
373
|
+
f"No mapping found for expert tensor: {base}")
|
|
374
|
+
pool[base] = (entry["blocks"], entry["scales"])
|
|
375
|
+
# Remove from pending to free memory
|
|
376
|
+
pending_experts.pop(base, None)
|
|
377
|
+
else:
|
|
378
|
+
pool[loaded_name] = loaded_weight
|
|
379
|
+
|
|
380
|
+
# Enforce completeness of expert bundles
|
|
381
|
+
if pending_experts:
|
|
382
|
+
details = []
|
|
383
|
+
for base, entry in pending_experts.items():
|
|
384
|
+
missing = [k for k in ("blocks", "scales") if k not in entry]
|
|
385
|
+
details.append(
|
|
386
|
+
f"{base} (missing: {', '.join(missing) if missing else 'unknown'})"
|
|
387
|
+
)
|
|
388
|
+
raise RuntimeError(
|
|
389
|
+
"Incomplete MXFP4 expert bundle(s) encountered: " +
|
|
390
|
+
", ".join(details))
|
|
391
|
+
return pool
|
|
392
|
+
|
|
393
|
+
def _load_mxfp4(self,
|
|
394
|
+
model_weight,
|
|
395
|
+
codes_fp32_t,
|
|
396
|
+
scales_fp32_t,
|
|
397
|
+
transform_fn=None):
|
|
398
|
+
"""Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
|
|
399
|
+
|
|
400
|
+
qv = model_weight.array.qvalue
|
|
401
|
+
sv = model_weight.array.scale
|
|
402
|
+
q_dtype = qv.value.dtype
|
|
403
|
+
s_dtype = sv.value.dtype
|
|
404
|
+
|
|
405
|
+
exp_q_shape = tuple(qv.value.shape)
|
|
406
|
+
exp_s_shape = tuple(sv.value.shape)
|
|
407
|
+
|
|
408
|
+
# Apply optional transform (e.g., swap last two dims) before conversion
|
|
409
|
+
if transform_fn is not None:
|
|
410
|
+
codes_fp32_t = transform_fn(codes_fp32_t, None)
|
|
411
|
+
scales_fp32_t = transform_fn(scales_fp32_t, None)
|
|
412
|
+
|
|
413
|
+
# Convert from torch.Tensor to numpy before creating JAX arrays
|
|
414
|
+
codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
|
|
415
|
+
scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
|
|
416
|
+
|
|
417
|
+
codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
|
|
418
|
+
scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
|
|
419
|
+
|
|
420
|
+
def get_q_slice(index):
|
|
421
|
+
return codes_jnp[index]
|
|
422
|
+
|
|
423
|
+
def get_s_slice(index):
|
|
424
|
+
return scales_jnp[index]
|
|
425
|
+
|
|
426
|
+
q_sharded = jax.make_array_from_callback(
|
|
427
|
+
exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)),
|
|
428
|
+
get_q_slice)
|
|
429
|
+
s_sharded = jax.make_array_from_callback(
|
|
430
|
+
exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)),
|
|
431
|
+
get_s_slice)
|
|
432
|
+
|
|
433
|
+
model_weight.array.qvalue.value = q_sharded
|
|
434
|
+
model_weight.array.scale.value = s_sharded
|
|
435
|
+
|
|
436
|
+
def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor,
|
|
437
|
+
cast_type, transform_fn, target_shape,
|
|
438
|
+
jax_path_template: str):
|
|
439
|
+
"""Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
|
|
440
|
+
if jax_path_template == "layers.*.attn.sinks_N":
|
|
441
|
+
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
|
|
442
|
+
weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
|
|
443
|
+
else:
|
|
444
|
+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
|
|
445
|
+
if torch_view_type:
|
|
446
|
+
weight_np = jnp.array(
|
|
447
|
+
loaded_weight.view(torch_view_type).numpy()).view(
|
|
448
|
+
cast_type)
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
f"Unsupported dtype for tensor conversion: {cast_type}")
|
|
452
|
+
|
|
453
|
+
transformed_weight = transform_fn(
|
|
454
|
+
weight_np, target_shape) if transform_fn else weight_np
|
|
455
|
+
|
|
456
|
+
if model_weight.value.shape != transformed_weight.shape:
|
|
457
|
+
raise ValueError(
|
|
458
|
+
f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform."
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
def get_slice(index):
|
|
462
|
+
return transformed_weight[index]
|
|
463
|
+
|
|
464
|
+
sharded_array = jax.make_array_from_callback(
|
|
465
|
+
transformed_weight.shape,
|
|
466
|
+
NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice)
|
|
467
|
+
model_weight.value = sharded_array
|
|
468
|
+
|
|
469
|
+
def __call__(
|
|
470
|
+
self,
|
|
471
|
+
kv_caches: List[jax.Array],
|
|
472
|
+
input_ids: jax.Array,
|
|
473
|
+
attention_metadata: AttentionMetadata,
|
|
474
|
+
*args,
|
|
475
|
+
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
|
|
476
|
+
is_prefill = False
|
|
477
|
+
x = self.embedder.encode(input_ids)
|
|
478
|
+
|
|
479
|
+
for i, block in enumerate(self.layers):
|
|
480
|
+
kv_cache = kv_caches[i]
|
|
481
|
+
current_sliding_window = self.sliding_window if i % 2 == 0 else None
|
|
482
|
+
attention_metadata.sliding_window = current_sliding_window
|
|
483
|
+
|
|
484
|
+
new_kv_cache, x = block(x, is_prefill, kv_cache,
|
|
485
|
+
attention_metadata)
|
|
486
|
+
kv_caches[i] = new_kv_cache
|
|
487
|
+
|
|
488
|
+
final_activation = self.final_norm(x)
|
|
489
|
+
return kv_caches, final_activation, []
|
|
490
|
+
|
|
491
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
492
|
+
return self.lm_head.decode(hidden_states)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, Union
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax.tree_util import register_pytree_node_class
|
|
6
|
+
from torchax.interop import jax_view, torch_view
|
|
7
|
+
from vllm.sequence import IntermediateTensors
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
11
|
+
KVConnectorOutput
|
|
12
|
+
else:
|
|
13
|
+
KVConnectorOutput = Any
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_pytree_node_class
|
|
17
|
+
@dataclass
|
|
18
|
+
class JaxIntermediateTensors:
|
|
19
|
+
"""For all pipeline stages except the last, we need to return the
|
|
20
|
+
intermediate tensor which is the hidden states (and residuals) to be
|
|
21
|
+
sent to the next stage. This data structure contains the
|
|
22
|
+
intermediate tensor for a request.
|
|
23
|
+
|
|
24
|
+
There is a PyTorch IntermediateTensors (in vllm/sequence.py) class in vllm
|
|
25
|
+
for the same purpose.
|
|
26
|
+
|
|
27
|
+
Each stage also needs to handle its own kv_connector_output.
|
|
28
|
+
|
|
29
|
+
This class also contains the from_torch and to_torch functions, the goal is
|
|
30
|
+
to convert between pytorch's intermediate tensor
|
|
31
|
+
and Jax's intermediate tensor in torchax path.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
tensors: Dict[str, Any]
|
|
35
|
+
kv_connector_output: KVConnectorOutput = None
|
|
36
|
+
|
|
37
|
+
def tree_flatten(self):
|
|
38
|
+
children = (self.tensors, )
|
|
39
|
+
aux_data = self.kv_connector_output
|
|
40
|
+
return (children, aux_data)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def tree_unflatten(cls, aux_data, children):
|
|
44
|
+
return cls(children[0], aux_data)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_torch(cls, torch_obj: IntermediateTensors):
|
|
48
|
+
kv_connector_output = getattr(torch_obj, 'kv_connector_output', None)
|
|
49
|
+
jax_tensors = {k: jax_view(v) for k, v in torch_obj.tensors.items()}
|
|
50
|
+
return cls(jax_tensors, kv_connector_output)
|
|
51
|
+
|
|
52
|
+
def to_torch(self) -> IntermediateTensors:
|
|
53
|
+
torch_tensors = {k: torch_view(v) for k, v in self.tensors.items()}
|
|
54
|
+
return IntermediateTensors(torch_tensors)
|
|
55
|
+
|
|
56
|
+
def __getitem__(self, key: Union[str, slice]):
|
|
57
|
+
if isinstance(key, str):
|
|
58
|
+
return self.tensors[key]
|
|
59
|
+
elif isinstance(key, slice):
|
|
60
|
+
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
|
61
|
+
|
|
62
|
+
def __setitem__(self, key: str, value: Any):
|
|
63
|
+
self.tensors[key] = value
|
|
64
|
+
|
|
65
|
+
def keys(self):
|
|
66
|
+
return self.tensors.keys()
|
|
67
|
+
|
|
68
|
+
def items(self):
|
|
69
|
+
return self.tensors.items()
|
|
70
|
+
|
|
71
|
+
def __len__(self):
|
|
72
|
+
return len(self.tensors)
|
|
73
|
+
|
|
74
|
+
def block_until_ready(self):
|
|
75
|
+
for tensor in self.tensors.values():
|
|
76
|
+
assert isinstance(
|
|
77
|
+
tensor, jax.Array
|
|
78
|
+
), "block_until_ready needs to be applied on jax arrays"
|
|
79
|
+
tensor.block_until_ready()
|