tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,401 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import SimpleNamespace
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import pytest
22
+ import torch
23
+ from flax import nnx
24
+ from jax.sharding import Mesh
25
+ from vllm.config import ModelConfig
26
+
27
+ # Assuming the model file is named deepseek_v3.py
28
+ from tpu_inference.models.jax.deepseek_v3 import (DeepSeekV3,
29
+ DeepSeekV3WeightLoader)
30
+
31
+
32
+ class MockVariable:
33
+ """Mocks an nnx.Variable or a QArray structure."""
34
+
35
+ def __init__(self, shape, dtype=jnp.bfloat16, sharding=None):
36
+ self.value = jnp.zeros(shape, dtype=dtype)
37
+ self.sharding = sharding or (None, ) * len(shape)
38
+ self.nbytes = self.value.nbytes
39
+ # Handle the QArray structure used in the loader
40
+ self.array = SimpleNamespace(
41
+ qvalue=self,
42
+ scale=SimpleNamespace(
43
+ value=jnp.ones((1, )),
44
+ nbytes=4,
45
+ sharding=None,
46
+ addressable_shards=[SimpleNamespace(data=jnp.ones((1, )))]))
47
+ self.addressable_shards = [SimpleNamespace(data=self.value)]
48
+
49
+
50
+ class MockVllmConfig:
51
+ """Mock VllmConfig for DeepSeekV3."""
52
+
53
+ def __init__(self,
54
+ model_name: str = "deepseek-ai/DeepSeek-V3",
55
+ use_mla: bool = False):
56
+ self.model_config = MagicMock(spec=ModelConfig)
57
+ self.model_config.model = model_name
58
+ self.model_config.use_mla = use_mla
59
+
60
+ # DeepSeek V3 specific config
61
+ hf_config = MagicMock()
62
+ hf_config.num_hidden_layers = 1 # Small for testing
63
+ hf_config.num_nextn_predict_layers = 1
64
+ self.model_config.hf_config = hf_config
65
+
66
+ self.load_config = MagicMock()
67
+ self.load_config.download_dir = None
68
+
69
+ self.cache_config = MagicMock()
70
+ self.cache_config.cache_dtype = "auto"
71
+
72
+ self.additional_config = {
73
+ "random_weights": False,
74
+ "sparse_matmul": False,
75
+ "is_verbose": True
76
+ }
77
+
78
+
79
+ @pytest.fixture(scope="module")
80
+ def mesh():
81
+ if not jax.devices():
82
+ pytest.skip("No JAX devices available.")
83
+ devices = np.array(jax.local_devices())
84
+ num_devices = len(devices)
85
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
86
+ # Simplify axis names for testing
87
+ with Mesh(device_mesh,
88
+ axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
89
+ yield m
90
+
91
+
92
+ @pytest.fixture
93
+ def rng():
94
+ return jax.random.PRNGKey(0)
95
+
96
+
97
+ @pytest.fixture
98
+ def mock_config():
99
+ return MockVllmConfig()
100
+
101
+
102
+ class TestDeepSeekV3:
103
+
104
+ def test_init(self, mock_config, rng, mesh):
105
+ """Tests if the model initializes with the correct hierarchy."""
106
+ model = DeepSeekV3(mock_config, rng, mesh)
107
+ assert len(model.layers) == 3 # num_layers from mock
108
+ assert isinstance(model.embedder, nnx.Module)
109
+ assert model.vllm_config.model_config.hf_config.num_hidden_layers == 1
110
+
111
+ def test_random_weights(self, mock_config, rng, mesh):
112
+ """Tests that force_random_weights initializes non-zero weights."""
113
+ with jax.set_mesh(mesh):
114
+ model = DeepSeekV3(mock_config,
115
+ rng,
116
+ mesh,
117
+ force_random_weights=True)
118
+ # Check embedding
119
+ weight = model.embedder.input_embedding_table_VD.value
120
+ assert jnp.std(weight) > 0
121
+ # Check a layer norm (should be 1s usually, but check existence)
122
+ assert model.final_norm.scale.value.shape == (7168, )
123
+
124
+ @patch("tpu_inference.models.jax.deepseek_v3.DeepSeekV3WeightLoader")
125
+ def test_load_weights_called(self, mock_loader_cls, mock_config, rng,
126
+ mesh):
127
+ model = DeepSeekV3(mock_config, rng, mesh)
128
+ mock_loader_instance = mock_loader_cls.return_value
129
+
130
+ model.load_weights(rng)
131
+
132
+ mock_loader_instance.load_weights.assert_called_once_with(model)
133
+
134
+
135
+ class TestDeepSeekV3WeightLoader:
136
+
137
+ @pytest.fixture
138
+ def loader(self, mock_config):
139
+ # We need to mock the generator so it doesn't try to download files
140
+ with patch(
141
+ "tpu_inference.models.jax.deepseek_v3.model_weights_generator",
142
+ return_value=[]):
143
+ return DeepSeekV3WeightLoader(vllm_config=mock_config,
144
+ num_layers=2,
145
+ hidden_size=7168,
146
+ q_lora_rank=1536,
147
+ kv_lora_rank=512,
148
+ attn_heads=128,
149
+ qk_nope_head_dim=128,
150
+ qk_rope_head_dim=64,
151
+ v_head_dim=128,
152
+ num_local_experts=256,
153
+ model_dtype=jnp.bfloat16)
154
+
155
+ @pytest.mark.parametrize("loaded_key, expected_mapped", [
156
+ ("model.embed_tokens.weight", "embedder.input_embedding_table_VD"),
157
+ ("model.layers.0.self_attn.q_a_proj.weight",
158
+ "layers.0.attn.kernel_q_down_proj_DA"),
159
+ ("model.layers.5.mlp.experts.10.gate_proj.weight",
160
+ "layers.5.custom_module.kernel_gating_EDF"),
161
+ ("model.layers.1.mlp.shared_experts.down_proj.weight",
162
+ "layers.1.shared_experts.kernel_down_proj_FD"),
163
+ ("model.norm.weight", "final_norm.scale"),
164
+ ])
165
+ def test_key_mapping(self, loader, loaded_key, expected_mapped):
166
+ assert loader.map_loaded_to_standardized_name(
167
+ loaded_key) == expected_mapped
168
+
169
+ def test_transpose_params(self, loader):
170
+ # Test a standard MLP transpose (1, 0)
171
+ dummy_weight = jnp.ones((100, 200))
172
+ transposed = loader._transpose_params("mlp.down_proj", dummy_weight)
173
+ assert transposed.shape == (200, 100)
174
+
175
+ # Test MLA kernel transpose (2, 0, 1)
176
+ dummy_mla = jnp.ones((10, 20, 30))
177
+ transposed_mla = loader._transpose_params("k_b_proj", dummy_mla)
178
+ assert transposed_mla.shape == (30, 10, 20)
179
+
180
+ def test_moe_stacking_logic(self, loader):
181
+ """Tests that individual expert weights are collected and stacked correctly."""
182
+ weights_dict = {}
183
+ layer_num = "0"
184
+ loader.num_routed_experts = 4 # Small for test
185
+
186
+ # Simulate loading 4 experts
187
+ for i in range(4):
188
+ name = f"model.layers.0.mlp.experts.{i}.gate_proj.weight"
189
+ weight = torch.ones((10, 20)) * i
190
+ result = loader._process_moe_weights(name, weight, weights_dict)
191
+
192
+ if i < 3:
193
+ assert result is None
194
+ assert weights_dict[layer_num][1] == i + 1
195
+ else:
196
+ # On the last expert, it should return stacked tensor
197
+ assert result is not None
198
+ assert result.shape == (4, 10, 20)
199
+ assert layer_num not in weights_dict # Should be cleaned up
200
+
201
+ def test_mla_kernel_weight_splitting(self, loader, mesh):
202
+ """Tests that kv_b_proj is split into k_b_proj and v_b_proj for MLA kernel."""
203
+ loader.use_mla_kernel = True
204
+ loader.attn_heads = 2
205
+ loader.qk_nope_head_dim = 4
206
+ loader.v_head_dim = 4
207
+ loader.kv_lora_rank = 8
208
+
209
+ # Total rows = heads * (nope_dim + v_dim) = 2 * (4 + 4) = 16
210
+ # Cols = kv_lora_rank = 8
211
+ kv_b_proj_weight = torch.randn((16, 8))
212
+
213
+ # Mocking the load_individual_weight to capture what gets passed
214
+ with patch.object(loader,
215
+ '_load_individual_weight',
216
+ return_value=(0, 0)):
217
+ model_mock = MagicMock()
218
+ model_mock.mesh = mesh
219
+
220
+ # Simulate the splitting logic in the loader
221
+ weight_reshaped = kv_b_proj_weight.view(2, 4 + 4, 8)
222
+ k_weight = weight_reshaped[:, :4, :]
223
+ v_weight = weight_reshaped[:, 4:, :]
224
+
225
+ # Verify shapes of split parts
226
+ assert k_weight.shape == (2, 4, 8)
227
+ assert v_weight.shape == (2, 4, 8)
228
+
229
+ def test_load_individual_weight_with_mxfp4(self, loader, mesh):
230
+ """Tests the logic for unpacking MXFP4 weights."""
231
+ name = "layers.0.attn.kernel_q_down_proj_DA"
232
+ # Mocking torch tensor as uint8 (packed fp4)
233
+ expected_weight_shape = (128, 128) # Unpacked
234
+ expected_scale_shape = (128, 1)
235
+
236
+ weight = torch.zeros(expected_weight_shape, dtype=torch.uint8)
237
+ scale = torch.ones(expected_scale_shape, dtype=torch.float32)
238
+
239
+ # Mock model parameters
240
+ mock_var = MockVariable(
241
+ (128, 128),
242
+ dtype=jnp.float4_e2m1fn,
243
+ sharding=(None, ('attn_dp', 'model',
244
+ 'expert'))) # Unpacked shape (64 * 2)
245
+ mock_params = {
246
+ "layers": {
247
+ "0": {
248
+ "attn": {
249
+ "kernel_q_down_proj_DA": mock_var
250
+ }
251
+ }
252
+ }
253
+ }
254
+
255
+ with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
256
+ patch("tpu_inference.models.jax.deepseek_v3.u8_unpack_e2m1") as mock_unpack, \
257
+ patch("jax.make_array_from_callback") as mock_make_array:
258
+
259
+ def side_effect_router(shape, *args, **kwargs):
260
+ if shape == expected_scale_shape:
261
+ # Return FP32 for the scale call
262
+ return jnp.ones(shape, dtype=jnp.float32)
263
+ elif shape == expected_weight_shape:
264
+ # Return FP4 for the weight call
265
+ return jnp.zeros(shape, dtype=jnp.float4_e2m1fn)
266
+ return jnp.zeros(shape) # Fallback
267
+
268
+ mock_make_array.side_effect = side_effect_router
269
+ mock_unpack.return_value = torch.zeros(expected_weight_shape)
270
+
271
+ loader._load_individual_weight(name,
272
+ weight,
273
+ mock_params,
274
+ mesh,
275
+ scale=scale)
276
+
277
+ mock_unpack.assert_called_once()
278
+ (actual_arg, ), _ = mock_unpack.call_args
279
+ # The implementation converts the torch weight to a JAX array
280
+ expected_arg = jnp.array(weight.cpu().numpy())
281
+ assert jnp.array_equal(actual_arg, expected_arg).item()
282
+ assert mock_make_array.called
283
+
284
+ def test_load_weights_full_flow(self, loader, mesh):
285
+ """Integrative test for the load_weights loop."""
286
+ model = MagicMock(spec=nnx.Module)
287
+ model.mesh = mesh
288
+
289
+ # Setup generator to return one normal weight
290
+ loader.names_and_weights_generator = [("model.embed_tokens.weight",
291
+ torch.ones((10, 10)))]
292
+
293
+ mock_var = MockVariable((10, 10))
294
+
295
+ with patch("tpu_inference.models.jax.deepseek_v3.nnx.state"), \
296
+ patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
297
+ patch("tpu_inference.models.jax.deepseek_v3.nnx.update"), \
298
+ patch.object(loader, '_load_individual_weight', return_value=(1.0, 0.5)):
299
+
300
+ loader.load_weights(model)
301
+ # Verify verbose logging worked if enabled
302
+ assert loader.is_verbose is True
303
+
304
+ def test_load_individual_weight_unpacked(self, loader, mesh):
305
+ """
306
+ Tests the logic for loading 'unpacked' weights (e.g., standard FP8).
307
+ This verifies the branch that uses DTYPE_VIEW_MAP for raw memory conversion.
308
+ """
309
+ name = "layers.0.attn.kernel_q_down_proj_DA"
310
+
311
+ # 1. Setup a standard 'unpacked' FP8 torch tensor
312
+ # DeepSeek V3 weights are often float8_e4m3fn
313
+ weight_shape = (128, 128)
314
+ weight = torch.randn(weight_shape).to(torch.float8_e4m3fn)
315
+
316
+ # 2. Mock model parameters to expect jnp.float8_e4m3fn
317
+ # We reuse the MockVariable helper but specify the dtype
318
+ mock_var = MockVariable(weight_shape, dtype=jnp.float8_e4m3fn)
319
+ mock_params = {
320
+ "layers": {
321
+ "0": {
322
+ "attn": {
323
+ "kernel_q_down_proj_DA": mock_var
324
+ }
325
+ }
326
+ }
327
+ }
328
+
329
+ # 3. Patch the necessary JAX/Utility functions
330
+ with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
331
+ patch("tpu_inference.models.jax.deepseek_v3.u8_unpack_e2m1") as mock_unpack, \
332
+ patch("jax.make_array_from_callback") as mock_make_array:
333
+
334
+ # Mock the JAX array creation to return a dummy
335
+ mock_make_array.return_value = jnp.zeros(weight_shape,
336
+ dtype=jnp.float8_e4m3fn)
337
+
338
+ # Execute the loader method
339
+ loader._load_individual_weight(name,
340
+ weight,
341
+ mock_params,
342
+ mesh,
343
+ scale=None)
344
+
345
+ # VERIFICATIONS:
346
+ # - u8_unpack_e2m1 should NOT be called for standard FP8 (only for packed uint8 + scale)
347
+ mock_unpack.assert_not_called()
348
+
349
+ # - make_array_from_callback should be called with the correct shape and sharding
350
+ # The first argument to make_array_from_callback is the shape
351
+ assert mock_make_array.call_args[0][0] == weight_shape
352
+
353
+ # - Verify the model weight value was updated (even if with our dummy)
354
+ assert mock_var.value.dtype == jnp.float8_e4m3fn
355
+
356
+ def test_load_individual_weight_with_scale(self, loader, mesh):
357
+ """
358
+ Tests loading an unpacked weight that also has a quantization scale.
359
+ """
360
+ name = "layers.0.custom_module.kernel_gating_DF"
361
+ weight_shape = (64, 128)
362
+ scale_shape = (64, 1)
363
+
364
+ # Use BF16 for this test to verify DTYPE_VIEW_MAP handles multiple types
365
+ weight = torch.randn(weight_shape).to(torch.bfloat16)
366
+ scale = torch.ones(scale_shape, dtype=torch.float32)
367
+
368
+ mock_var = MockVariable(weight_shape, dtype=jnp.bfloat16)
369
+ mock_params = {
370
+ "layers": {
371
+ "0": {
372
+ "custom_module": {
373
+ "kernel_gating_DF": mock_var
374
+ }
375
+ }
376
+ }
377
+ }
378
+
379
+ with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
380
+ patch("jax.make_array_from_callback") as mock_make_array:
381
+
382
+ def side_effect_router(shape, *args, **kwargs):
383
+ if shape == scale_shape:
384
+ # Return FP32 for the scale call
385
+ return jnp.ones(shape, dtype=jnp.float32)
386
+ elif shape == weight_shape:
387
+ # Return FP4 for the weight call
388
+ return jnp.zeros(shape, dtype=jnp.bfloat16)
389
+ return jnp.zeros(shape) # Fallback
390
+
391
+ mock_make_array.side_effect = side_effect_router
392
+
393
+ loader._load_individual_weight(name,
394
+ weight,
395
+ mock_params,
396
+ mesh,
397
+ scale=scale)
398
+
399
+ # Verify the scale was applied to the MockVariable's internal QArray structure
400
+ # (In the model code: base_model_weight.array.scale.value = maybe_sharded_scale)
401
+ assert mock_var.array.scale.value is not None
@@ -0,0 +1,184 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from flax import nnx
22
+ from flax.typing import PRNGKey
23
+ from jax.sharding import Mesh
24
+ from vllm.config import ModelConfig
25
+
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.models.jax.llama3 import LlamaForCausalLM
28
+ from tpu_inference.runner.kv_cache import create_kv_caches
29
+
30
+
31
+ class MockVllmConfig:
32
+
33
+ def __init__(self, model: str, kv_cache_dtype: str):
34
+ self.model_config = ModelConfig(model)
35
+ self.model_config.dtype = jnp.bfloat16
36
+ self.load_config = MagicMock()
37
+ self.load_config.download_dir = None
38
+ self.speculative_config = None
39
+ self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
40
+
41
+
42
+ @pytest.fixture(scope="module")
43
+ def mesh():
44
+ """
45
+ Creates a mesh with 1 device.
46
+ """
47
+ if not jax.devices():
48
+ pytest.skip("No JAX devices available for mesh creation.")
49
+
50
+ devices = np.array(jax.local_devices()[:1])
51
+ num_devices = len(devices)
52
+ assert num_devices == 1
53
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
54
+
55
+ with Mesh(device_mesh,
56
+ axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
57
+ yield m
58
+
59
+
60
+ @pytest.fixture
61
+ def mock_model_inputs():
62
+ num_tokens = 8
63
+ num_reqs = 1
64
+ max_num_blocks_per_req = 4
65
+ input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
66
+ positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
67
+ block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
68
+ dtype=jnp.int32).reshape(-1)
69
+ seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
70
+ query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
71
+ request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
72
+
73
+ attention_metadata = AttentionMetadata(
74
+ input_positions=positions,
75
+ block_tables=block_tables,
76
+ seq_lens=seq_lens,
77
+ query_start_loc=query_start_loc,
78
+ request_distribution=request_distribution,
79
+ )
80
+ indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
81
+
82
+ return (input_ids, attention_metadata, indices_do_sample)
83
+
84
+
85
+ @pytest.fixture
86
+ def rng() -> PRNGKey:
87
+ """Provides a reusable JAX PRNGKey."""
88
+ return jax.random.PRNGKey(42)
89
+
90
+
91
+ @pytest.fixture(autouse=True)
92
+ def mock_get_pp_group():
93
+ mock_pp = MagicMock(is_first_rank=True,
94
+ is_last_rank=True,
95
+ rank_in_group=0,
96
+ world_size=1)
97
+ with patch("tpu_inference.models.jax.llama3.get_pp_group",
98
+ return_value=mock_pp), patch(
99
+ "tpu_inference.layers.jax.pp_utils.get_pp_group",
100
+ return_value=mock_pp):
101
+ yield
102
+
103
+
104
+ class TestLlamaForCausalLM:
105
+ """Tests for the main LlamaForCausalLM model class."""
106
+
107
+ @pytest.mark.parametrize("mock_vllm_config", [
108
+ MockVllmConfig("meta-llama/Llama-3.2-1B", "auto"),
109
+ MockVllmConfig("meta-llama/Llama-3.2-1B", "fp8")
110
+ ])
111
+ def test_llama32_1b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
112
+ """Tests model init and model forward for the 8B model variant."""
113
+
114
+ # Test model init
115
+ model = LlamaForCausalLM(mock_vllm_config, rng, mesh)
116
+
117
+ model_config = mock_vllm_config.model_config
118
+ hf_config = model_config.hf_config
119
+
120
+ assert model.mesh.shape == {
121
+ "data": 1,
122
+ "attn_dp": 1,
123
+ "expert": 1,
124
+ "model": 1
125
+ }
126
+
127
+ layers = model.model.layers
128
+ assert len(layers) == hf_config.num_hidden_layers
129
+ assert isinstance(model.rng, nnx.Rngs)
130
+ assert model.model.lm_head == model.model.embed.embedding
131
+
132
+ attn = layers[0].self_attn
133
+ hidden_size = hf_config.hidden_size
134
+ num_heads = hf_config.num_attention_heads
135
+ num_kv_heads = hf_config.num_key_value_heads
136
+ rope_theta = hf_config.rope_theta
137
+ head_dim = hf_config.head_dim
138
+ intermediate_size = hf_config.intermediate_size
139
+
140
+ assert attn.hidden_size == hidden_size
141
+ assert attn.num_heads == num_heads
142
+ assert attn.num_kv_heads == num_kv_heads
143
+ assert attn.rope_theta == rope_theta
144
+ assert attn.head_dim_original == head_dim
145
+ assert attn.head_dim == head_dim
146
+ assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
147
+ assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
148
+ head_dim)
149
+ assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
150
+ head_dim)
151
+ assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
152
+
153
+ mlp = layers[0].mlp
154
+ assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
155
+ assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
156
+ assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
157
+
158
+ # Test model load
159
+ model.load_weights(rng)
160
+
161
+ # Test model forward
162
+ kv_caches = create_kv_caches(
163
+ num_blocks=4,
164
+ block_size=32,
165
+ num_kv_heads=num_kv_heads,
166
+ head_size=head_dim,
167
+ mesh=mesh,
168
+ layer_names=["layer"] * hf_config.num_hidden_layers,
169
+ cache_dtype=jnp.float8_e4m3fn
170
+ if mock_vllm_config.cache_config.cache_dtype == "fp8" else
171
+ jnp.bfloat16)
172
+ # 1 seq with 16 tokens
173
+ input_ids, attention_metadata, indices_do_sample = mock_model_inputs
174
+ kv_caches, hidden_states, aux_hidden_states = model(
175
+ kv_caches, input_ids, attention_metadata, None, None, None, None,
176
+ None, True, True)
177
+ assert hidden_states.shape == (8, hidden_size)
178
+ assert len(aux_hidden_states) == 0
179
+
180
+ hidden_states = hidden_states[indices_do_sample]
181
+ assert hidden_states.shape == (1, hidden_size)
182
+
183
+ logits = model.compute_logits(hidden_states)
184
+ assert logits.shape == (1, hf_config.vocab_size)