tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,492 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import torch
8
+ from flax import nnx
9
+ from flax.typing import PRNGKey
10
+ from jax.sharding import Mesh, NamedSharding
11
+ from jax.sharding import PartitionSpec as P
12
+ from vllm.config import VllmConfig
13
+
14
+ from tpu_inference.layers.jax.attention.gpt_oss_attention import (
15
+ AttentionMetadata, GptOssAttention)
16
+ from tpu_inference.layers.jax.constants import KVCacheType
17
+ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
18
+ from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
19
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22
+ MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
23
+ from tpu_inference.models.jax.utils.weight_utils import (
24
+ get_param, model_weights_generator, print_param_info)
25
+
26
+ logger = init_logger(__name__)
27
+
28
+ # A map from JAX dtype to the corresponding PyTorch integer dtype for raw memory viewing.
29
+ DTYPE_VIEW_MAP = {
30
+ jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
31
+ jnp.dtype(jnp.bfloat16): torch.uint16,
32
+ jnp.dtype(jnp.float32): torch.uint32,
33
+ }
34
+
35
+
36
+ @dataclass
37
+ class GptOss(nnx.Module):
38
+ """
39
+ JAX implementation of the GPT-OSS model architecture.
40
+ """
41
+
42
+ def __init__(self,
43
+ vllm_config: VllmConfig,
44
+ rng: jax.Array,
45
+ mesh: Mesh,
46
+ force_random_weights: bool = False):
47
+ assert mesh is not None
48
+
49
+ self.vllm_config = vllm_config
50
+ self.hf_config = vllm_config.model_config.hf_config
51
+ self.rng = nnx.Rngs(rng)
52
+
53
+ num_layers: int = self.hf_config.num_hidden_layers
54
+ num_experts: int = self.hf_config.num_local_experts
55
+ vocab_size: int = self.hf_config.vocab_size
56
+ num_attention_heads: int = self.hf_config.num_attention_heads
57
+ num_key_value_heads: int = self.hf_config.num_key_value_heads
58
+ head_dim: int = self.hf_config.head_dim
59
+ hidden_size: int = self.hf_config.hidden_size
60
+ ffw_intermediate_size: int = self.hf_config.intermediate_size
61
+ num_experts_per_token: int = self.hf_config.num_experts_per_tok
62
+ rms_norm_eps: float = self.hf_config.rms_norm_eps
63
+ swiglu_limit: float = self.hf_config.swiglu_limit
64
+
65
+ rope_theta: float = self.hf_config.rope_theta
66
+ rope_scaling_factor: float = self.hf_config.rope_scaling["factor"]
67
+ rope_ntk_alpha: float = self.hf_config.rope_scaling["beta_slow"]
68
+ rope_ntk_beta: float = self.hf_config.rope_scaling["beta_fast"]
69
+ initial_context_length: int = self.hf_config.rope_scaling[
70
+ "original_max_position_embeddings"]
71
+
72
+ dtype: jnp.dtype = jnp.bfloat16
73
+
74
+ self.sliding_window = self.hf_config.sliding_window
75
+
76
+ self.random_init = force_random_weights or self.vllm_config.additional_config.get(
77
+ "random_weights", False)
78
+ self.mesh = mesh
79
+
80
+ self.embedder = Embedder(
81
+ vocab_size=vocab_size,
82
+ hidden_size=hidden_size,
83
+ dtype=dtype,
84
+ rngs=self.rng,
85
+ vd_sharding=P(('data', 'model'), None),
86
+ random_init=self.random_init,
87
+ )
88
+
89
+ self.layers = []
90
+ for i in range(num_layers):
91
+ attn = GptOssAttention(
92
+ hidden_size=hidden_size,
93
+ num_attention_heads=num_attention_heads,
94
+ num_key_value_heads=num_key_value_heads,
95
+ head_dim=head_dim,
96
+ dtype=dtype,
97
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
98
+ rope_theta=rope_theta,
99
+ initial_context_length=initial_context_length,
100
+ rope_scaling_factor=rope_scaling_factor,
101
+ rope_ntk_alpha=rope_ntk_alpha,
102
+ rope_ntk_beta=rope_ntk_beta,
103
+ rngs=self.rng,
104
+ random_init=self.random_init,
105
+ query_tnh=P(None, 'model', None),
106
+ keyvalue_skh=P(None, 'model', None),
107
+ attn_o_tnh=P(None, 'model', None),
108
+ dnh_sharding=P(None, 'model', None),
109
+ dkh_sharding=P(None, 'model', None),
110
+ nhd_sharding=P('model', None, None),
111
+ mesh=self.mesh,
112
+ )
113
+
114
+ # MoE MLP block
115
+ router = GptOssRouter(
116
+ hidden_size=hidden_size,
117
+ num_experts=num_experts,
118
+ num_experts_per_tok=num_experts_per_token,
119
+ rngs=self.rng,
120
+ dtype=dtype,
121
+ router_act='softmax',
122
+ random_init=self.random_init,
123
+ activation_ffw_td=P('data', None),
124
+ ed_sharding=P('model', None),
125
+ e_sharding=P('model'),
126
+ )
127
+
128
+ moe_mlp = GptOssMoE(
129
+ dtype=dtype,
130
+ num_local_experts=num_experts,
131
+ hidden_size=hidden_size,
132
+ intermediate_size_moe=ffw_intermediate_size,
133
+ rngs=self.rng,
134
+ random_init=self.random_init,
135
+ router=router,
136
+ swiglu_limit=swiglu_limit,
137
+ # Sharding configuration
138
+ activation_ffw_td=P('data', None),
139
+ edf_sharding=P('model', None, None),
140
+ efd_sharding=P('model', None, None),
141
+ ed_sharding=P('model', None),
142
+ )
143
+
144
+ block = TransformerBlock(
145
+ pre_attention_norm=RMSNorm(
146
+ dims=hidden_size,
147
+ random_init=self.random_init,
148
+ epsilon=rms_norm_eps,
149
+ dtype=dtype,
150
+ rngs=self.rng,
151
+ activation_ffw_td=P('data', None),
152
+ ),
153
+ pre_mlp_norm=RMSNorm(
154
+ dims=hidden_size,
155
+ random_init=self.random_init,
156
+ epsilon=rms_norm_eps,
157
+ dtype=dtype,
158
+ rngs=self.rng,
159
+ activation_ffw_td=P('data', None),
160
+ ),
161
+ attn=attn,
162
+ custom_module=moe_mlp,
163
+ )
164
+ self.layers.append(block)
165
+ # Note: ALL RMSNorm does not upcast input to float32, while the pytorch does
166
+ self.final_norm = RMSNorm(
167
+ dims=hidden_size,
168
+ rngs=self.rng,
169
+ random_init=self.random_init,
170
+ epsilon=rms_norm_eps,
171
+ dtype=dtype,
172
+ activation_ffw_td=P('data', None),
173
+ )
174
+
175
+ self.lm_head = LMhead(
176
+ vocab_size=vocab_size,
177
+ hidden_size=hidden_size,
178
+ dtype=dtype,
179
+ rngs=self.rng,
180
+ vd_sharding=P(('data', 'model'), None),
181
+ dv_sharding=P(None, ('data', 'model')),
182
+ random_init=self.random_init,
183
+ )
184
+
185
+ # For compatibility with flax.
186
+ def apply(self, variables, *args, **kwargs):
187
+ return self.__call__(*args, **kwargs)
188
+
189
+ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
190
+ """Loads and transforms all weights from a checkpoint"""
191
+ self.rng = nnx.Rngs(rng)
192
+
193
+ # Determine quantization method from HF config (config.json)
194
+ quant_method = (self.hf_config.quantization_config["quant_method"]
195
+ if hasattr(self.hf_config, "quantization_config") else
196
+ None)
197
+
198
+ # Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
199
+ transforms = {
200
+ "transpose_reshape": lambda w, shape: w.T.reshape(shape),
201
+ "reshape": lambda b, shape: b.reshape(shape),
202
+ "transpose": lambda w, _: w.T,
203
+ "swap_last2": lambda w, _: w.swapaxes(-1, -2),
204
+ }
205
+
206
+ # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
207
+ swap_mlp_transform = transforms[
208
+ "swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
209
+
210
+ mappings = {
211
+ # Embeddings, Norms, and LM Head
212
+ "model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
213
+ None, None),
214
+ "lm_head.weight": ("lm_head.input_embedding_table_DV",
215
+ transforms["transpose"], None),
216
+ "model.norm.weight": ("final_norm.scale", None, None),
217
+ "model.layers.*.input_layernorm.weight":
218
+ ("layers.*.pre_attention_norm.scale", None, None),
219
+ "model.layers.*.post_attention_layernorm.weight":
220
+ ("layers.*.pre_mlp_norm.scale", None, None),
221
+
222
+ # Attention Weights
223
+ "model.layers.*.self_attn.q_proj.weight":
224
+ ("layers.*.attn.kernel_q_DNH", transforms["transpose_reshape"],
225
+ (self.hf_config.hidden_size, self.hf_config.num_attention_heads,
226
+ self.hf_config.head_dim)),
227
+ "model.layers.*.self_attn.k_proj.weight":
228
+ ("layers.*.attn.kernel_k_DKH", transforms["transpose_reshape"],
229
+ (self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
230
+ self.hf_config.head_dim)),
231
+ "model.layers.*.self_attn.v_proj.weight":
232
+ ("layers.*.attn.kernel_v_DKH", transforms["transpose_reshape"],
233
+ (self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
234
+ self.hf_config.head_dim)),
235
+ "model.layers.*.self_attn.o_proj.weight":
236
+ ("layers.*.attn.kernel_o_proj_NHD",
237
+ transforms["transpose_reshape"],
238
+ (self.hf_config.num_attention_heads, self.hf_config.head_dim,
239
+ self.hf_config.hidden_size)),
240
+
241
+ # Attention Biases
242
+ "model.layers.*.self_attn.q_proj.bias":
243
+ ("layers.*.attn.bias_q_NH", transforms["reshape"],
244
+ (self.hf_config.num_attention_heads, self.hf_config.head_dim)),
245
+ "model.layers.*.self_attn.k_proj.bias":
246
+ ("layers.*.attn.bias_k_KH", transforms["reshape"],
247
+ (self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
248
+ "model.layers.*.self_attn.v_proj.bias":
249
+ ("layers.*.attn.bias_v_KH", transforms["reshape"],
250
+ (self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
251
+ "model.layers.*.self_attn.o_proj.bias": ("layers.*.attn.bias_o_D",
252
+ None, None),
253
+
254
+ # Sinks
255
+ "model.layers.*.self_attn.sinks": ("layers.*.attn.sinks_N", None,
256
+ None),
257
+
258
+ # MoE Weights
259
+ "model.layers.*.mlp.router.weight":
260
+ ("layers.*.custom_module.router.kernel_DE",
261
+ transforms["transpose"], None),
262
+ "model.layers.*.mlp.router.bias":
263
+ ("layers.*.custom_module.router.bias_E", None, None),
264
+ "model.layers.*.mlp.experts.gate_up_proj":
265
+ ("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform,
266
+ None),
267
+ "model.layers.*.mlp.experts.gate_up_proj_bias":
268
+ ("layers.*.custom_module.mlp1_bias_EF2", None, None),
269
+ "model.layers.*.mlp.experts.down_proj":
270
+ ("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform,
271
+ None),
272
+ "model.layers.*.mlp.experts.down_proj_bias":
273
+ ("layers.*.custom_module.mlp2_bias_ED", None, None),
274
+ }
275
+
276
+ model_params = nnx.state(self)
277
+ is_verbose = self.vllm_config.additional_config.get(
278
+ "is_verbose", False)
279
+
280
+ names_and_weights_generator = model_weights_generator(
281
+ model_name_or_path=self.vllm_config.model_config.model,
282
+ framework="pt",
283
+ download_dir=self.vllm_config.load_config.download_dir)
284
+
285
+ # Build a pool of weights with MXFP4 experts combined if neededs
286
+ pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
287
+ names_and_weights_generator,
288
+ mappings) if quant_method == MXFP4_QUANT_METHOD else {
289
+ loaded_name: loaded_weight
290
+ for loaded_name, loaded_weight in names_and_weights_generator
291
+ })
292
+
293
+ with jax.default_device(jax.devices("cpu")[0]):
294
+ for loaded_name, loaded_weight in pool.items():
295
+ hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
296
+ if hf_pattern not in mappings:
297
+ logger.warning(
298
+ f"No mapping found for checkpoint tensor: {loaded_name}. Skipping."
299
+ )
300
+ continue
301
+
302
+ jax_path_template, transform_fn, target_shape = mappings[
303
+ hf_pattern]
304
+
305
+ layer_num_match = re.search(r"layers\.(\d+)", loaded_name)
306
+ jax_path = jax_path_template
307
+ if layer_num_match:
308
+ jax_path = jax_path_template.replace(
309
+ "*", layer_num_match.group(1))
310
+
311
+ model_weight = get_param(model_params, jax_path)
312
+
313
+ prepared_weight = loaded_weight
314
+ if isinstance(loaded_weight, tuple):
315
+ # Loaded weight is an MXFP4 tuple
316
+ blocks_u8, scales_u8 = loaded_weight
317
+ # Quantized param (QArray): set qvalue/scale directly and skip regular path
318
+ if hasattr(model_weight, "array"): # QArray check
319
+ codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
320
+ blocks_u8, scales_u8)
321
+ self._load_mxfp4(
322
+ model_weight=model_weight,
323
+ codes_fp32_t=codes_fp32_t,
324
+ scales_fp32_t=scales_fp32_t,
325
+ transform_fn=transform_fn,
326
+ )
327
+ if is_verbose:
328
+ print_param_info(model_weight, loaded_name)
329
+ continue
330
+ # Not a QArray: dequantize MXFP4 to BF16 full weights
331
+ prepared_weight = dequant_mxfp4_to_bf16(
332
+ blocks_u8, scales_u8)
333
+
334
+ # Single regular-tensor load call (BF16 or dequantized MXFP4)
335
+ cast_type = model_weight.value.dtype
336
+ self._load_regular_param(
337
+ model_weight=model_weight,
338
+ loaded_weight=prepared_weight,
339
+ cast_type=cast_type,
340
+ transform_fn=transform_fn,
341
+ target_shape=target_shape,
342
+ jax_path_template=jax_path_template,
343
+ )
344
+
345
+ if is_verbose:
346
+ print_param_info(model_weight, loaded_name)
347
+
348
+ nnx.update(self, model_params)
349
+
350
+ def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
351
+ """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
352
+
353
+ Combines *_blocks and *_scales pairs and stores uint8 tensors together.
354
+ Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
355
+ """
356
+ pool: dict[str, torch.Tensor | tuple] = {}
357
+ pending_experts: dict[str, dict[str, torch.Tensor]] = {}
358
+ for loaded_name, loaded_weight in names_and_weights_generator:
359
+ if loaded_name.endswith("_blocks") or loaded_name.endswith(
360
+ "_scales"):
361
+ base = loaded_name[:-7]
362
+ entry = pending_experts.setdefault(base, {})
363
+ if loaded_name.endswith("_blocks"):
364
+ entry["blocks"] = loaded_weight
365
+ else:
366
+ entry["scales"] = loaded_weight
367
+
368
+ # If we have both parts, place raw pair into the main pool
369
+ if "blocks" in entry and "scales" in entry:
370
+ hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
371
+ if hf_pattern not in mappings:
372
+ raise ValueError(
373
+ f"No mapping found for expert tensor: {base}")
374
+ pool[base] = (entry["blocks"], entry["scales"])
375
+ # Remove from pending to free memory
376
+ pending_experts.pop(base, None)
377
+ else:
378
+ pool[loaded_name] = loaded_weight
379
+
380
+ # Enforce completeness of expert bundles
381
+ if pending_experts:
382
+ details = []
383
+ for base, entry in pending_experts.items():
384
+ missing = [k for k in ("blocks", "scales") if k not in entry]
385
+ details.append(
386
+ f"{base} (missing: {', '.join(missing) if missing else 'unknown'})"
387
+ )
388
+ raise RuntimeError(
389
+ "Incomplete MXFP4 expert bundle(s) encountered: " +
390
+ ", ".join(details))
391
+ return pool
392
+
393
+ def _load_mxfp4(self,
394
+ model_weight,
395
+ codes_fp32_t,
396
+ scales_fp32_t,
397
+ transform_fn=None):
398
+ """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
399
+
400
+ qv = model_weight.array.qvalue
401
+ sv = model_weight.array.scale
402
+ q_dtype = qv.value.dtype
403
+ s_dtype = sv.value.dtype
404
+
405
+ exp_q_shape = tuple(qv.value.shape)
406
+ exp_s_shape = tuple(sv.value.shape)
407
+
408
+ # Apply optional transform (e.g., swap last two dims) before conversion
409
+ if transform_fn is not None:
410
+ codes_fp32_t = transform_fn(codes_fp32_t, None)
411
+ scales_fp32_t = transform_fn(scales_fp32_t, None)
412
+
413
+ # Convert from torch.Tensor to numpy before creating JAX arrays
414
+ codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
415
+ scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
416
+
417
+ codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
418
+ scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
419
+
420
+ def get_q_slice(index):
421
+ return codes_jnp[index]
422
+
423
+ def get_s_slice(index):
424
+ return scales_jnp[index]
425
+
426
+ q_sharded = jax.make_array_from_callback(
427
+ exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)),
428
+ get_q_slice)
429
+ s_sharded = jax.make_array_from_callback(
430
+ exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)),
431
+ get_s_slice)
432
+
433
+ model_weight.array.qvalue.value = q_sharded
434
+ model_weight.array.scale.value = s_sharded
435
+
436
+ def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor,
437
+ cast_type, transform_fn, target_shape,
438
+ jax_path_template: str):
439
+ """Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
440
+ if jax_path_template == "layers.*.attn.sinks_N":
441
+ # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
442
+ weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
443
+ else:
444
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
445
+ if torch_view_type:
446
+ weight_np = jnp.array(
447
+ loaded_weight.view(torch_view_type).numpy()).view(
448
+ cast_type)
449
+ else:
450
+ raise ValueError(
451
+ f"Unsupported dtype for tensor conversion: {cast_type}")
452
+
453
+ transformed_weight = transform_fn(
454
+ weight_np, target_shape) if transform_fn else weight_np
455
+
456
+ if model_weight.value.shape != transformed_weight.shape:
457
+ raise ValueError(
458
+ f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform."
459
+ )
460
+
461
+ def get_slice(index):
462
+ return transformed_weight[index]
463
+
464
+ sharded_array = jax.make_array_from_callback(
465
+ transformed_weight.shape,
466
+ NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice)
467
+ model_weight.value = sharded_array
468
+
469
+ def __call__(
470
+ self,
471
+ kv_caches: List[jax.Array],
472
+ input_ids: jax.Array,
473
+ attention_metadata: AttentionMetadata,
474
+ *args,
475
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
476
+ is_prefill = False
477
+ x = self.embedder.encode(input_ids)
478
+
479
+ for i, block in enumerate(self.layers):
480
+ kv_cache = kv_caches[i]
481
+ current_sliding_window = self.sliding_window if i % 2 == 0 else None
482
+ attention_metadata.sliding_window = current_sliding_window
483
+
484
+ new_kv_cache, x = block(x, is_prefill, kv_cache,
485
+ attention_metadata)
486
+ kv_caches[i] = new_kv_cache
487
+
488
+ final_activation = self.final_norm(x)
489
+ return kv_caches, final_activation, []
490
+
491
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
492
+ return self.lm_head.decode(hidden_states)
@@ -0,0 +1,79 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Any, Dict, Union
3
+
4
+ import jax
5
+ from jax.tree_util import register_pytree_node_class
6
+ from torchax.interop import jax_view, torch_view
7
+ from vllm.sequence import IntermediateTensors
8
+
9
+ if TYPE_CHECKING:
10
+ from vllm.v1.worker.kv_connector_model_runner_mixin import \
11
+ KVConnectorOutput
12
+ else:
13
+ KVConnectorOutput = Any
14
+
15
+
16
+ @register_pytree_node_class
17
+ @dataclass
18
+ class JaxIntermediateTensors:
19
+ """For all pipeline stages except the last, we need to return the
20
+ intermediate tensor which is the hidden states (and residuals) to be
21
+ sent to the next stage. This data structure contains the
22
+ intermediate tensor for a request.
23
+
24
+ There is a PyTorch IntermediateTensors (in vllm/sequence.py) class in vllm
25
+ for the same purpose.
26
+
27
+ Each stage also needs to handle its own kv_connector_output.
28
+
29
+ This class also contains the from_torch and to_torch functions, the goal is
30
+ to convert between pytorch's intermediate tensor
31
+ and Jax's intermediate tensor in torchax path.
32
+ """
33
+
34
+ tensors: Dict[str, Any]
35
+ kv_connector_output: KVConnectorOutput = None
36
+
37
+ def tree_flatten(self):
38
+ children = (self.tensors, )
39
+ aux_data = self.kv_connector_output
40
+ return (children, aux_data)
41
+
42
+ @classmethod
43
+ def tree_unflatten(cls, aux_data, children):
44
+ return cls(children[0], aux_data)
45
+
46
+ @classmethod
47
+ def from_torch(cls, torch_obj: IntermediateTensors):
48
+ kv_connector_output = getattr(torch_obj, 'kv_connector_output', None)
49
+ jax_tensors = {k: jax_view(v) for k, v in torch_obj.tensors.items()}
50
+ return cls(jax_tensors, kv_connector_output)
51
+
52
+ def to_torch(self) -> IntermediateTensors:
53
+ torch_tensors = {k: torch_view(v) for k, v in self.tensors.items()}
54
+ return IntermediateTensors(torch_tensors)
55
+
56
+ def __getitem__(self, key: Union[str, slice]):
57
+ if isinstance(key, str):
58
+ return self.tensors[key]
59
+ elif isinstance(key, slice):
60
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
61
+
62
+ def __setitem__(self, key: str, value: Any):
63
+ self.tensors[key] = value
64
+
65
+ def keys(self):
66
+ return self.tensors.keys()
67
+
68
+ def items(self):
69
+ return self.tensors.items()
70
+
71
+ def __len__(self):
72
+ return len(self.tensors)
73
+
74
+ def block_until_ready(self):
75
+ for tensor in self.tensors.values():
76
+ assert isinstance(
77
+ tensor, jax.Array
78
+ ), "block_until_ready needs to be applied on jax arrays"
79
+ tensor.block_until_ready()