tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,508 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import torch
22
+ from flax import nnx
23
+ from flax.typing import PRNGKey
24
+ from jax.sharding import Mesh, NamedSharding
25
+ from jax.sharding import PartitionSpec as P
26
+ from vllm.config import VllmConfig
27
+
28
+ from tpu_inference.layers.common.quant_methods import MXFP4
29
+ from tpu_inference.layers.common.quantization import (
30
+ dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
31
+ from tpu_inference.layers.jax.attention.gpt_oss_attention import (
32
+ AttentionMetadata, GptOssAttention)
33
+ from tpu_inference.layers.jax.constants import KVCacheType
34
+ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
35
+ from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
36
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
37
+ from tpu_inference.logger import init_logger
38
+ from tpu_inference.models.jax.utils.weight_utils import (
39
+ get_param, model_weights_generator, print_param_info)
40
+
41
+ logger = init_logger(__name__)
42
+
43
+ # A map from JAX dtype to the corresponding PyTorch integer dtype for raw memory viewing.
44
+ DTYPE_VIEW_MAP = {
45
+ jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
46
+ jnp.dtype(jnp.bfloat16): torch.uint16,
47
+ jnp.dtype(jnp.float32): torch.uint32,
48
+ }
49
+
50
+
51
+ @dataclass
52
+ class GptOss(nnx.Module):
53
+ """
54
+ JAX implementation of the GPT-OSS model architecture.
55
+ """
56
+
57
+ def __init__(self,
58
+ vllm_config: VllmConfig,
59
+ rng: jax.Array,
60
+ mesh: Mesh,
61
+ force_random_weights: bool = False):
62
+ assert mesh is not None
63
+
64
+ self.vllm_config = vllm_config
65
+ self.hf_config = vllm_config.model_config.hf_config
66
+ self.rng = nnx.Rngs(rng)
67
+
68
+ num_layers: int = self.hf_config.num_hidden_layers
69
+ num_experts: int = self.hf_config.num_local_experts
70
+ vocab_size: int = self.hf_config.vocab_size
71
+ num_attention_heads: int = self.hf_config.num_attention_heads
72
+ num_key_value_heads: int = self.hf_config.num_key_value_heads
73
+ head_dim: int = self.hf_config.head_dim
74
+ hidden_size: int = self.hf_config.hidden_size
75
+ ffw_intermediate_size: int = self.hf_config.intermediate_size
76
+ num_experts_per_token: int = self.hf_config.num_experts_per_tok
77
+ rms_norm_eps: float = self.hf_config.rms_norm_eps
78
+ swiglu_limit: float = self.hf_config.swiglu_limit
79
+
80
+ rope_theta: float = self.hf_config.rope_theta
81
+ rope_scaling_factor: float = self.hf_config.rope_scaling["factor"]
82
+ rope_ntk_alpha: float = self.hf_config.rope_scaling["beta_slow"]
83
+ rope_ntk_beta: float = self.hf_config.rope_scaling["beta_fast"]
84
+ initial_context_length: int = self.hf_config.rope_scaling[
85
+ "original_max_position_embeddings"]
86
+
87
+ dtype: jnp.dtype = jnp.bfloat16
88
+
89
+ self.sliding_window = self.hf_config.sliding_window
90
+
91
+ self.random_init = force_random_weights or self.vllm_config.additional_config.get(
92
+ "random_weights", False)
93
+ self.mesh = mesh
94
+
95
+ self.embedder = Embedder(
96
+ vocab_size=vocab_size,
97
+ hidden_size=hidden_size,
98
+ dtype=dtype,
99
+ rngs=self.rng,
100
+ vd_sharding=P(('data', 'model'), None),
101
+ random_init=self.random_init,
102
+ )
103
+
104
+ self.layers = []
105
+ for i in range(num_layers):
106
+ attn = GptOssAttention(
107
+ hidden_size=hidden_size,
108
+ num_attention_heads=num_attention_heads,
109
+ num_key_value_heads=num_key_value_heads,
110
+ head_dim=head_dim,
111
+ dtype=dtype,
112
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
113
+ rope_theta=rope_theta,
114
+ initial_context_length=initial_context_length,
115
+ rope_scaling_factor=rope_scaling_factor,
116
+ rope_ntk_alpha=rope_ntk_alpha,
117
+ rope_ntk_beta=rope_ntk_beta,
118
+ rngs=self.rng,
119
+ random_init=self.random_init,
120
+ query_tnh=P("data", 'model', None),
121
+ keyvalue_skh=P("data", 'model', None),
122
+ attn_o_tnh=P("data", 'model', None),
123
+ dnh_sharding=P(None, 'model', None),
124
+ dkh_sharding=P(None, 'model', None),
125
+ nhd_sharding=P('model', None, None),
126
+ mesh=self.mesh,
127
+ )
128
+
129
+ # MoE MLP block
130
+ router = GptOssRouter(
131
+ hidden_size=hidden_size,
132
+ num_experts=num_experts,
133
+ num_experts_per_tok=num_experts_per_token,
134
+ rngs=self.rng,
135
+ dtype=dtype,
136
+ router_act='softmax',
137
+ random_init=self.random_init,
138
+ activation_ffw_td=P('data', None),
139
+ ed_sharding=P('model', None),
140
+ e_sharding=P('model'),
141
+ )
142
+
143
+ moe_mlp = GptOssMoE(
144
+ dtype=dtype,
145
+ num_local_experts=num_experts,
146
+ hidden_size=hidden_size,
147
+ intermediate_size_moe=ffw_intermediate_size,
148
+ rngs=self.rng,
149
+ random_init=self.random_init,
150
+ router=router,
151
+ swiglu_limit=swiglu_limit,
152
+ # Sharding configuration
153
+ activation_ffw_td=P('data', None),
154
+ edf_sharding=P('model', None, None),
155
+ efd_sharding=P('model', None, None),
156
+ ed_sharding=P('model', None),
157
+ )
158
+
159
+ block = TransformerBlock(
160
+ pre_attention_norm=RMSNorm(
161
+ dims=hidden_size,
162
+ random_init=self.random_init,
163
+ epsilon=rms_norm_eps,
164
+ dtype=dtype,
165
+ rngs=self.rng,
166
+ activation_ffw_td=P('data', None),
167
+ ),
168
+ pre_mlp_norm=RMSNorm(
169
+ dims=hidden_size,
170
+ random_init=self.random_init,
171
+ epsilon=rms_norm_eps,
172
+ dtype=dtype,
173
+ rngs=self.rng,
174
+ activation_ffw_td=P('data', None),
175
+ ),
176
+ attn=attn,
177
+ custom_module=moe_mlp,
178
+ )
179
+ self.layers.append(block)
180
+ # Note: ALL RMSNorm does not upcast input to float32, while the pytorch does
181
+ self.final_norm = RMSNorm(
182
+ dims=hidden_size,
183
+ rngs=self.rng,
184
+ random_init=self.random_init,
185
+ epsilon=rms_norm_eps,
186
+ dtype=dtype,
187
+ activation_ffw_td=P('data', None),
188
+ )
189
+
190
+ self.lm_head = LMhead(
191
+ vocab_size=vocab_size,
192
+ hidden_size=hidden_size,
193
+ dtype=dtype,
194
+ rngs=self.rng,
195
+ vd_sharding=P(('data', 'model'), None),
196
+ dv_sharding=P(None, ('data', 'model')),
197
+ random_init=self.random_init,
198
+ )
199
+
200
+ # For compatibility with flax.
201
+ def apply(self, variables, *args, **kwargs):
202
+ return self.__call__(*args, **kwargs)
203
+
204
+ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
205
+ """Loads and transforms all weights from a checkpoint"""
206
+ self.rng = nnx.Rngs(rng)
207
+
208
+ # Determine quantization method from HF config (config.json)
209
+ quant_method = (self.hf_config.quantization_config["quant_method"]
210
+ if hasattr(self.hf_config, "quantization_config") else
211
+ None)
212
+
213
+ # Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
214
+ transforms = {
215
+ "transpose_reshape": lambda w, shape: w.T.reshape(shape),
216
+ "reshape": lambda b, shape: b.reshape(shape),
217
+ "transpose": lambda w, _: w.T,
218
+ "swap_last2": lambda w, _: w.swapaxes(-1, -2),
219
+ }
220
+
221
+ # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
222
+ swap_mlp_transform = transforms[
223
+ "swap_last2"] if quant_method == MXFP4 else None
224
+
225
+ mappings = {
226
+ # Embeddings, Norms, and LM Head
227
+ "model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
228
+ None, None),
229
+ "lm_head.weight": ("lm_head.input_embedding_table_DV",
230
+ transforms["transpose"], None),
231
+ "model.norm.weight": ("final_norm.scale", None, None),
232
+ "model.layers.*.input_layernorm.weight":
233
+ ("layers.*.pre_attention_norm.scale", None, None),
234
+ "model.layers.*.post_attention_layernorm.weight":
235
+ ("layers.*.pre_mlp_norm.scale", None, None),
236
+
237
+ # Attention Weights
238
+ "model.layers.*.self_attn.q_proj.weight":
239
+ ("layers.*.attn.kernel_q_DNH", transforms["transpose_reshape"],
240
+ (self.hf_config.hidden_size, self.hf_config.num_attention_heads,
241
+ self.hf_config.head_dim)),
242
+ "model.layers.*.self_attn.k_proj.weight":
243
+ ("layers.*.attn.kernel_k_DKH", transforms["transpose_reshape"],
244
+ (self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
245
+ self.hf_config.head_dim)),
246
+ "model.layers.*.self_attn.v_proj.weight":
247
+ ("layers.*.attn.kernel_v_DKH", transforms["transpose_reshape"],
248
+ (self.hf_config.hidden_size, self.hf_config.num_key_value_heads,
249
+ self.hf_config.head_dim)),
250
+ "model.layers.*.self_attn.o_proj.weight":
251
+ ("layers.*.attn.kernel_o_proj_NHD",
252
+ transforms["transpose_reshape"],
253
+ (self.hf_config.num_attention_heads, self.hf_config.head_dim,
254
+ self.hf_config.hidden_size)),
255
+
256
+ # Attention Biases
257
+ "model.layers.*.self_attn.q_proj.bias":
258
+ ("layers.*.attn.bias_q_NH", transforms["reshape"],
259
+ (self.hf_config.num_attention_heads, self.hf_config.head_dim)),
260
+ "model.layers.*.self_attn.k_proj.bias":
261
+ ("layers.*.attn.bias_k_KH", transforms["reshape"],
262
+ (self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
263
+ "model.layers.*.self_attn.v_proj.bias":
264
+ ("layers.*.attn.bias_v_KH", transforms["reshape"],
265
+ (self.hf_config.num_key_value_heads, self.hf_config.head_dim)),
266
+ "model.layers.*.self_attn.o_proj.bias": ("layers.*.attn.bias_o_D",
267
+ None, None),
268
+
269
+ # Sinks
270
+ "model.layers.*.self_attn.sinks": ("layers.*.attn.sinks_N", None,
271
+ None),
272
+
273
+ # MoE Weights
274
+ "model.layers.*.mlp.router.weight":
275
+ ("layers.*.custom_module.router.kernel_DE",
276
+ transforms["transpose"], None),
277
+ "model.layers.*.mlp.router.bias":
278
+ ("layers.*.custom_module.router.bias_E", None, None),
279
+ "model.layers.*.mlp.experts.gate_up_proj":
280
+ ("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform,
281
+ None),
282
+ "model.layers.*.mlp.experts.gate_up_proj_bias":
283
+ ("layers.*.custom_module.mlp1_bias_EF2", None, None),
284
+ "model.layers.*.mlp.experts.down_proj":
285
+ ("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform,
286
+ None),
287
+ "model.layers.*.mlp.experts.down_proj_bias":
288
+ ("layers.*.custom_module.mlp2_bias_ED", None, None),
289
+ }
290
+
291
+ model_params = nnx.state(self)
292
+ is_verbose = self.vllm_config.additional_config.get(
293
+ "is_verbose", False)
294
+
295
+ names_and_weights_generator = model_weights_generator(
296
+ model_name_or_path=self.vllm_config.model_config.model,
297
+ framework="pt",
298
+ download_dir=self.vllm_config.load_config.download_dir)
299
+
300
+ # Build a pool of weights with MXFP4 experts combined if neededs
301
+ pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
302
+ names_and_weights_generator,
303
+ mappings) if quant_method == MXFP4 else {
304
+ loaded_name: loaded_weight
305
+ for loaded_name, loaded_weight in names_and_weights_generator
306
+ })
307
+
308
+ with jax.default_device(jax.devices("cpu")[0]):
309
+ for loaded_name, loaded_weight in pool.items():
310
+ hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
311
+ if hf_pattern not in mappings:
312
+ logger.warning(
313
+ f"No mapping found for checkpoint tensor: {loaded_name}. Skipping."
314
+ )
315
+ continue
316
+
317
+ jax_path_template, transform_fn, target_shape = mappings[
318
+ hf_pattern]
319
+
320
+ layer_num_match = re.search(r"layers\.(\d+)", loaded_name)
321
+ jax_path = jax_path_template
322
+ if layer_num_match:
323
+ jax_path = jax_path_template.replace(
324
+ "*", layer_num_match.group(1))
325
+
326
+ model_weight = get_param(model_params, jax_path)
327
+
328
+ prepared_weight = loaded_weight
329
+ if isinstance(loaded_weight, tuple):
330
+ # Loaded weight is an MXFP4 tuple
331
+ blocks_u8, scales_u8 = loaded_weight
332
+ # Quantized param (QArray): set qvalue/scale directly and skip regular path
333
+ if hasattr(model_weight, "array"): # QArray check
334
+ codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
335
+ jnp.float32)
336
+ scales_fp32_t = e8m0_to_fp32(scales_u8)
337
+ self._load_mxfp4(
338
+ model_weight=model_weight,
339
+ codes_fp32_t=codes_fp32_t,
340
+ scales_fp32_t=scales_fp32_t,
341
+ transform_fn=transform_fn,
342
+ )
343
+ if is_verbose:
344
+ print_param_info(model_weight, loaded_name)
345
+ continue
346
+ # Not a QArray: dequantize MXFP4 to BF16 full weights
347
+ prepared_weight = dequantize_tensor_from_mxfp4_packed(
348
+ blocks_u8, scales_u8)
349
+
350
+ # Single regular-tensor load call (BF16 or dequantized MXFP4)
351
+ cast_type = model_weight.value.dtype
352
+ self._load_regular_param(
353
+ model_weight=model_weight,
354
+ loaded_weight=prepared_weight,
355
+ cast_type=cast_type,
356
+ transform_fn=transform_fn,
357
+ target_shape=target_shape,
358
+ jax_path_template=jax_path_template,
359
+ )
360
+
361
+ if is_verbose:
362
+ print_param_info(model_weight, loaded_name)
363
+
364
+ nnx.update(self, model_params)
365
+
366
+ def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
367
+ """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
368
+
369
+ Combines *_blocks and *_scales pairs and stores uint8 tensors together.
370
+ Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
371
+ """
372
+ pool: dict[str, torch.Tensor | tuple] = {}
373
+ pending_experts: dict[str, dict[str, torch.Tensor]] = {}
374
+ for loaded_name, loaded_weight in names_and_weights_generator:
375
+ if loaded_name.endswith("_blocks") or loaded_name.endswith(
376
+ "_scales"):
377
+ base = loaded_name[:-7]
378
+ entry = pending_experts.setdefault(base, {})
379
+ if loaded_name.endswith("_blocks"):
380
+ entry["blocks"] = loaded_weight
381
+ else:
382
+ entry["scales"] = loaded_weight
383
+
384
+ # If we have both parts, place raw pair into the main pool
385
+ if "blocks" in entry and "scales" in entry:
386
+ hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
387
+ if hf_pattern not in mappings:
388
+ raise ValueError(
389
+ f"No mapping found for expert tensor: {base}")
390
+ pool[base] = (entry["blocks"], entry["scales"])
391
+ # Remove from pending to free memory
392
+ pending_experts.pop(base, None)
393
+ else:
394
+ pool[loaded_name] = loaded_weight
395
+
396
+ # Enforce completeness of expert bundles
397
+ if pending_experts:
398
+ details = []
399
+ for base, entry in pending_experts.items():
400
+ missing = [k for k in ("blocks", "scales") if k not in entry]
401
+ details.append(
402
+ f"{base} (missing: {', '.join(missing) if missing else 'unknown'})"
403
+ )
404
+ raise RuntimeError(
405
+ "Incomplete MXFP4 expert bundle(s) encountered: " +
406
+ ", ".join(details))
407
+ return pool
408
+
409
+ def _load_mxfp4(self,
410
+ model_weight,
411
+ codes_fp32_t,
412
+ scales_fp32_t,
413
+ transform_fn=None):
414
+ """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
415
+
416
+ qv = model_weight.array.qvalue
417
+ sv = model_weight.array.scale
418
+ q_dtype = qv.value.dtype
419
+ s_dtype = sv.value.dtype
420
+
421
+ exp_q_shape = tuple(qv.value.shape)
422
+ exp_s_shape = tuple(sv.value.shape)
423
+
424
+ # Apply optional transform (e.g., swap last two dims) before conversion
425
+ if transform_fn is not None:
426
+ codes_fp32_t = transform_fn(codes_fp32_t, None)
427
+ scales_fp32_t = transform_fn(scales_fp32_t, None)
428
+
429
+ # Convert from torch.Tensor to numpy before creating JAX arrays
430
+ codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
431
+ scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
432
+
433
+ codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
434
+ scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
435
+
436
+ def get_q_slice(index):
437
+ return codes_jnp[index]
438
+
439
+ def get_s_slice(index):
440
+ return scales_jnp[index]
441
+
442
+ q_sharded = jax.make_array_from_callback(
443
+ exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)),
444
+ get_q_slice)
445
+ s_sharded = jax.make_array_from_callback(
446
+ exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)),
447
+ get_s_slice)
448
+
449
+ model_weight.array.qvalue.value = q_sharded
450
+ model_weight.array.scale.value = s_sharded
451
+
452
+ def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor,
453
+ cast_type, transform_fn, target_shape,
454
+ jax_path_template: str):
455
+ """Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
456
+ if jax_path_template == "layers.*.attn.sinks_N":
457
+ # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
458
+ weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
459
+ else:
460
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
461
+ if torch_view_type:
462
+ weight_np = jnp.array(
463
+ loaded_weight.view(torch_view_type).numpy()).view(
464
+ cast_type)
465
+ else:
466
+ raise ValueError(
467
+ f"Unsupported dtype for tensor conversion: {cast_type}")
468
+
469
+ transformed_weight = transform_fn(
470
+ weight_np, target_shape) if transform_fn else weight_np
471
+
472
+ if model_weight.value.shape != transformed_weight.shape:
473
+ raise ValueError(
474
+ f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform."
475
+ )
476
+
477
+ def get_slice(index):
478
+ return transformed_weight[index]
479
+
480
+ sharded_array = jax.make_array_from_callback(
481
+ transformed_weight.shape,
482
+ NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice)
483
+ model_weight.value = sharded_array
484
+
485
+ def __call__(
486
+ self,
487
+ kv_caches: List[jax.Array],
488
+ input_ids: jax.Array,
489
+ attention_metadata: AttentionMetadata,
490
+ *args,
491
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
492
+ is_prefill = False
493
+ x = self.embedder.encode(input_ids)
494
+
495
+ for i, block in enumerate(self.layers):
496
+ kv_cache = kv_caches[i]
497
+ current_sliding_window = self.sliding_window if i % 2 == 0 else None
498
+ attention_metadata.sliding_window = current_sliding_window
499
+
500
+ new_kv_cache, x = block(x, is_prefill, kv_cache,
501
+ attention_metadata)
502
+ kv_caches[i] = new_kv_cache
503
+
504
+ final_activation = self.final_norm(x)
505
+ return kv_caches, final_activation, []
506
+
507
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
508
+ return self.lm_head.decode(hidden_states)
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Any, Dict, Union
17
+
18
+ import jax
19
+ from jax.tree_util import register_pytree_node_class
20
+ from torchax.interop import jax_view, torch_view
21
+ from vllm.sequence import IntermediateTensors
22
+
23
+ if TYPE_CHECKING:
24
+ from vllm.v1.worker.kv_connector_model_runner_mixin import \
25
+ KVConnectorOutput
26
+ else:
27
+ KVConnectorOutput = Any
28
+
29
+
30
+ @register_pytree_node_class
31
+ @dataclass
32
+ class JaxIntermediateTensors:
33
+ """For all pipeline stages except the last, we need to return the
34
+ intermediate tensor which is the hidden states (and residuals) to be
35
+ sent to the next stage. This data structure contains the
36
+ intermediate tensor for a request.
37
+
38
+ There is a PyTorch IntermediateTensors (in vllm/sequence.py) class in vllm
39
+ for the same purpose.
40
+
41
+ Each stage also needs to handle its own kv_connector_output.
42
+
43
+ This class also contains the from_torch and to_torch functions, the goal is
44
+ to convert between pytorch's intermediate tensor
45
+ and Jax's intermediate tensor in torchax path.
46
+ """
47
+
48
+ tensors: Dict[str, Any]
49
+ kv_connector_output: KVConnectorOutput = None
50
+
51
+ def tree_flatten(self):
52
+ children = (self.tensors, )
53
+ aux_data = self.kv_connector_output
54
+ return (children, aux_data)
55
+
56
+ @classmethod
57
+ def tree_unflatten(cls, aux_data, children):
58
+ return cls(children[0], aux_data)
59
+
60
+ @classmethod
61
+ def from_torch(cls, torch_obj: IntermediateTensors):
62
+ kv_connector_output = getattr(torch_obj, 'kv_connector_output', None)
63
+ jax_tensors = {k: jax_view(v) for k, v in torch_obj.tensors.items()}
64
+ return cls(jax_tensors, kv_connector_output)
65
+
66
+ def to_torch(self) -> IntermediateTensors:
67
+ torch_tensors = {k: torch_view(v) for k, v in self.tensors.items()}
68
+ return IntermediateTensors(torch_tensors)
69
+
70
+ def __getitem__(self, key: Union[str, slice]):
71
+ if isinstance(key, str):
72
+ return self.tensors[key]
73
+ elif isinstance(key, slice):
74
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
75
+
76
+ def __setitem__(self, key: str, value: Any):
77
+ self.tensors[key] = value
78
+
79
+ def keys(self):
80
+ return self.tensors.keys()
81
+
82
+ def items(self):
83
+ return self.tensors.items()
84
+
85
+ def __len__(self):
86
+ return len(self.tensors)
87
+
88
+ def block_until_ready(self):
89
+ for tensor in self.tensors.values():
90
+ assert isinstance(
91
+ tensor, jax.Array
92
+ ), "block_until_ready needs to be applied on jax arrays"
93
+ tensor.block_until_ready()